In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, LSTM, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-04-07 16:37:00.962931: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-07 16:37:00.971566: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-07 16:37:00.997163: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744033021.036906  270979 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744033021.049027  270979 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-07 16:37:01.094517: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(LSTM(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=False, activation='tanh'))

    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.002, weight_decay=1e-7, use_ema=True), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = []
USE_DIFF = True
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

if USE_DIFF:
    for feature_name in FEATURES:
        data[feature_name] = data[feature_name] - data.groupby(IDS)[feature_name].shift(1).fillna(0)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,0.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,0.0,-228176.0,-62529.0,-32296.0,-21751.0,-25200.0,-41410.0,-27935.0,-5553.0,0.0
2,0,0.0,684566.0,355662.0,200560.0,59867.0,33547.0,126849.0,51950.0,22614.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-04-07 16:37:06.637980: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 390ms/step - accuracy: 0.5600 - auc: 0.5725 - loss: 0.6901 - val_accuracy: 0.7667 - val_auc: 0.7467 - val_loss: 0.6761
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 204ms/step - accuracy: 0.6739 - auc: 0.6356 - loss: 0.6816 - val_accuracy: 0.7667 - val_auc: 0.7000 - val_loss: 0.6563
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 175ms/step - accuracy: 0.6894 - auc: 0.6831 - loss: 0.6601 - val_accuracy: 0.7667 - val_auc: 0.7222 - val_loss: 0.6244
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 219ms/step - accuracy: 0.7132 - auc: 0.6735 - loss: 0.6327 - val_accuracy: 0.7333 - val_auc: 0.7178 - val_loss: 0.5933
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 244ms/step - accuracy: 0.7167 - auc: 0.6711 - loss: 0.6062 - val_accuracy: 0.7667 - val_auc: 0.6956 - val_loss: 0.5541
Epoch 6/200
[1m7/7[0m [32m━━━━━

1it [06:03, 363.80s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 375ms/step - accuracy: 0.5122 - auc: 0.4788 - loss: 0.6919 - val_accuracy: 0.5333 - val_auc: 0.7356 - val_loss: 0.6822
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 181ms/step - accuracy: 0.5671 - auc: 0.6788 - loss: 0.6813 - val_accuracy: 0.6667 - val_auc: 0.7311 - val_loss: 0.6627
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 224ms/step - accuracy: 0.6283 - auc: 0.7589 - loss: 0.6591 - val_accuracy: 0.7333 - val_auc: 0.7311 - val_loss: 0.6262
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 183ms/step - accuracy: 0.6953 - auc: 0.7243 - loss: 0.6200 - val_accuracy: 0.7333 - val_auc: 0.7422 - val_loss: 0.5810
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 182ms/step - accuracy: 0.7170 - auc: 0.7118 - loss: 0.5888 - val_accuracy: 0.7333 - val_auc: 0.7178 - val_loss: 0.5554
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [12:21, 372.07s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 376ms/step - accuracy: 0.5594 - auc: 0.2881 - loss: 0.6966 - val_accuracy: 0.5000 - val_auc: 0.7467 - val_loss: 0.6886
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 187ms/step - accuracy: 0.6144 - auc: 0.7537 - loss: 0.6823 - val_accuracy: 0.6667 - val_auc: 0.7022 - val_loss: 0.6735
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 212ms/step - accuracy: 0.7568 - auc: 0.6779 - loss: 0.6696 - val_accuracy: 0.6667 - val_auc: 0.6822 - val_loss: 0.6490
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 154ms/step - accuracy: 0.7282 - auc: 0.6889 - loss: 0.6440 - val_accuracy: 0.6667 - val_auc: 0.7156 - val_loss: 0.6224
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 237ms/step - accuracy: 0.7568 - auc: 0.7497 - loss: 0.5928 - val_accuracy: 0.6667 - val_auc: 0.7311 - val_loss: 0.6121
Epoch 6/200
[1m7/7[0m [32m━━━━━

3it [18:10, 361.51s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 402ms/step - accuracy: 0.5864 - auc: 0.6699 - loss: 0.6798 - val_accuracy: 0.5333 - val_auc: 0.6267 - val_loss: 0.6780
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 222ms/step - accuracy: 0.6054 - auc: 0.7200 - loss: 0.6376 - val_accuracy: 0.5667 - val_auc: 0.6356 - val_loss: 0.6689
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 216ms/step - accuracy: 0.7519 - auc: 0.7436 - loss: 0.5894 - val_accuracy: 0.6000 - val_auc: 0.6289 - val_loss: 0.6706
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 195ms/step - accuracy: 0.7632 - auc: 0.7632 - loss: 0.5263 - val_accuracy: 0.6333 - val_auc: 0.6489 - val_loss: 0.6458
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 184ms/step - accuracy: 0.7745 - auc: 0.7838 - loss: 0.5089 - val_accuracy: 0.6667 - val_auc: 0.6756 - val_loss: 0.6417
Epoch 6/200
[1m7/7[0m [32m━━━━━

4it [24:24, 366.32s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 392ms/step - accuracy: 0.7459 - auc: 0.7530 - loss: 0.6762 - val_accuracy: 0.6667 - val_auc: 0.6556 - val_loss: 0.6610
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 209ms/step - accuracy: 0.6900 - auc: 0.7908 - loss: 0.6466 - val_accuracy: 0.6667 - val_auc: 0.6400 - val_loss: 0.6379
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 216ms/step - accuracy: 0.7459 - auc: 0.7713 - loss: 0.6095 - val_accuracy: 0.7000 - val_auc: 0.6444 - val_loss: 0.6157
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 189ms/step - accuracy: 0.7836 - auc: 0.7965 - loss: 0.5706 - val_accuracy: 0.7333 - val_auc: 0.6511 - val_loss: 0.5935
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 185ms/step - accuracy: 0.8104 - auc: 0.7858 - loss: 0.5332 - val_accuracy: 0.7333 - val_auc: 0.6778 - val_loss: 0.5753
Epoch 6/200
[1m7/7[0m [32m━━━━━

5it [30:22, 364.52s/it]

CPU times: user 39min 45s, sys: 10min 58s, total: 50min 43s
Wall time: 30min 22s
Parser   : 308 ms





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

with open("./logs/LSTMv2_emb4.txt", "w") as f:
    for i in range(NUM_EPOCHS):
        f.write(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}\n")
        f.write(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}\n")

Epoch 1: TRAIN Accuracy = 0.569 Loss = 0.686 AUC = 0.56
Epoch 1: VAL Accuracy = 0.6 Loss = 0.677 AUC = 0.702
Epoch 2: TRAIN Accuracy = 0.634 Loss = 0.665 AUC = 0.705
Epoch 2: VAL Accuracy = 0.667 Loss = 0.66 AUC = 0.682
Epoch 3: TRAIN Accuracy = 0.711 Loss = 0.636 AUC = 0.734
Epoch 3: VAL Accuracy = 0.693 Loss = 0.637 AUC = 0.682
Epoch 4: TRAIN Accuracy = 0.734 Loss = 0.6 AUC = 0.714
Epoch 4: VAL Accuracy = 0.7 Loss = 0.607 AUC = 0.695
Epoch 5: TRAIN Accuracy = 0.749 Loss = 0.571 AUC = 0.733
Epoch 5: VAL Accuracy = 0.713 Loss = 0.588 AUC = 0.7
Epoch 6: TRAIN Accuracy = 0.751 Loss = 0.535 AUC = 0.76
Epoch 6: VAL Accuracy = 0.713 Loss = 0.572 AUC = 0.728
Epoch 7: TRAIN Accuracy = 0.754 Loss = 0.523 AUC = 0.763
Epoch 7: VAL Accuracy = 0.72 Loss = 0.557 AUC = 0.742
Epoch 8: TRAIN Accuracy = 0.766 Loss = 0.513 AUC = 0.762
Epoch 8: VAL Accuracy = 0.727 Loss = 0.55 AUC = 0.758
Epoch 9: TRAIN Accuracy = 0.78 Loss = 0.49 AUC = 0.785
Epoch 9: VAL Accuracy = 0.72 Loss = 0.543 AUC = 0.771
Epoch 10