In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-21 02:12:28.564953: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 02:12:28.573265: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 02:12:28.599851: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742512348.643476  273793 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742512348.657979  273793 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-21 02:12:28.713615: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=False))
    
    model.add(Dropout(0.2))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = [1]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,...,Raw_1,Delta_1,Theta_1,Alpha1_1,Alpha2_1,Beta1_1,Beta2_1,Gamma1_1,Gamma2_1,predefinedlabel
0,0,0.0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0.0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,...,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
2,0,0.0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,...,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-21 02:12:34.447588: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 465ms/step - accuracy: 0.6701 - auc: 0.7135 - loss: 0.6616 - val_accuracy: 0.8000 - val_auc: 0.8222 - val_loss: 0.5966
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 261ms/step - accuracy: 0.6112 - auc: 0.7285 - loss: 0.6329 - val_accuracy: 0.8000 - val_auc: 0.7889 - val_loss: 0.5484
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 277ms/step - accuracy: 0.7285 - auc: 0.7274 - loss: 0.6078 - val_accuracy: 0.8000 - val_auc: 0.8222 - val_loss: 0.5177
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 316ms/step - accuracy: 0.7341 - auc: 0.7228 - loss: 0.5806 - val_accuracy: 0.8000 - val_auc: 0.7556 - val_loss: 0.4994
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 229ms/step - accuracy: 0.7285 - auc: 0.7182 - loss: 0.5674 - val_accuracy: 0.8000 - val_auc: 0.8000 - val_loss: 0.4836
Epoch 6/200
[1m7/7[0m [32m━━━━━

1it [07:29, 449.53s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 521ms/step - accuracy: 0.7040 - auc: 0.6332 - loss: 0.6843 - val_accuracy: 0.8000 - val_auc: 0.8000 - val_loss: 0.6693
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 273ms/step - accuracy: 0.7147 - auc: 0.7206 - loss: 0.6691 - val_accuracy: 0.8000 - val_auc: 0.7600 - val_loss: 0.6481
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 266ms/step - accuracy: 0.7441 - auc: 0.7790 - loss: 0.6496 - val_accuracy: 0.8000 - val_auc: 0.7600 - val_loss: 0.6217
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 236ms/step - accuracy: 0.7287 - auc: 0.8017 - loss: 0.6309 - val_accuracy: 0.7667 - val_auc: 0.7600 - val_loss: 0.5871
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 318ms/step - accuracy: 0.7554 - auc: 0.7298 - loss: 0.5923 - val_accuracy: 0.7667 - val_auc: 0.7600 - val_loss: 0.5508
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [14:52, 445.73s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 504ms/step - accuracy: 0.5326 - auc: 0.7269 - loss: 0.6492 - val_accuracy: 0.7000 - val_auc: 0.7889 - val_loss: 0.5959
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 268ms/step - accuracy: 0.5835 - auc: 0.7466 - loss: 0.6125 - val_accuracy: 0.7667 - val_auc: 0.8111 - val_loss: 0.5604
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 272ms/step - accuracy: 0.7520 - auc: 0.7772 - loss: 0.5838 - val_accuracy: 0.7667 - val_auc: 0.8111 - val_loss: 0.5426
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 252ms/step - accuracy: 0.5679 - auc: 0.7183 - loss: 0.5682 - val_accuracy: 0.7667 - val_auc: 0.8111 - val_loss: 0.5399
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 292ms/step - accuracy: 0.7009 - auc: 0.7289 - loss: 0.5523 - val_accuracy: 0.7667 - val_auc: 0.7889 - val_loss: 0.5417
Epoch 6/200
[1m7/7[0m [32m━━━━━

3it [22:11, 442.47s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 462ms/step - accuracy: 0.6240 - auc: 0.6572 - loss: 0.6778 - val_accuracy: 0.8000 - val_auc: 0.8133 - val_loss: 0.6450
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 233ms/step - accuracy: 0.7368 - auc: 0.7783 - loss: 0.6559 - val_accuracy: 0.7667 - val_auc: 0.7778 - val_loss: 0.6086
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 239ms/step - accuracy: 0.7297 - auc: 0.7931 - loss: 0.6287 - val_accuracy: 0.7667 - val_auc: 0.7667 - val_loss: 0.5669
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 277ms/step - accuracy: 0.7739 - auc: 0.7509 - loss: 0.5821 - val_accuracy: 0.7667 - val_auc: 0.7667 - val_loss: 0.5314
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 254ms/step - accuracy: 0.7899 - auc: 0.8058 - loss: 0.5111 - val_accuracy: 0.7667 - val_auc: 0.7644 - val_loss: 0.5052
Epoch 6/200
[1m7/7[0m [32m━━━━━

4it [29:34, 442.72s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 494ms/step - accuracy: 0.6736 - auc: 0.8196 - loss: 0.6371 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5839
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 283ms/step - accuracy: 0.8600 - auc: 0.8928 - loss: 0.5632 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5447
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 252ms/step - accuracy: 0.8537 - auc: 0.8428 - loss: 0.5072 - val_accuracy: 0.7333 - val_auc: 0.6933 - val_loss: 0.5242
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 248ms/step - accuracy: 0.8754 - auc: 0.8504 - loss: 0.4236 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5297
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 245ms/step - accuracy: 0.8600 - auc: 0.8522 - loss: 0.4226 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5323
Epoch 6/200
[1m7/7[0m [32m━━━━━

5it [36:41, 440.21s/it]

CPU times: user 58min 58s, sys: 16min 58s, total: 1h 15min 56s
Wall time: 36min 41s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

Epoch 1: TRAIN Accuracy = 0.694 Loss = 0.649 AUC = 0.723
Epoch 1: VAL Accuracy = 0.767 Loss = 0.618 AUC = 0.794
Epoch 2: TRAIN Accuracy = 0.729 Loss = 0.616 AUC = 0.78
Epoch 2: VAL Accuracy = 0.773 Loss = 0.582 AUC = 0.777
Epoch 3: TRAIN Accuracy = 0.774 Loss = 0.581 AUC = 0.79
Epoch 3: VAL Accuracy = 0.773 Loss = 0.555 AUC = 0.771
Epoch 4: TRAIN Accuracy = 0.746 Loss = 0.549 AUC = 0.759
Epoch 4: VAL Accuracy = 0.767 Loss = 0.537 AUC = 0.768
Epoch 5: TRAIN Accuracy = 0.769 Loss = 0.528 AUC = 0.758
Epoch 5: VAL Accuracy = 0.767 Loss = 0.523 AUC = 0.772
Epoch 6: TRAIN Accuracy = 0.771 Loss = 0.505 AUC = 0.796
Epoch 6: VAL Accuracy = 0.767 Loss = 0.504 AUC = 0.779
Epoch 7: TRAIN Accuracy = 0.78 Loss = 0.493 AUC = 0.791
Epoch 7: VAL Accuracy = 0.767 Loss = 0.481 AUC = 0.792
Epoch 8: TRAIN Accuracy = 0.791 Loss = 0.474 AUC = 0.788
Epoch 8: VAL Accuracy = 0.767 Loss = 0.464 AUC = 0.796
Epoch 9: TRAIN Accuracy = 0.789 Loss = 0.471 AUC = 0.778
Epoch 9: VAL Accuracy = 0.787 Loss = 0.45 AUC = 0.