In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, LSTM, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-04-07 16:36:44.313811: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-07 16:36:44.320385: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-07 16:36:44.343987: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744033004.382712  269404 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744033004.395210  269404 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-07 16:36:44.450074: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(LSTM(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=False, activation='tanh'))

    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.002, weight_decay=1e-7, use_ema=True), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = [2]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,Delta_2,Theta_2,Alpha1_2,Alpha2_2,Beta1_2,Beta2_2,Gamma1_2,Gamma2_2,predefinedlabel
0,0,0.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0,0.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-04-07 16:36:49.618665: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 437ms/step - accuracy: 0.6011 - auc: 0.6363 - loss: 0.6856 - val_accuracy: 0.7667 - val_auc: 0.7733 - val_loss: 0.6544
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 239ms/step - accuracy: 0.7192 - auc: 0.6413 - loss: 0.6625 - val_accuracy: 0.7667 - val_auc: 0.7956 - val_loss: 0.6141
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 230ms/step - accuracy: 0.6970 - auc: 0.6762 - loss: 0.6390 - val_accuracy: 0.7667 - val_auc: 0.7800 - val_loss: 0.5768
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 229ms/step - accuracy: 0.7111 - auc: 0.7156 - loss: 0.6072 - val_accuracy: 0.8000 - val_auc: 0.7956 - val_loss: 0.5489
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 230ms/step - accuracy: 0.7151 - auc: 0.7079 - loss: 0.5940 - val_accuracy: 0.8000 - val_auc: 0.7867 - val_loss: 0.5226
Epoch 6/200
[1m7/7[0m [32m━━━━━

1it [06:20, 380.02s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 364ms/step - accuracy: 0.3821 - auc: 0.4796 - loss: 0.6893 - val_accuracy: 0.5333 - val_auc: 0.7333 - val_loss: 0.6665
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 195ms/step - accuracy: 0.6441 - auc: 0.6831 - loss: 0.6691 - val_accuracy: 0.7000 - val_auc: 0.7000 - val_loss: 0.6362
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 194ms/step - accuracy: 0.7021 - auc: 0.7451 - loss: 0.6375 - val_accuracy: 0.7000 - val_auc: 0.7511 - val_loss: 0.5938
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 203ms/step - accuracy: 0.7019 - auc: 0.7786 - loss: 0.5996 - val_accuracy: 0.7000 - val_auc: 0.7644 - val_loss: 0.5587
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 196ms/step - accuracy: 0.6932 - auc: 0.7592 - loss: 0.5707 - val_accuracy: 0.7667 - val_auc: 0.7489 - val_loss: 0.5258
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [12:10, 362.73s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 348ms/step - accuracy: 0.5377 - auc: 0.4264 - loss: 0.6943 - val_accuracy: 0.5000 - val_auc: 0.8489 - val_loss: 0.6825
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 207ms/step - accuracy: 0.6774 - auc: 0.7144 - loss: 0.6767 - val_accuracy: 0.8000 - val_auc: 0.7600 - val_loss: 0.6574
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 205ms/step - accuracy: 0.7294 - auc: 0.6772 - loss: 0.6636 - val_accuracy: 0.7000 - val_auc: 0.7400 - val_loss: 0.6259
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 195ms/step - accuracy: 0.7624 - auc: 0.7049 - loss: 0.6258 - val_accuracy: 0.7000 - val_auc: 0.7400 - val_loss: 0.5840
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 227ms/step - accuracy: 0.7351 - auc: 0.6827 - loss: 0.6027 - val_accuracy: 0.7000 - val_auc: 0.7356 - val_loss: 0.5684
Epoch 6/200
[1m7/7[0m [32m━━━━━

3it [18:10, 361.62s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 370ms/step - accuracy: 0.3093 - auc: 0.2725 - loss: 0.7093 - val_accuracy: 0.4333 - val_auc: 0.4867 - val_loss: 0.6887
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 229ms/step - accuracy: 0.6161 - auc: 0.7408 - loss: 0.6618 - val_accuracy: 0.4333 - val_auc: 0.5289 - val_loss: 0.6805
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 224ms/step - accuracy: 0.7359 - auc: 0.7593 - loss: 0.6183 - val_accuracy: 0.5667 - val_auc: 0.5978 - val_loss: 0.6514
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 211ms/step - accuracy: 0.7632 - auc: 0.7699 - loss: 0.5610 - val_accuracy: 0.6000 - val_auc: 0.6578 - val_loss: 0.6228
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 202ms/step - accuracy: 0.7751 - auc: 0.7826 - loss: 0.5094 - val_accuracy: 0.7000 - val_auc: 0.7156 - val_loss: 0.5959
Epoch 6/200
[1m7/7[0m [32m━━━━━

4it [24:13, 361.85s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 370ms/step - accuracy: 0.6260 - auc: 0.6448 - loss: 0.6850 - val_accuracy: 0.7667 - val_auc: 0.7556 - val_loss: 0.6541
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 212ms/step - accuracy: 0.7805 - auc: 0.8419 - loss: 0.6412 - val_accuracy: 0.7000 - val_auc: 0.7511 - val_loss: 0.6296
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 163ms/step - accuracy: 0.7257 - auc: 0.7930 - loss: 0.6085 - val_accuracy: 0.7000 - val_auc: 0.7600 - val_loss: 0.6047
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 208ms/step - accuracy: 0.7762 - auc: 0.7790 - loss: 0.5577 - val_accuracy: 0.7000 - val_auc: 0.7644 - val_loss: 0.5980
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 184ms/step - accuracy: 0.8176 - auc: 0.8297 - loss: 0.4919 - val_accuracy: 0.7000 - val_auc: 0.7556 - val_loss: 0.6018
Epoch 6/200
[1m7/7[0m [32m━━━━━

5it [30:11, 362.35s/it]

CPU times: user 39min 9s, sys: 10min 50s, total: 50min
Wall time: 30min 11s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

with open("./logs/LSTMv2_emb3.txt", "w") as f:
    for i in range(NUM_EPOCHS):
        f.write(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}\n")
        f.write(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}\n")

Epoch 1: TRAIN Accuracy = 0.531 Loss = 0.689 AUC = 0.54
Epoch 1: VAL Accuracy = 0.6 Loss = 0.669 AUC = 0.72
Epoch 2: TRAIN Accuracy = 0.669 Loss = 0.66 AUC = 0.724
Epoch 2: VAL Accuracy = 0.68 Loss = 0.644 AUC = 0.707
Epoch 3: TRAIN Accuracy = 0.717 Loss = 0.629 AUC = 0.738
Epoch 3: VAL Accuracy = 0.687 Loss = 0.611 AUC = 0.726
Epoch 4: TRAIN Accuracy = 0.74 Loss = 0.587 AUC = 0.748
Epoch 4: VAL Accuracy = 0.7 Loss = 0.582 AUC = 0.744
Epoch 5: TRAIN Accuracy = 0.743 Loss = 0.555 AUC = 0.758
Epoch 5: VAL Accuracy = 0.733 Loss = 0.563 AUC = 0.748
Epoch 6: TRAIN Accuracy = 0.76 Loss = 0.532 AUC = 0.773
Epoch 6: VAL Accuracy = 0.733 Loss = 0.553 AUC = 0.748
Epoch 7: TRAIN Accuracy = 0.757 Loss = 0.521 AUC = 0.75
Epoch 7: VAL Accuracy = 0.74 Loss = 0.545 AUC = 0.78
Epoch 8: TRAIN Accuracy = 0.766 Loss = 0.506 AUC = 0.788
Epoch 8: VAL Accuracy = 0.74 Loss = 0.535 AUC = 0.787
Epoch 9: TRAIN Accuracy = 0.769 Loss = 0.497 AUC = 0.789
Epoch 9: VAL Accuracy = 0.753 Loss = 0.514 AUC = 0.804
Epoch 