In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-02-17 01:54:35.993049: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-17 01:54:36.065058: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-17 01:54:36.158918: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1739746476.238827  257755 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739746476.263036  257755 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-17 01:54:36.461333: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 300
STEP = 5
NUM_EXPERIMENTS = 10

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(GRU(64, return_sequences=True))
    model.add(BatchNormalization())
    model.add(Dropout(0.3))

    model.add(GRU(32, return_sequences=False))
    model.add(BatchNormalization())
    model.add(Dropout(0.3))

    model.add(Dense(16, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.0003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-02-17 01:54:40.420764: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

models = []

for i in range(NUM_EPOCHS // STEP):
    epoch_acc = []
    epoch_loss = []
    epoch_auc = []

    epoch_val_acc = []
    epoch_val_loss = []
    epoch_val_auc = []

    for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
        np.random.seed(int(seed))
        random.seed(int(seed))
        tf.random.set_seed(int(seed))

        train_id = np.random.choice(np.unique(np.ravel(data[ID])), 70, replace=False)
        train_index = np.isin(data[ID], train_id)

        train = data.iloc[train_index]
        test = data.iloc[~train_index]

        X_train, y_train = reshape_dataset(train)
        X_test, y_test = reshape_dataset(test)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)

        if i == 0:
            model = create_model(X_train)
            models.append(model)
        else:
            model = models[j]

        history = model.fit(
            X_train, y_train,
            validation_data=(X_test, y_test),
            epochs=STEP,
            batch_size=16,
            verbose=0,
        )
        acc = history.history['accuracy'][0]
        loss = history.history['loss'][0]
        auc = history.history['auc'][0]

        val_acc = history.history['val_accuracy'][0]
        val_loss = history.history['val_loss'][0]
        val_auc = history.history['val_auc'][0]

        epoch_acc.append(acc)
        epoch_loss.append(loss)
        epoch_auc.append(auc)

        epoch_val_acc.append(val_acc)
        epoch_val_loss.append(val_loss)
        epoch_val_auc.append(val_auc)

    print(f"Epoch {(i + 1) * STEP}: TRAIN Accuracy = {np.round(np.mean(epoch_acc), 3)} Loss = {np.round(np.mean(epoch_loss), 3)} AUC = {np.round(np.mean(epoch_auc), 3)}")
    print(f"Epoch {(i + 1) * STEP}: VAL Accuracy = {np.round(np.mean(epoch_val_acc), 3)} Loss = {np.round(np.mean(epoch_val_loss), 3)} AUC = {np.round(np.mean(epoch_val_auc), 3)}")

10it [01:43, 10.34s/it]


Epoch 5: TRAIN Accuracy = 0.52 Loss = 0.86 AUC = 0.489
Epoch 5: VAL Accuracy = 0.553 Loss = 0.702 AUC = 0.435


10it [01:02,  6.26s/it]


Epoch 10: TRAIN Accuracy = 0.619 Loss = 0.69 AUC = 0.656
Epoch 10: VAL Accuracy = 0.743 Loss = 0.639 AUC = 0.723


10it [01:00,  6.08s/it]


Epoch 15: TRAIN Accuracy = 0.7 Loss = 0.596 AUC = 0.742
Epoch 15: VAL Accuracy = 0.783 Loss = 0.594 AUC = 0.782


10it [01:00,  6.09s/it]


Epoch 20: TRAIN Accuracy = 0.721 Loss = 0.552 AUC = 0.778
Epoch 20: VAL Accuracy = 0.8 Loss = 0.561 AUC = 0.778


10it [00:58,  5.85s/it]


Epoch 25: TRAIN Accuracy = 0.744 Loss = 0.519 AUC = 0.8
Epoch 25: VAL Accuracy = 0.797 Loss = 0.542 AUC = 0.775


10it [01:02,  6.26s/it]


Epoch 30: TRAIN Accuracy = 0.77 Loss = 0.514 AUC = 0.797
Epoch 30: VAL Accuracy = 0.793 Loss = 0.528 AUC = 0.777


10it [00:58,  5.83s/it]


Epoch 35: TRAIN Accuracy = 0.76 Loss = 0.508 AUC = 0.79
Epoch 35: VAL Accuracy = 0.79 Loss = 0.516 AUC = 0.788


10it [01:09,  6.96s/it]


Epoch 40: TRAIN Accuracy = 0.771 Loss = 0.488 AUC = 0.808
Epoch 40: VAL Accuracy = 0.793 Loss = 0.506 AUC = 0.802


10it [01:10,  7.03s/it]


Epoch 45: TRAIN Accuracy = 0.796 Loss = 0.458 AUC = 0.831
Epoch 45: VAL Accuracy = 0.793 Loss = 0.501 AUC = 0.793


10it [01:08,  6.81s/it]


Epoch 50: TRAIN Accuracy = 0.786 Loss = 0.461 AUC = 0.816
Epoch 50: VAL Accuracy = 0.793 Loss = 0.494 AUC = 0.795


10it [01:06,  6.60s/it]


Epoch 55: TRAIN Accuracy = 0.796 Loss = 0.447 AUC = 0.825
Epoch 55: VAL Accuracy = 0.787 Loss = 0.486 AUC = 0.8


10it [01:09,  6.95s/it]


Epoch 60: TRAIN Accuracy = 0.773 Loss = 0.462 AUC = 0.805
Epoch 60: VAL Accuracy = 0.787 Loss = 0.483 AUC = 0.811


10it [01:12,  7.21s/it]


Epoch 65: TRAIN Accuracy = 0.807 Loss = 0.434 AUC = 0.838
Epoch 65: VAL Accuracy = 0.78 Loss = 0.48 AUC = 0.805


10it [01:16,  7.68s/it]


Epoch 70: TRAIN Accuracy = 0.784 Loss = 0.457 AUC = 0.815
Epoch 70: VAL Accuracy = 0.78 Loss = 0.475 AUC = 0.813


10it [01:15,  7.54s/it]


Epoch 75: TRAIN Accuracy = 0.8 Loss = 0.414 AUC = 0.853
Epoch 75: VAL Accuracy = 0.787 Loss = 0.466 AUC = 0.812


10it [01:31,  9.19s/it]


Epoch 80: TRAIN Accuracy = 0.807 Loss = 0.419 AUC = 0.848
Epoch 80: VAL Accuracy = 0.787 Loss = 0.462 AUC = 0.821


10it [01:51, 11.18s/it]


Epoch 85: TRAIN Accuracy = 0.803 Loss = 0.408 AUC = 0.865
Epoch 85: VAL Accuracy = 0.783 Loss = 0.46 AUC = 0.824


10it [01:58, 11.87s/it]


Epoch 90: TRAIN Accuracy = 0.803 Loss = 0.405 AUC = 0.864
Epoch 90: VAL Accuracy = 0.783 Loss = 0.46 AUC = 0.82


10it [01:14,  7.43s/it]


Epoch 95: TRAIN Accuracy = 0.809 Loss = 0.408 AUC = 0.858
Epoch 95: VAL Accuracy = 0.78 Loss = 0.455 AUC = 0.824


10it [01:21,  8.11s/it]


Epoch 100: TRAIN Accuracy = 0.816 Loss = 0.396 AUC = 0.866
Epoch 100: VAL Accuracy = 0.777 Loss = 0.453 AUC = 0.815


10it [01:21,  8.10s/it]


Epoch 105: TRAIN Accuracy = 0.817 Loss = 0.39 AUC = 0.876
Epoch 105: VAL Accuracy = 0.773 Loss = 0.442 AUC = 0.823


10it [01:28,  8.83s/it]


Epoch 110: TRAIN Accuracy = 0.821 Loss = 0.392 AUC = 0.869
Epoch 110: VAL Accuracy = 0.777 Loss = 0.443 AUC = 0.819


10it [01:10,  7.00s/it]


Epoch 115: TRAIN Accuracy = 0.826 Loss = 0.379 AUC = 0.883
Epoch 115: VAL Accuracy = 0.783 Loss = 0.435 AUC = 0.832


10it [01:06,  6.68s/it]


Epoch 120: TRAIN Accuracy = 0.831 Loss = 0.369 AUC = 0.89
Epoch 120: VAL Accuracy = 0.79 Loss = 0.436 AUC = 0.837


10it [01:25,  8.50s/it]


Epoch 125: TRAIN Accuracy = 0.824 Loss = 0.368 AUC = 0.891
Epoch 125: VAL Accuracy = 0.787 Loss = 0.436 AUC = 0.852


10it [01:24,  8.46s/it]


Epoch 130: TRAIN Accuracy = 0.824 Loss = 0.354 AUC = 0.905
Epoch 130: VAL Accuracy = 0.783 Loss = 0.441 AUC = 0.844


10it [01:09,  6.93s/it]


Epoch 135: TRAIN Accuracy = 0.84 Loss = 0.347 AUC = 0.905
Epoch 135: VAL Accuracy = 0.78 Loss = 0.455 AUC = 0.84


10it [01:07,  6.76s/it]


Epoch 140: TRAIN Accuracy = 0.847 Loss = 0.348 AUC = 0.902
Epoch 140: VAL Accuracy = 0.777 Loss = 0.456 AUC = 0.833


10it [01:08,  6.87s/it]


Epoch 145: TRAIN Accuracy = 0.857 Loss = 0.326 AUC = 0.924
Epoch 145: VAL Accuracy = 0.777 Loss = 0.454 AUC = 0.836


10it [01:05,  6.56s/it]


Epoch 150: TRAIN Accuracy = 0.866 Loss = 0.313 AUC = 0.935
Epoch 150: VAL Accuracy = 0.773 Loss = 0.465 AUC = 0.831


10it [01:08,  6.82s/it]


Epoch 155: TRAIN Accuracy = 0.86 Loss = 0.318 AUC = 0.919
Epoch 155: VAL Accuracy = 0.77 Loss = 0.478 AUC = 0.83


10it [01:13,  7.37s/it]


Epoch 160: TRAIN Accuracy = 0.854 Loss = 0.315 AUC = 0.926
Epoch 160: VAL Accuracy = 0.767 Loss = 0.476 AUC = 0.829


10it [01:18,  7.84s/it]


Epoch 165: TRAIN Accuracy = 0.874 Loss = 0.291 AUC = 0.943
Epoch 165: VAL Accuracy = 0.77 Loss = 0.49 AUC = 0.827


10it [01:24,  8.48s/it]


Epoch 170: TRAIN Accuracy = 0.887 Loss = 0.283 AUC = 0.939
Epoch 170: VAL Accuracy = 0.753 Loss = 0.508 AUC = 0.825


10it [01:23,  8.38s/it]


Epoch 175: TRAIN Accuracy = 0.883 Loss = 0.27 AUC = 0.949
Epoch 175: VAL Accuracy = 0.753 Loss = 0.535 AUC = 0.821


10it [01:31,  9.15s/it]


Epoch 180: TRAIN Accuracy = 0.897 Loss = 0.251 AUC = 0.959
Epoch 180: VAL Accuracy = 0.75 Loss = 0.54 AUC = 0.818


10it [01:46, 10.66s/it]


Epoch 185: TRAIN Accuracy = 0.886 Loss = 0.253 AUC = 0.961
Epoch 185: VAL Accuracy = 0.743 Loss = 0.564 AUC = 0.82


10it [02:01, 12.17s/it]


Epoch 190: TRAIN Accuracy = 0.887 Loss = 0.251 AUC = 0.958
Epoch 190: VAL Accuracy = 0.75 Loss = 0.571 AUC = 0.82


10it [02:22, 14.24s/it]


Epoch 195: TRAIN Accuracy = 0.91 Loss = 0.223 AUC = 0.971
Epoch 195: VAL Accuracy = 0.753 Loss = 0.609 AUC = 0.817


10it [02:41, 16.17s/it]


Epoch 200: TRAIN Accuracy = 0.907 Loss = 0.222 AUC = 0.969
Epoch 200: VAL Accuracy = 0.76 Loss = 0.648 AUC = 0.814


10it [02:17, 13.72s/it]


Epoch 205: TRAIN Accuracy = 0.92 Loss = 0.201 AUC = 0.981
Epoch 205: VAL Accuracy = 0.75 Loss = 0.67 AUC = 0.813


10it [03:59, 23.91s/it]


Epoch 210: TRAIN Accuracy = 0.92 Loss = 0.196 AUC = 0.977
Epoch 210: VAL Accuracy = 0.75 Loss = 0.693 AUC = 0.814


10it [09:49, 58.94s/it]


Epoch 215: TRAIN Accuracy = 0.92 Loss = 0.187 AUC = 0.979
Epoch 215: VAL Accuracy = 0.75 Loss = 0.712 AUC = 0.813


10it [10:57, 65.74s/it]


Epoch 220: TRAIN Accuracy = 0.933 Loss = 0.171 AUC = 0.983
Epoch 220: VAL Accuracy = 0.73 Loss = 0.763 AUC = 0.811


10it [19:50, 119.02s/it]


Epoch 225: TRAIN Accuracy = 0.956 Loss = 0.14 AUC = 0.988
Epoch 225: VAL Accuracy = 0.74 Loss = 0.811 AUC = 0.81


10it [12:30, 75.04s/it]


Epoch 230: TRAIN Accuracy = 0.964 Loss = 0.126 AUC = 0.996
Epoch 230: VAL Accuracy = 0.73 Loss = 0.835 AUC = 0.805


10it [21:25, 128.59s/it]


Epoch 235: TRAIN Accuracy = 0.974 Loss = 0.114 AUC = 0.997
Epoch 235: VAL Accuracy = 0.73 Loss = 0.895 AUC = 0.801


10it [01:19,  7.96s/it]


Epoch 240: TRAIN Accuracy = 0.967 Loss = 0.111 AUC = 0.994
Epoch 240: VAL Accuracy = 0.723 Loss = 0.902 AUC = 0.802


10it [01:16,  7.68s/it]


Epoch 245: TRAIN Accuracy = 0.98 Loss = 0.101 AUC = 0.996
Epoch 245: VAL Accuracy = 0.723 Loss = 0.937 AUC = 0.798


10it [01:18,  7.89s/it]


Epoch 250: TRAIN Accuracy = 0.986 Loss = 0.085 AUC = 0.998
Epoch 250: VAL Accuracy = 0.713 Loss = 0.996 AUC = 0.798


10it [01:18,  7.89s/it]


Epoch 255: TRAIN Accuracy = 0.986 Loss = 0.08 AUC = 0.998
Epoch 255: VAL Accuracy = 0.707 Loss = 1.034 AUC = 0.807


10it [01:21,  8.11s/it]


Epoch 260: TRAIN Accuracy = 0.983 Loss = 0.073 AUC = 0.998
Epoch 260: VAL Accuracy = 0.727 Loss = 1.002 AUC = 0.799


10it [01:21,  8.12s/it]


Epoch 265: TRAIN Accuracy = 0.983 Loss = 0.072 AUC = 0.997
Epoch 265: VAL Accuracy = 0.727 Loss = 1.021 AUC = 0.807


10it [01:19,  7.91s/it]


Epoch 270: TRAIN Accuracy = 0.989 Loss = 0.067 AUC = 0.997
Epoch 270: VAL Accuracy = 0.743 Loss = 1.028 AUC = 0.804


10it [01:19,  7.91s/it]


Epoch 275: TRAIN Accuracy = 0.99 Loss = 0.056 AUC = 0.999
Epoch 275: VAL Accuracy = 0.72 Loss = 1.085 AUC = 0.803


10it [01:17,  7.79s/it]


Epoch 280: TRAIN Accuracy = 0.994 Loss = 0.048 AUC = 1.0
Epoch 280: VAL Accuracy = 0.727 Loss = 1.112 AUC = 0.798


10it [01:20,  8.10s/it]


Epoch 285: TRAIN Accuracy = 0.993 Loss = 0.042 AUC = 1.0
Epoch 285: VAL Accuracy = 0.713 Loss = 1.14 AUC = 0.8


10it [01:20,  8.07s/it]


Epoch 290: TRAIN Accuracy = 0.993 Loss = 0.041 AUC = 1.0
Epoch 290: VAL Accuracy = 0.733 Loss = 1.18 AUC = 0.794


10it [01:19,  7.95s/it]


Epoch 295: TRAIN Accuracy = 0.996 Loss = 0.034 AUC = 1.0
Epoch 295: VAL Accuracy = 0.717 Loss = 1.218 AUC = 0.789


10it [01:19,  7.98s/it]

Epoch 300: TRAIN Accuracy = 0.999 Loss = 0.027 AUC = 1.0
Epoch 300: VAL Accuracy = 0.713 Loss = 1.219 AUC = 0.79
CPU times: user 2h 14min 33s, sys: 1h 4min 9s, total: 3h 18min 43s
Wall time: 2h 32min 36s



