In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-21 17:04:54.081671: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 17:04:54.090566: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 17:04:54.125287: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742565894.182559  888115 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742565894.197148  888115 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-21 17:04:54.256281: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=False))
    
    model.add(Dropout(0.2))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = [2]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,...,Raw_2,Delta_2,Theta_2,Alpha1_2,Alpha2_2,Beta1_2,Beta2_2,Gamma1_2,Gamma2_2,predefinedlabel
0,0,0.0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0.0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0,0.0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,...,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-21 17:05:00.639916: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 411ms/step - accuracy: 0.6051 - auc: 0.7220 - loss: 0.6593 - val_accuracy: 0.8000 - val_auc: 0.8222 - val_loss: 0.5768
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 222ms/step - accuracy: 0.6728 - auc: 0.6662 - loss: 0.6294 - val_accuracy: 0.8000 - val_auc: 0.8222 - val_loss: 0.5257
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 229ms/step - accuracy: 0.7249 - auc: 0.7093 - loss: 0.6046 - val_accuracy: 0.8000 - val_auc: 0.8000 - val_loss: 0.5004
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 227ms/step - accuracy: 0.6943 - auc: 0.7208 - loss: 0.5797 - val_accuracy: 0.8000 - val_auc: 0.8333 - val_loss: 0.4825
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 239ms/step - accuracy: 0.7203 - auc: 0.7086 - loss: 0.5650 - val_accuracy: 0.8000 - val_auc: 0.8222 - val_loss: 0.4670
Epoch 6/200
[1m7/7[0m [32m━━━━━

1it [06:21, 381.84s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 515ms/step - accuracy: 0.5643 - auc: 0.6001 - loss: 0.6856 - val_accuracy: 0.7667 - val_auc: 0.7867 - val_loss: 0.6762
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 258ms/step - accuracy: 0.7138 - auc: 0.7503 - loss: 0.6696 - val_accuracy: 0.7667 - val_auc: 0.7733 - val_loss: 0.6531
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 269ms/step - accuracy: 0.7441 - auc: 0.7869 - loss: 0.6494 - val_accuracy: 0.7667 - val_auc: 0.7733 - val_loss: 0.6257
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 242ms/step - accuracy: 0.7287 - auc: 0.8335 - loss: 0.6272 - val_accuracy: 0.7667 - val_auc: 0.7733 - val_loss: 0.5890
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 257ms/step - accuracy: 0.7183 - auc: 0.6654 - loss: 0.5946 - val_accuracy: 0.7667 - val_auc: 0.7733 - val_loss: 0.5413
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [12:54, 388.44s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 469ms/step - accuracy: 0.5543 - auc: 0.7246 - loss: 0.6365 - val_accuracy: 0.7000 - val_auc: 0.7889 - val_loss: 0.5950
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 245ms/step - accuracy: 0.6792 - auc: 0.7125 - loss: 0.6106 - val_accuracy: 0.7667 - val_auc: 0.8111 - val_loss: 0.5586
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 255ms/step - accuracy: 0.7918 - auc: 0.7828 - loss: 0.5810 - val_accuracy: 0.7667 - val_auc: 0.7889 - val_loss: 0.5404
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 248ms/step - accuracy: 0.6491 - auc: 0.7120 - loss: 0.5661 - val_accuracy: 0.7667 - val_auc: 0.8111 - val_loss: 0.5333
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 235ms/step - accuracy: 0.7841 - auc: 0.7515 - loss: 0.5457 - val_accuracy: 0.7667 - val_auc: 0.8111 - val_loss: 0.5301
Epoch 6/200
[1m7/7[0m [32m━━━━━

3it [19:20, 387.10s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 471ms/step - accuracy: 0.4905 - auc: 0.5413 - loss: 0.6903 - val_accuracy: 0.7333 - val_auc: 0.7844 - val_loss: 0.6567
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 243ms/step - accuracy: 0.6596 - auc: 0.7480 - loss: 0.6676 - val_accuracy: 0.8333 - val_auc: 0.8111 - val_loss: 0.6253
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 256ms/step - accuracy: 0.7883 - auc: 0.8320 - loss: 0.6231 - val_accuracy: 0.8333 - val_auc: 0.7667 - val_loss: 0.5772
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 248ms/step - accuracy: 0.7899 - auc: 0.7459 - loss: 0.5896 - val_accuracy: 0.8333 - val_auc: 0.7667 - val_loss: 0.5260
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 239ms/step - accuracy: 0.7956 - auc: 0.7792 - loss: 0.5257 - val_accuracy: 0.7667 - val_auc: 0.7667 - val_loss: 0.4911
Epoch 6/200
[1m7/7[0m [32m━━━━━

4it [25:48, 387.49s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 467ms/step - accuracy: 0.6890 - auc: 0.8180 - loss: 0.6203 - val_accuracy: 0.7333 - val_auc: 0.7333 - val_loss: 0.5771
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 252ms/step - accuracy: 0.8600 - auc: 0.8737 - loss: 0.5472 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5374
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 238ms/step - accuracy: 0.8258 - auc: 0.8813 - loss: 0.4839 - val_accuracy: 0.7333 - val_auc: 0.7200 - val_loss: 0.5147
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 252ms/step - accuracy: 0.8258 - auc: 0.8454 - loss: 0.4534 - val_accuracy: 0.7333 - val_auc: 0.7311 - val_loss: 0.5088
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 247ms/step - accuracy: 0.8258 - auc: 0.8548 - loss: 0.4386 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.4983
Epoch 6/200
[1m7/7[0m [32m━━━━━

5it [32:02, 384.52s/it]

CPU times: user 1h 3min 45s, sys: 22min 56s, total: 1h 26min 42s
Wall time: 32min 2s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

Epoch 1: TRAIN Accuracy = 0.634 Loss = 0.646 AUC = 0.693
Epoch 1: VAL Accuracy = 0.747 Loss = 0.616 AUC = 0.783
Epoch 2: TRAIN Accuracy = 0.734 Loss = 0.615 AUC = 0.766
Epoch 2: VAL Accuracy = 0.78 Loss = 0.58 AUC = 0.793
Epoch 3: TRAIN Accuracy = 0.777 Loss = 0.578 AUC = 0.795
Epoch 3: VAL Accuracy = 0.78 Loss = 0.552 AUC = 0.77
Epoch 4: TRAIN Accuracy = 0.746 Loss = 0.553 AUC = 0.763
Epoch 4: VAL Accuracy = 0.78 Loss = 0.528 AUC = 0.783
Epoch 5: TRAIN Accuracy = 0.771 Loss = 0.529 AUC = 0.756
Epoch 5: VAL Accuracy = 0.767 Loss = 0.506 AUC = 0.784
Epoch 6: TRAIN Accuracy = 0.771 Loss = 0.508 AUC = 0.792
Epoch 6: VAL Accuracy = 0.773 Loss = 0.484 AUC = 0.789
Epoch 7: TRAIN Accuracy = 0.771 Loss = 0.492 AUC = 0.798
Epoch 7: VAL Accuracy = 0.78 Loss = 0.463 AUC = 0.797
Epoch 8: TRAIN Accuracy = 0.791 Loss = 0.475 AUC = 0.783
Epoch 8: VAL Accuracy = 0.787 Loss = 0.45 AUC = 0.792
Epoch 9: TRAIN Accuracy = 0.789 Loss = 0.467 AUC = 0.786
Epoch 9: VAL Accuracy = 0.793 Loss = 0.443 AUC = 0.801