In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout, LSTM
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-21 17:04:42.483588: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 17:04:42.490108: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 17:04:42.508510: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742565882.533263  887875 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742565882.538925  887875 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-21 17:04:42.574372: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=False))

    model.add(Dropout(0.2))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = [2]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,...,Raw_2,Delta_2,Theta_2,Alpha1_2,Alpha2_2,Beta1_2,Beta2_2,Gamma1_2,Gamma2_2,predefinedlabel
0,0,0.0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0.0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0,0.0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,...,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-21 17:04:46.306706: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)


0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 281ms/step - accuracy: 0.6483 - auc: 0.7102 - loss: 0.6843 - val_accuracy: 0.6333 - val_auc: 0.7289 - val_loss: 0.6573
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 252ms/step - accuracy: 0.5854 - auc: 0.6207 - loss: 0.6724 - val_accuracy: 0.6333 - val_auc: 0.7511 - val_loss: 0.6094
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 265ms/step - accuracy: 0.6192 - auc: 0.7108 - loss: 0.6470 - val_accuracy: 0.7333 - val_auc: 0.7956 - val_loss: 0.5795
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 237ms/step - accuracy: 0.6085 - auc: 0.6318 - loss: 0.6361 - val_accuracy: 0.7000 - val_auc: 0.7889 - val_loss: 0.5624
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 210ms/step - accuracy: 0.6806 - auc: 0.7345 - loss: 0.6184 - val_accuracy: 0.7000 - val_auc: 0.7956 - val_loss: 0.5403
Epoch 6/200
[1m7/7[0m [32m━━━━━━

1it [06:06, 366.97s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 467ms/step - accuracy: 0.4151 - auc: 0.5245 - loss: 0.6924 - val_accuracy: 0.3667 - val_auc: 0.3044 - val_loss: 0.6939
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 234ms/step - accuracy: 0.5045 - auc: 0.6293 - loss: 0.6822 - val_accuracy: 0.4333 - val_auc: 0.3444 - val_loss: 0.6916
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 235ms/step - accuracy: 0.5289 - auc: 0.6129 - loss: 0.6719 - val_accuracy: 0.6000 - val_auc: 0.4444 - val_loss: 0.6861
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 239ms/step - accuracy: 0.7190 - auc: 0.7627 - loss: 0.6688 - val_accuracy: 0.6333 - val_auc: 0.6689 - val_loss: 0.6726
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 234ms/step - accuracy: 0.6226 - auc: 0.6754 - loss: 0.6405 - val_accuracy: 0.6333 - val_auc: 0.6933 - val_loss: 0.6721
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [12:13, 366.51s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 424ms/step - accuracy: 0.4553 - auc: 0.6616 - loss: 0.6907 - val_accuracy: 0.5667 - val_auc: 0.7511 - val_loss: 0.6803
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 253ms/step - accuracy: 0.5454 - auc: 0.6922 - loss: 0.6775 - val_accuracy: 0.5667 - val_auc: 0.7556 - val_loss: 0.6435
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 234ms/step - accuracy: 0.4564 - auc: 0.6855 - loss: 0.6688 - val_accuracy: 0.7000 - val_auc: 0.7889 - val_loss: 0.6056
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 247ms/step - accuracy: 0.6582 - auc: 0.7428 - loss: 0.6285 - val_accuracy: 0.7000 - val_auc: 0.7889 - val_loss: 0.5783
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 228ms/step - accuracy: 0.6908 - auc: 0.7347 - loss: 0.5972 - val_accuracy: 0.7000 - val_auc: 0.7889 - val_loss: 0.5924
Epoch 6/200
[1m7/7[0m [32m━━━━━

3it [18:33, 372.92s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 463ms/step - accuracy: 0.4807 - auc: 0.4911 - loss: 0.6924 - val_accuracy: 0.5000 - val_auc: 0.7178 - val_loss: 0.6832
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 239ms/step - accuracy: 0.5717 - auc: 0.7545 - loss: 0.6739 - val_accuracy: 0.5333 - val_auc: 0.6400 - val_loss: 0.6774
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 229ms/step - accuracy: 0.6378 - auc: 0.7486 - loss: 0.6490 - val_accuracy: 0.5667 - val_auc: 0.6578 - val_loss: 0.6583
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 224ms/step - accuracy: 0.7709 - auc: 0.7577 - loss: 0.6139 - val_accuracy: 0.6000 - val_auc: 0.7244 - val_loss: 0.6403
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 244ms/step - accuracy: 0.7709 - auc: 0.7248 - loss: 0.5854 - val_accuracy: 0.6667 - val_auc: 0.7511 - val_loss: 0.6105
Epoch 6/200
[1m7/7[0m [32m━━━━━

4it [24:43, 371.54s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 441ms/step - accuracy: 0.5114 - auc: 0.4581 - loss: 0.6958 - val_accuracy: 0.5000 - val_auc: 0.5578 - val_loss: 0.6863
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 234ms/step - accuracy: 0.5340 - auc: 0.6048 - loss: 0.6863 - val_accuracy: 0.6000 - val_auc: 0.7244 - val_loss: 0.6728
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 244ms/step - accuracy: 0.6604 - auc: 0.7774 - loss: 0.6688 - val_accuracy: 0.6667 - val_auc: 0.6733 - val_loss: 0.6593
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 231ms/step - accuracy: 0.7533 - auc: 0.7611 - loss: 0.6589 - val_accuracy: 0.7000 - val_auc: 0.7356 - val_loss: 0.6475
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 247ms/step - accuracy: 0.6637 - auc: 0.7788 - loss: 0.6523 - val_accuracy: 0.6667 - val_auc: 0.7422 - val_loss: 0.6424
Epoch 6/200
[1m7/7[0m [32m━━━━━

5it [31:02, 372.40s/it]

CPU times: user 57min 19s, sys: 20min 54s, total: 1h 18min 14s
Wall time: 31min 2s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

Epoch 1: TRAIN Accuracy = 0.529 Loss = 0.687 AUC = 0.603
Epoch 1: VAL Accuracy = 0.513 Loss = 0.68 AUC = 0.612
Epoch 2: TRAIN Accuracy = 0.569 Loss = 0.673 AUC = 0.677
Epoch 2: VAL Accuracy = 0.553 Loss = 0.659 AUC = 0.643
Epoch 3: TRAIN Accuracy = 0.609 Loss = 0.656 AUC = 0.695
Epoch 3: VAL Accuracy = 0.653 Loss = 0.638 AUC = 0.672
Epoch 4: TRAIN Accuracy = 0.691 Loss = 0.633 AUC = 0.73
Epoch 4: VAL Accuracy = 0.667 Loss = 0.62 AUC = 0.741
Epoch 5: TRAIN Accuracy = 0.691 Loss = 0.611 AUC = 0.735
Epoch 5: VAL Accuracy = 0.673 Loss = 0.612 AUC = 0.754
Epoch 6: TRAIN Accuracy = 0.706 Loss = 0.59 AUC = 0.763
Epoch 6: VAL Accuracy = 0.72 Loss = 0.581 AUC = 0.764
Epoch 7: TRAIN Accuracy = 0.734 Loss = 0.572 AUC = 0.742
Epoch 7: VAL Accuracy = 0.727 Loss = 0.561 AUC = 0.779
Epoch 8: TRAIN Accuracy = 0.757 Loss = 0.554 AUC = 0.77
Epoch 8: VAL Accuracy = 0.733 Loss = 0.554 AUC = 0.771
Epoch 9: TRAIN Accuracy = 0.754 Loss = 0.524 AUC = 0.786
Epoch 9: VAL Accuracy = 0.733 Loss = 0.555 AUC = 0.78