In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout, LSTM
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-21 11:22:41.068300: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 11:22:41.075452: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 11:22:41.092435: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742545361.122878  659117 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742545361.136511  659117 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-21 11:22:41.170754: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=False))

    model.add(Dropout(0.2))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = [1, 2]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,...,Raw_2,Delta_2,Theta_2,Alpha1_2,Alpha2_2,Beta1_2,Beta2_2,Gamma1_2,Gamma2_2,predefinedlabel
0,0,0.0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0.0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0,0.0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,...,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-21 11:22:45.296175: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)


0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 414ms/step - accuracy: 0.5041 - auc: 0.6258 - loss: 0.6817 - val_accuracy: 0.5333 - val_auc: 0.7156 - val_loss: 0.6646
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 231ms/step - accuracy: 0.5849 - auc: 0.6401 - loss: 0.6680 - val_accuracy: 0.6333 - val_auc: 0.7467 - val_loss: 0.6327
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 206ms/step - accuracy: 0.6837 - auc: 0.6953 - loss: 0.6449 - val_accuracy: 0.7000 - val_auc: 0.7800 - val_loss: 0.5931
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 213ms/step - accuracy: 0.7398 - auc: 0.7136 - loss: 0.6315 - val_accuracy: 0.7000 - val_auc: 0.7844 - val_loss: 0.5664
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 211ms/step - accuracy: 0.6875 - auc: 0.6596 - loss: 0.5964 - val_accuracy: 0.7000 - val_auc: 0.7622 - val_loss: 0.5459
Epoch 6/200
[1m7/7[0m [32m━━━━━━

1it [05:46, 346.44s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 356ms/step - accuracy: 0.4626 - auc: 0.6089 - loss: 0.6887 - val_accuracy: 0.4667 - val_auc: 0.7489 - val_loss: 0.6822
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 227ms/step - accuracy: 0.5654 - auc: 0.6986 - loss: 0.6776 - val_accuracy: 0.5000 - val_auc: 0.7089 - val_loss: 0.6659
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 238ms/step - accuracy: 0.6221 - auc: 0.7526 - loss: 0.6497 - val_accuracy: 0.7000 - val_auc: 0.7200 - val_loss: 0.6425
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 225ms/step - accuracy: 0.6831 - auc: 0.7644 - loss: 0.6342 - val_accuracy: 0.7333 - val_auc: 0.7844 - val_loss: 0.6104
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 235ms/step - accuracy: 0.6819 - auc: 0.7211 - loss: 0.6045 - val_accuracy: 0.7667 - val_auc: 0.8200 - val_loss: 0.5799
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [11:52, 358.13s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 434ms/step - accuracy: 0.4085 - auc: 0.6734 - loss: 0.6904 - val_accuracy: 0.5000 - val_auc: 0.7400 - val_loss: 0.6813
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 219ms/step - accuracy: 0.5322 - auc: 0.6661 - loss: 0.6840 - val_accuracy: 0.7000 - val_auc: 0.7800 - val_loss: 0.6715
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 194ms/step - accuracy: 0.6491 - auc: 0.7369 - loss: 0.6729 - val_accuracy: 0.7000 - val_auc: 0.8222 - val_loss: 0.6381
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 237ms/step - accuracy: 0.5360 - auc: 0.7730 - loss: 0.6475 - val_accuracy: 0.7000 - val_auc: 0.7889 - val_loss: 0.6175
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 219ms/step - accuracy: 0.5761 - auc: 0.7470 - loss: 0.6383 - val_accuracy: 0.7000 - val_auc: 0.7889 - val_loss: 0.5958
Epoch 6/200
[1m7/7[0m [32m━━━━━

3it [17:42, 354.10s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 434ms/step - accuracy: 0.5493 - auc: 0.6030 - loss: 0.6871 - val_accuracy: 0.5667 - val_auc: 0.5333 - val_loss: 0.6870
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 222ms/step - accuracy: 0.7010 - auc: 0.7077 - loss: 0.6730 - val_accuracy: 0.6000 - val_auc: 0.6044 - val_loss: 0.6740
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 234ms/step - accuracy: 0.6903 - auc: 0.7298 - loss: 0.6451 - val_accuracy: 0.6000 - val_auc: 0.6667 - val_loss: 0.6628
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 230ms/step - accuracy: 0.7010 - auc: 0.7248 - loss: 0.6406 - val_accuracy: 0.6667 - val_auc: 0.7000 - val_loss: 0.6386
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 234ms/step - accuracy: 0.7565 - auc: 0.7527 - loss: 0.6119 - val_accuracy: 0.7000 - val_auc: 0.6644 - val_loss: 0.6296
Epoch 6/200
[1m7/7[0m [32m━━━━━

4it [23:42, 356.67s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 407ms/step - accuracy: 0.4109 - auc: 0.2890 - loss: 0.7010 - val_accuracy: 0.6000 - val_auc: 0.6333 - val_loss: 0.6914
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 245ms/step - accuracy: 0.6210 - auc: 0.5613 - loss: 0.6874 - val_accuracy: 0.6333 - val_auc: 0.7133 - val_loss: 0.6822
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 216ms/step - accuracy: 0.6688 - auc: 0.6281 - loss: 0.6788 - val_accuracy: 0.6333 - val_auc: 0.6622 - val_loss: 0.6697
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 202ms/step - accuracy: 0.7004 - auc: 0.7213 - loss: 0.6635 - val_accuracy: 0.6667 - val_auc: 0.6978 - val_loss: 0.6536
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 220ms/step - accuracy: 0.7838 - auc: 0.7743 - loss: 0.6349 - val_accuracy: 0.6000 - val_auc: 0.6667 - val_loss: 0.6419
Epoch 6/200
[1m7/7[0m [32m━━━━━

5it [29:41, 356.35s/it]

CPU times: user 55min 20s, sys: 19min 47s, total: 1h 15min 7s
Wall time: 29min 41s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

Epoch 1: TRAIN Accuracy = 0.491 Loss = 0.685 AUC = 0.575
Epoch 1: VAL Accuracy = 0.533 Loss = 0.681 AUC = 0.674
Epoch 2: TRAIN Accuracy = 0.589 Loss = 0.675 AUC = 0.654
Epoch 2: VAL Accuracy = 0.613 Loss = 0.665 AUC = 0.711
Epoch 3: TRAIN Accuracy = 0.649 Loss = 0.655 AUC = 0.703
Epoch 3: VAL Accuracy = 0.667 Loss = 0.641 AUC = 0.73
Epoch 4: TRAIN Accuracy = 0.666 Loss = 0.635 AUC = 0.731
Epoch 4: VAL Accuracy = 0.693 Loss = 0.617 AUC = 0.751
Epoch 5: TRAIN Accuracy = 0.691 Loss = 0.61 AUC = 0.73
Epoch 5: VAL Accuracy = 0.693 Loss = 0.599 AUC = 0.74
Epoch 6: TRAIN Accuracy = 0.734 Loss = 0.582 AUC = 0.746
Epoch 6: VAL Accuracy = 0.707 Loss = 0.574 AUC = 0.771
Epoch 7: TRAIN Accuracy = 0.731 Loss = 0.563 AUC = 0.764
Epoch 7: VAL Accuracy = 0.753 Loss = 0.546 AUC = 0.781
Epoch 8: TRAIN Accuracy = 0.746 Loss = 0.525 AUC = 0.762
Epoch 8: VAL Accuracy = 0.76 Loss = 0.531 AUC = 0.786
Epoch 9: TRAIN Accuracy = 0.751 Loss = 0.514 AUC = 0.776
Epoch 9: VAL Accuracy = 0.76 Loss = 0.523 AUC = 0.79