In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, LSTM, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-04-07 16:36:16.939915: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-07 16:36:16.945538: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-07 16:36:16.957307: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744032976.976908  267751 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744032976.982123  267751 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-07 16:36:17.008642: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(LSTM(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=False, activation='tanh'))

    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.002, weight_decay=1e-7, use_ema=True), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = [1]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,Delta_1,Theta_1,Alpha1_1,Alpha2_1,Beta1_1,Beta2_1,Gamma1_1,Gamma2_1,predefinedlabel
0,0,0.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
2,0,0.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-04-07 16:36:19.691829: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 297ms/step - accuracy: 0.6088 - auc: 0.6405 - loss: 0.6852 - val_accuracy: 0.7667 - val_auc: 0.7556 - val_loss: 0.6575
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 145ms/step - accuracy: 0.6871 - auc: 0.6851 - loss: 0.6639 - val_accuracy: 0.8000 - val_auc: 0.7689 - val_loss: 0.6171
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 131ms/step - accuracy: 0.7130 - auc: 0.6804 - loss: 0.6433 - val_accuracy: 0.8000 - val_auc: 0.7800 - val_loss: 0.5797
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 130ms/step - accuracy: 0.7228 - auc: 0.6971 - loss: 0.6129 - val_accuracy: 0.8000 - val_auc: 0.7800 - val_loss: 0.5478
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 123ms/step - accuracy: 0.7285 - auc: 0.7312 - loss: 0.5827 - val_accuracy: 0.8000 - val_auc: 0.7622 - val_loss: 0.5201
Epoch 6/200
[1m7/7[0m [32m━━━━━━

1it [05:57, 357.49s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 379ms/step - accuracy: 0.4057 - auc: 0.4556 - loss: 0.6908 - val_accuracy: 0.6000 - val_auc: 0.7044 - val_loss: 0.6690
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 222ms/step - accuracy: 0.7006 - auc: 0.7107 - loss: 0.6705 - val_accuracy: 0.7000 - val_auc: 0.7378 - val_loss: 0.6375
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 189ms/step - accuracy: 0.7149 - auc: 0.7293 - loss: 0.6473 - val_accuracy: 0.7333 - val_auc: 0.7644 - val_loss: 0.5948
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 214ms/step - accuracy: 0.7170 - auc: 0.7912 - loss: 0.5987 - val_accuracy: 0.7667 - val_auc: 0.6867 - val_loss: 0.5563
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 203ms/step - accuracy: 0.7287 - auc: 0.7576 - loss: 0.5729 - val_accuracy: 0.7667 - val_auc: 0.8022 - val_loss: 0.5285
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [12:01, 361.36s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 374ms/step - accuracy: 0.5035 - auc: 0.4822 - loss: 0.6932 - val_accuracy: 0.5000 - val_auc: 0.8289 - val_loss: 0.6792
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 184ms/step - accuracy: 0.6912 - auc: 0.7266 - loss: 0.6716 - val_accuracy: 0.7667 - val_auc: 0.7578 - val_loss: 0.6527
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 187ms/step - accuracy: 0.7326 - auc: 0.6664 - loss: 0.6596 - val_accuracy: 0.7333 - val_auc: 0.7378 - val_loss: 0.6230
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 225ms/step - accuracy: 0.7624 - auc: 0.7412 - loss: 0.6176 - val_accuracy: 0.7333 - val_auc: 0.7289 - val_loss: 0.5875
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 184ms/step - accuracy: 0.7407 - auc: 0.6858 - loss: 0.6069 - val_accuracy: 0.7333 - val_auc: 0.7133 - val_loss: 0.5633
Epoch 6/200
[1m7/7[0m [32m━━━━━

3it [17:54, 357.56s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 350ms/step - accuracy: 0.3253 - auc: 0.2376 - loss: 0.7124 - val_accuracy: 0.5000 - val_auc: 0.5133 - val_loss: 0.6882
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 200ms/step - accuracy: 0.5672 - auc: 0.7144 - loss: 0.6621 - val_accuracy: 0.5000 - val_auc: 0.6022 - val_loss: 0.6725
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 196ms/step - accuracy: 0.7038 - auc: 0.7394 - loss: 0.6195 - val_accuracy: 0.6333 - val_auc: 0.6111 - val_loss: 0.6485
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 217ms/step - accuracy: 0.7596 - auc: 0.7650 - loss: 0.5627 - val_accuracy: 0.6333 - val_auc: 0.6867 - val_loss: 0.6192
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 211ms/step - accuracy: 0.7632 - auc: 0.7376 - loss: 0.5050 - val_accuracy: 0.7000 - val_auc: 0.7378 - val_loss: 0.5885
Epoch 6/200
[1m7/7[0m [32m━━━━━

4it [24:10, 364.87s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 385ms/step - accuracy: 0.6213 - auc: 0.6281 - loss: 0.6861 - val_accuracy: 0.6333 - val_auc: 0.7044 - val_loss: 0.6627
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 207ms/step - accuracy: 0.7762 - auc: 0.8373 - loss: 0.6378 - val_accuracy: 0.6000 - val_auc: 0.7133 - val_loss: 0.6405
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 197ms/step - accuracy: 0.7874 - auc: 0.7787 - loss: 0.6038 - val_accuracy: 0.6667 - val_auc: 0.7111 - val_loss: 0.6187
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 185ms/step - accuracy: 0.7762 - auc: 0.8053 - loss: 0.5497 - val_accuracy: 0.6667 - val_auc: 0.7111 - val_loss: 0.6017
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 189ms/step - accuracy: 0.8104 - auc: 0.8070 - loss: 0.5087 - val_accuracy: 0.6667 - val_auc: 0.7067 - val_loss: 0.6000
Epoch 6/200
[1m7/7[0m [32m━━━━━

5it [30:27, 365.49s/it]

CPU times: user 39min 37s, sys: 10min 57s, total: 50min 34s
Wall time: 30min 27s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

with open("./logs/LSTMv2_emb1.txt", "w") as f:
    for i in range(NUM_EPOCHS):
        f.write(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}\n")
        f.write(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}\n")

Epoch 1: TRAIN Accuracy = 0.526 Loss = 0.69 AUC = 0.538
Epoch 1: VAL Accuracy = 0.6 Loss = 0.671 AUC = 0.701
Epoch 2: TRAIN Accuracy = 0.677 Loss = 0.66 AUC = 0.733
Epoch 2: VAL Accuracy = 0.673 Loss = 0.644 AUC = 0.716
Epoch 3: TRAIN Accuracy = 0.729 Loss = 0.632 AUC = 0.726
Epoch 3: VAL Accuracy = 0.713 Loss = 0.613 AUC = 0.721
Epoch 4: TRAIN Accuracy = 0.746 Loss = 0.585 AUC = 0.76
Epoch 4: VAL Accuracy = 0.72 Loss = 0.583 AUC = 0.719
Epoch 5: TRAIN Accuracy = 0.757 Loss = 0.551 AUC = 0.751
Epoch 5: VAL Accuracy = 0.733 Loss = 0.56 AUC = 0.744
Epoch 6: TRAIN Accuracy = 0.766 Loss = 0.533 AUC = 0.771
Epoch 6: VAL Accuracy = 0.727 Loss = 0.546 AUC = 0.744
Epoch 7: TRAIN Accuracy = 0.766 Loss = 0.515 AUC = 0.753
Epoch 7: VAL Accuracy = 0.747 Loss = 0.528 AUC = 0.753
Epoch 8: TRAIN Accuracy = 0.769 Loss = 0.503 AUC = 0.777
Epoch 8: VAL Accuracy = 0.74 Loss = 0.515 AUC = 0.768
Epoch 9: TRAIN Accuracy = 0.771 Loss = 0.491 AUC = 0.79
Epoch 9: VAL Accuracy = 0.767 Loss = 0.491 AUC = 0.812
E