In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout, LSTM
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-21 20:08:21.691992: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 20:08:21.699567: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 20:08:21.725232: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742576901.779437 1088846 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742576901.795089 1088846 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-21 20:08:21.849731: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=False))

    model.add(Dropout(0.2))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = []
USE_DIFF = True
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

if USE_DIFF:
    for feature_name in FEATURES:
        data[feature_name] = data[feature_name] - data.groupby(IDS)[feature_name].shift(1).fillna(0)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,0.0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,0.0,-328.0,-228176.0,-62529.0,-32296.0,-21751.0,-25200.0,-41410.0,-27935.0,-5553.0,0.0
2,0,0.0,151.0,684566.0,355662.0,200560.0,59867.0,33547.0,126849.0,51950.0,22614.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-21 20:08:28.548278: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)


0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 461ms/step - accuracy: 0.5065 - auc: 0.6152 - loss: 0.6859 - val_accuracy: 0.6333 - val_auc: 0.7756 - val_loss: 0.6639
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 234ms/step - accuracy: 0.5618 - auc: 0.6211 - loss: 0.6749 - val_accuracy: 0.4667 - val_auc: 0.7133 - val_loss: 0.6375
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 233ms/step - accuracy: 0.5929 - auc: 0.7106 - loss: 0.6540 - val_accuracy: 0.5333 - val_auc: 0.7200 - val_loss: 0.6061
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 246ms/step - accuracy: 0.6373 - auc: 0.8168 - loss: 0.6096 - val_accuracy: 0.5667 - val_auc: 0.7267 - val_loss: 0.5952
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 244ms/step - accuracy: 0.5317 - auc: 0.6505 - loss: 0.6149 - val_accuracy: 0.6333 - val_auc: 0.7333 - val_loss: 0.5457
Epoch 6/200
[1m7/7[0m [32m━━━━━

1it [05:55, 355.82s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 437ms/step - accuracy: 0.5173 - auc: 0.5751 - loss: 0.6873 - val_accuracy: 0.5000 - val_auc: 0.6644 - val_loss: 0.6824
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 216ms/step - accuracy: 0.5822 - auc: 0.6360 - loss: 0.6747 - val_accuracy: 0.5000 - val_auc: 0.6756 - val_loss: 0.6693
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 215ms/step - accuracy: 0.6261 - auc: 0.6671 - loss: 0.6556 - val_accuracy: 0.6667 - val_auc: 0.7511 - val_loss: 0.6490
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 243ms/step - accuracy: 0.5446 - auc: 0.7418 - loss: 0.6374 - val_accuracy: 0.7000 - val_auc: 0.7067 - val_loss: 0.6292
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 221ms/step - accuracy: 0.6498 - auc: 0.6599 - loss: 0.6140 - val_accuracy: 0.7000 - val_auc: 0.7067 - val_loss: 0.6061
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [11:59, 360.64s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 372ms/step - accuracy: 0.4240 - auc: 0.6014 - loss: 0.6928 - val_accuracy: 0.5333 - val_auc: 0.7356 - val_loss: 0.6788
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 226ms/step - accuracy: 0.5640 - auc: 0.6894 - loss: 0.6822 - val_accuracy: 0.4667 - val_auc: 0.7067 - val_loss: 0.6712
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 233ms/step - accuracy: 0.5909 - auc: 0.6484 - loss: 0.6757 - val_accuracy: 0.4667 - val_auc: 0.6556 - val_loss: 0.6606
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 235ms/step - accuracy: 0.5126 - auc: 0.6744 - loss: 0.6697 - val_accuracy: 0.5667 - val_auc: 0.7800 - val_loss: 0.6351
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 224ms/step - accuracy: 0.5746 - auc: 0.7074 - loss: 0.6435 - val_accuracy: 0.4667 - val_auc: 0.7067 - val_loss: 0.6363
Epoch 6/200
[1m7/7[0m [32m━━━━━━

3it [18:15, 367.32s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 378ms/step - accuracy: 0.5607 - auc: 0.7373 - loss: 0.6768 - val_accuracy: 0.5667 - val_auc: 0.7067 - val_loss: 0.6650
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 218ms/step - accuracy: 0.6079 - auc: 0.7318 - loss: 0.6435 - val_accuracy: 0.6667 - val_auc: 0.7600 - val_loss: 0.6346
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 217ms/step - accuracy: 0.7571 - auc: 0.7719 - loss: 0.5967 - val_accuracy: 0.7667 - val_auc: 0.7800 - val_loss: 0.6083
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 270ms/step - accuracy: 0.7843 - auc: 0.7188 - loss: 0.5839 - val_accuracy: 0.7333 - val_auc: 0.8200 - val_loss: 0.5799
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 245ms/step - accuracy: 0.7826 - auc: 0.7429 - loss: 0.5364 - val_accuracy: 0.7333 - val_auc: 0.8378 - val_loss: 0.5621
Epoch 6/200
[1m7/7[0m [32m━━━━━━

4it [24:14, 364.14s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 401ms/step - accuracy: 0.4957 - auc: 0.4864 - loss: 0.6923 - val_accuracy: 0.6333 - val_auc: 0.6956 - val_loss: 0.6796
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 234ms/step - accuracy: 0.6867 - auc: 0.7051 - loss: 0.6804 - val_accuracy: 0.6667 - val_auc: 0.6511 - val_loss: 0.6655
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 214ms/step - accuracy: 0.7551 - auc: 0.8064 - loss: 0.6543 - val_accuracy: 0.7000 - val_auc: 0.6400 - val_loss: 0.6497
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 228ms/step - accuracy: 0.7459 - auc: 0.8037 - loss: 0.6206 - val_accuracy: 0.7000 - val_auc: 0.7000 - val_loss: 0.6266
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 241ms/step - accuracy: 0.7955 - auc: 0.7997 - loss: 0.5885 - val_accuracy: 0.7000 - val_auc: 0.7089 - val_loss: 0.6055
Epoch 6/200
[1m7/7[0m [32m━━━━━

5it [30:15, 363.08s/it]

CPU times: user 55min 4s, sys: 19min 42s, total: 1h 14min 47s
Wall time: 30min 15s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

Epoch 1: TRAIN Accuracy = 0.529 Loss = 0.683 AUC = 0.638
Epoch 1: VAL Accuracy = 0.573 Loss = 0.674 AUC = 0.716
Epoch 2: TRAIN Accuracy = 0.611 Loss = 0.669 AUC = 0.691
Epoch 2: VAL Accuracy = 0.553 Loss = 0.656 AUC = 0.701
Epoch 3: TRAIN Accuracy = 0.68 Loss = 0.644 AUC = 0.734
Epoch 3: VAL Accuracy = 0.627 Loss = 0.635 AUC = 0.709
Epoch 4: TRAIN Accuracy = 0.651 Loss = 0.617 AUC = 0.745
Epoch 4: VAL Accuracy = 0.653 Loss = 0.613 AUC = 0.747
Epoch 5: TRAIN Accuracy = 0.683 Loss = 0.59 AUC = 0.738
Epoch 5: VAL Accuracy = 0.647 Loss = 0.591 AUC = 0.739
Epoch 6: TRAIN Accuracy = 0.723 Loss = 0.577 AUC = 0.741
Epoch 6: VAL Accuracy = 0.707 Loss = 0.571 AUC = 0.754
Epoch 7: TRAIN Accuracy = 0.754 Loss = 0.55 AUC = 0.766
Epoch 7: VAL Accuracy = 0.74 Loss = 0.548 AUC = 0.766
Epoch 8: TRAIN Accuracy = 0.749 Loss = 0.526 AUC = 0.759
Epoch 8: VAL Accuracy = 0.753 Loss = 0.531 AUC = 0.79
Epoch 9: TRAIN Accuracy = 0.757 Loss = 0.517 AUC = 0.764
Epoch 9: VAL Accuracy = 0.767 Loss = 0.521 AUC = 0.8