In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, LSTM, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-04-07 16:36:24.795292: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-07 16:36:24.798884: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-07 16:36:24.810625: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744032984.837271  267896 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744032984.846778  267896 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-07 16:36:24.878342: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(LSTM(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=False, activation='tanh'))

    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.002, weight_decay=1e-7, use_ema=True), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = [1, 2]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,...,Gamma2_1,Delta_2,Theta_2,Alpha1_2,Alpha2_2,Beta1_2,Beta2_2,Gamma1_2,Gamma2_2,predefinedlabel
0,0,0.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,...,8293.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0,0.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,...,2740.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-04-07 16:36:28.352600: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 308ms/step - accuracy: 0.4860 - auc: 0.5182 - loss: 0.6921 - val_accuracy: 0.7333 - val_auc: 0.7800 - val_loss: 0.6642
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 186ms/step - accuracy: 0.6537 - auc: 0.7374 - loss: 0.6674 - val_accuracy: 0.7333 - val_auc: 0.7889 - val_loss: 0.6299
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 155ms/step - accuracy: 0.7100 - auc: 0.7395 - loss: 0.6364 - val_accuracy: 0.7333 - val_auc: 0.7778 - val_loss: 0.5909
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 157ms/step - accuracy: 0.6838 - auc: 0.6719 - loss: 0.6239 - val_accuracy: 0.7667 - val_auc: 0.7556 - val_loss: 0.5614
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 161ms/step - accuracy: 0.6976 - auc: 0.7148 - loss: 0.5955 - val_accuracy: 0.8000 - val_auc: 0.7578 - val_loss: 0.5273
Epoch 6/200
[1m7/7[0m [32m━━━━━━

1it [05:42, 342.31s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 341ms/step - accuracy: 0.4700 - auc: 0.5105 - loss: 0.6905 - val_accuracy: 0.6333 - val_auc: 0.7622 - val_loss: 0.6732
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 201ms/step - accuracy: 0.6373 - auc: 0.6615 - loss: 0.6719 - val_accuracy: 0.7000 - val_auc: 0.7467 - val_loss: 0.6473
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 207ms/step - accuracy: 0.6840 - auc: 0.6991 - loss: 0.6456 - val_accuracy: 0.7000 - val_auc: 0.7511 - val_loss: 0.6018
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 251ms/step - accuracy: 0.7205 - auc: 0.7189 - loss: 0.6042 - val_accuracy: 0.7333 - val_auc: 0.7689 - val_loss: 0.5572
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 229ms/step - accuracy: 0.7013 - auc: 0.7461 - loss: 0.5747 - val_accuracy: 0.7667 - val_auc: 0.7822 - val_loss: 0.5273
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [11:58, 362.38s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 350ms/step - accuracy: 0.5719 - auc: 0.4169 - loss: 0.6924 - val_accuracy: 0.5000 - val_auc: 0.8556 - val_loss: 0.6831
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 173ms/step - accuracy: 0.6167 - auc: 0.7403 - loss: 0.6752 - val_accuracy: 0.7667 - val_auc: 0.7622 - val_loss: 0.6644
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 224ms/step - accuracy: 0.8001 - auc: 0.7346 - loss: 0.6583 - val_accuracy: 0.7333 - val_auc: 0.7311 - val_loss: 0.6308
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 205ms/step - accuracy: 0.7326 - auc: 0.7364 - loss: 0.6249 - val_accuracy: 0.7000 - val_auc: 0.7289 - val_loss: 0.5892
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 181ms/step - accuracy: 0.7326 - auc: 0.7231 - loss: 0.5954 - val_accuracy: 0.7333 - val_auc: 0.7111 - val_loss: 0.5642
Epoch 6/200
[1m7/7[0m [32m━━━━━

3it [18:05, 364.47s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 353ms/step - accuracy: 0.3278 - auc: 0.2589 - loss: 0.7097 - val_accuracy: 0.5667 - val_auc: 0.5911 - val_loss: 0.6849
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 222ms/step - accuracy: 0.6746 - auc: 0.7971 - loss: 0.6710 - val_accuracy: 0.6333 - val_auc: 0.6444 - val_loss: 0.6647
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 213ms/step - accuracy: 0.7512 - auc: 0.7490 - loss: 0.6247 - val_accuracy: 0.5667 - val_auc: 0.6444 - val_loss: 0.6539
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 174ms/step - accuracy: 0.7136 - auc: 0.7445 - loss: 0.5646 - val_accuracy: 0.7000 - val_auc: 0.7400 - val_loss: 0.5876
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 212ms/step - accuracy: 0.7632 - auc: 0.8048 - loss: 0.5178 - val_accuracy: 0.7333 - val_auc: 0.7556 - val_loss: 0.5550
Epoch 6/200
[1m7/7[0m [32m━━━━━

4it [24:15, 366.74s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 368ms/step - accuracy: 0.5809 - auc: 0.7479 - loss: 0.6711 - val_accuracy: 0.7000 - val_auc: 0.7378 - val_loss: 0.6449
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 208ms/step - accuracy: 0.7377 - auc: 0.7971 - loss: 0.6294 - val_accuracy: 0.6667 - val_auc: 0.7244 - val_loss: 0.6162
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 207ms/step - accuracy: 0.7271 - auc: 0.7496 - loss: 0.5926 - val_accuracy: 0.6667 - val_auc: 0.7111 - val_loss: 0.5830
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 239ms/step - accuracy: 0.7991 - auc: 0.7850 - loss: 0.5386 - val_accuracy: 0.7333 - val_auc: 0.7267 - val_loss: 0.5534
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 204ms/step - accuracy: 0.8145 - auc: 0.8215 - loss: 0.5041 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5339
Epoch 6/200
[1m7/7[0m [32m━━━━━

5it [30:13, 362.71s/it]

CPU times: user 39min 34s, sys: 11min 5s, total: 50min 40s
Wall time: 30min 13s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

with open("./logs/LSTMv2_emb2.txt", "w") as f:
    for i in range(NUM_EPOCHS):
        f.write(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}\n")
        f.write(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}\n")

Epoch 1: TRAIN Accuracy = 0.529 Loss = 0.688 AUC = 0.545
Epoch 1: VAL Accuracy = 0.627 Loss = 0.67 AUC = 0.745
Epoch 2: TRAIN Accuracy = 0.663 Loss = 0.661 AUC = 0.73
Epoch 2: VAL Accuracy = 0.7 Loss = 0.644 AUC = 0.733
Epoch 3: TRAIN Accuracy = 0.714 Loss = 0.631 AUC = 0.728
Epoch 3: VAL Accuracy = 0.68 Loss = 0.612 AUC = 0.723
Epoch 4: TRAIN Accuracy = 0.731 Loss = 0.59 AUC = 0.736
Epoch 4: VAL Accuracy = 0.727 Loss = 0.57 AUC = 0.744
Epoch 5: TRAIN Accuracy = 0.743 Loss = 0.555 AUC = 0.754
Epoch 5: VAL Accuracy = 0.753 Loss = 0.542 AUC = 0.751
Epoch 6: TRAIN Accuracy = 0.76 Loss = 0.533 AUC = 0.762
Epoch 6: VAL Accuracy = 0.767 Loss = 0.524 AUC = 0.75
Epoch 7: TRAIN Accuracy = 0.763 Loss = 0.518 AUC = 0.753
Epoch 7: VAL Accuracy = 0.76 Loss = 0.502 AUC = 0.788
Epoch 8: TRAIN Accuracy = 0.774 Loss = 0.496 AUC = 0.787
Epoch 8: VAL Accuracy = 0.773 Loss = 0.491 AUC = 0.804
Epoch 9: TRAIN Accuracy = 0.774 Loss = 0.498 AUC = 0.776
Epoch 9: VAL Accuracy = 0.787 Loss = 0.483 AUC = 0.829
Ep