In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-21 11:22:37.379556: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 11:22:37.385139: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 11:22:37.399277: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742545357.426367  659057 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742545357.434751  659057 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-21 11:22:37.472727: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=False))
    
    model.add(Dropout(0.2))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = [1, 2]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,...,Raw_2,Delta_2,Theta_2,Alpha1_2,Alpha2_2,Beta1_2,Beta2_2,Gamma1_2,Gamma2_2,predefinedlabel
0,0,0.0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0.0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0,0.0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,...,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-21 11:22:42.574495: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 337ms/step - accuracy: 0.5055 - auc: 0.6433 - loss: 0.6553 - val_accuracy: 0.8000 - val_auc: 0.8156 - val_loss: 0.6111
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 244ms/step - accuracy: 0.6491 - auc: 0.7268 - loss: 0.6364 - val_accuracy: 0.8000 - val_auc: 0.8178 - val_loss: 0.5672
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 248ms/step - accuracy: 0.7024 - auc: 0.7182 - loss: 0.6085 - val_accuracy: 0.7000 - val_auc: 0.7800 - val_loss: 0.5468
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 244ms/step - accuracy: 0.5882 - auc: 0.6926 - loss: 0.5879 - val_accuracy: 0.8000 - val_auc: 0.8178 - val_loss: 0.5221
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 240ms/step - accuracy: 0.7226 - auc: 0.7479 - loss: 0.5705 - val_accuracy: 0.8000 - val_auc: 0.8511 - val_loss: 0.5116
Epoch 6/200
[1m7/7[0m [32m━━━━━━

1it [05:57, 357.56s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 464ms/step - accuracy: 0.5292 - auc: 0.5640 - loss: 0.6891 - val_accuracy: 0.5000 - val_auc: 0.7844 - val_loss: 0.6788
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 237ms/step - accuracy: 0.5458 - auc: 0.7437 - loss: 0.6745 - val_accuracy: 0.5000 - val_auc: 0.7867 - val_loss: 0.6581
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 242ms/step - accuracy: 0.5494 - auc: 0.7547 - loss: 0.6619 - val_accuracy: 0.5000 - val_auc: 0.7867 - val_loss: 0.6356
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 234ms/step - accuracy: 0.5860 - auc: 0.7706 - loss: 0.6407 - val_accuracy: 0.7333 - val_auc: 0.7867 - val_loss: 0.6074
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 256ms/step - accuracy: 0.7221 - auc: 0.7734 - loss: 0.6114 - val_accuracy: 0.7667 - val_auc: 0.7867 - val_loss: 0.5778
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [12:09, 366.27s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 444ms/step - accuracy: 0.4396 - auc: 0.7093 - loss: 0.6323 - val_accuracy: 0.7000 - val_auc: 0.7889 - val_loss: 0.5960
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 248ms/step - accuracy: 0.5493 - auc: 0.7256 - loss: 0.6078 - val_accuracy: 0.7667 - val_auc: 0.8111 - val_loss: 0.5575
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 229ms/step - accuracy: 0.7975 - auc: 0.8053 - loss: 0.5744 - val_accuracy: 0.7667 - val_auc: 0.8111 - val_loss: 0.5413
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 247ms/step - accuracy: 0.6455 - auc: 0.7227 - loss: 0.5592 - val_accuracy: 0.7667 - val_auc: 0.8111 - val_loss: 0.5423
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 236ms/step - accuracy: 0.7624 - auc: 0.7240 - loss: 0.5465 - val_accuracy: 0.7667 - val_auc: 0.8111 - val_loss: 0.5353
Epoch 6/200
[1m7/7[0m [32m━━━━━

3it [18:18, 367.36s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 465ms/step - accuracy: 0.2944 - auc: 0.3455 - loss: 0.7012 - val_accuracy: 0.4000 - val_auc: 0.4578 - val_loss: 0.6918
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 234ms/step - accuracy: 0.4577 - auc: 0.5467 - loss: 0.6912 - val_accuracy: 0.7333 - val_auc: 0.6778 - val_loss: 0.6796
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 249ms/step - accuracy: 0.6844 - auc: 0.6511 - loss: 0.6783 - val_accuracy: 0.7000 - val_auc: 0.7356 - val_loss: 0.6623
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 240ms/step - accuracy: 0.7262 - auc: 0.7762 - loss: 0.6603 - val_accuracy: 0.7333 - val_auc: 0.6911 - val_loss: 0.6340
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 234ms/step - accuracy: 0.7522 - auc: 0.7896 - loss: 0.6163 - val_accuracy: 0.7333 - val_auc: 0.6867 - val_loss: 0.5981
Epoch 6/200
[1m7/7[0m [32m━━━━━

4it [24:46, 375.42s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 470ms/step - accuracy: 0.7420 - auc: 0.7906 - loss: 0.6200 - val_accuracy: 0.7333 - val_auc: 0.7733 - val_loss: 0.5779
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 236ms/step - accuracy: 0.8600 - auc: 0.8814 - loss: 0.5441 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5404
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 252ms/step - accuracy: 0.8258 - auc: 0.8622 - loss: 0.4943 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5170
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 239ms/step - accuracy: 0.8258 - auc: 0.8500 - loss: 0.4540 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5116
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 239ms/step - accuracy: 0.8258 - auc: 0.8460 - loss: 0.4397 - val_accuracy: 0.7333 - val_auc: 0.7733 - val_loss: 0.5014
Epoch 6/200
[1m7/7[0m [32m━━━━━

5it [30:55, 371.02s/it]

CPU times: user 1h 1min 49s, sys: 22min 10s, total: 1h 23min 59s
Wall time: 30min 55s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

Epoch 1: TRAIN Accuracy = 0.514 Loss = 0.649 AUC = 0.618
Epoch 1: VAL Accuracy = 0.627 Loss = 0.631 AUC = 0.724
Epoch 2: TRAIN Accuracy = 0.62 Loss = 0.623 AUC = 0.715
Epoch 2: VAL Accuracy = 0.707 Loss = 0.601 AUC = 0.768
Epoch 3: TRAIN Accuracy = 0.703 Loss = 0.594 AUC = 0.778
Epoch 3: VAL Accuracy = 0.68 Loss = 0.581 AUC = 0.772
Epoch 4: TRAIN Accuracy = 0.669 Loss = 0.572 AUC = 0.754
Epoch 4: VAL Accuracy = 0.753 Loss = 0.563 AUC = 0.771
Epoch 5: TRAIN Accuracy = 0.749 Loss = 0.552 AUC = 0.77
Epoch 5: VAL Accuracy = 0.76 Loss = 0.545 AUC = 0.782
Epoch 6: TRAIN Accuracy = 0.749 Loss = 0.533 AUC = 0.778
Epoch 6: VAL Accuracy = 0.76 Loss = 0.525 AUC = 0.786
Epoch 7: TRAIN Accuracy = 0.771 Loss = 0.51 AUC = 0.813
Epoch 7: VAL Accuracy = 0.767 Loss = 0.506 AUC = 0.792
Epoch 8: TRAIN Accuracy = 0.777 Loss = 0.491 AUC = 0.785
Epoch 8: VAL Accuracy = 0.78 Loss = 0.484 AUC = 0.799
Epoch 9: TRAIN Accuracy = 0.78 Loss = 0.477 AUC = 0.794
Epoch 9: VAL Accuracy = 0.78 Loss = 0.468 AUC = 0.791
E