In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, LSTM, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-05-03 19:07:25.529624: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-03 19:07:25.535568: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-03 19:07:25.552282: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746288445.578896  950822 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746288445.586978  950822 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-03 19:07:25.614267: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 400
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(LSTM(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=False, activation='tanh'))

    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.002, weight_decay=1e-7, use_ema=True), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,0.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,0.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,0.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-05-03 19:07:29.697973: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 285ms/step - accuracy: 0.6532 - auc: 0.6711 - loss: 0.6838 - val_accuracy: 0.6333 - val_auc: 0.6289 - val_loss: 0.6727
Epoch 2/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 154ms/step - accuracy: 0.6873 - auc: 0.7283 - loss: 0.6619 - val_accuracy: 0.7333 - val_auc: 0.7111 - val_loss: 0.6290
Epoch 3/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 157ms/step - accuracy: 0.6773 - auc: 0.6961 - loss: 0.6422 - val_accuracy: 0.8000 - val_auc: 0.7200 - val_loss: 0.5931
Epoch 4/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 173ms/step - accuracy: 0.6992 - auc: 0.6706 - loss: 0.6294 - val_accuracy: 0.8000 - val_auc: 0.6889 - val_loss: 0.5607
Epoch 5/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 153ms/step - accuracy: 0.6851 - auc: 0.7001 - loss: 0.6063 - val_accuracy: 0.8000 - val_auc: 0.7111 - val_loss: 0.5346
Epoch 6/400
[1m7/7[0m [32m━━━━━━

1it [09:28, 568.05s/it]

Epoch 1/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 346ms/step - accuracy: 0.4626 - auc: 0.4722 - loss: 0.6922 - val_accuracy: 0.5000 - val_auc: 0.6578 - val_loss: 0.6761
Epoch 2/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 181ms/step - accuracy: 0.5885 - auc: 0.6944 - loss: 0.6673 - val_accuracy: 0.7333 - val_auc: 0.6667 - val_loss: 0.6365
Epoch 3/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 190ms/step - accuracy: 0.6875 - auc: 0.7874 - loss: 0.6404 - val_accuracy: 0.7667 - val_auc: 0.7867 - val_loss: 0.5895
Epoch 4/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 209ms/step - accuracy: 0.6783 - auc: 0.7238 - loss: 0.6015 - val_accuracy: 0.7667 - val_auc: 0.7178 - val_loss: 0.5508
Epoch 5/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 240ms/step - accuracy: 0.7138 - auc: 0.7023 - loss: 0.5803 - val_accuracy: 0.7667 - val_auc: 0.7933 - val_loss: 0.5239
Epoch 6/400
[1m7/7[0m [32m━━━━━

2it [19:41, 594.68s/it]

Epoch 1/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 371ms/step - accuracy: 0.5559 - auc: 0.4465 - loss: 0.6941 - val_accuracy: 0.5000 - val_auc: 0.7467 - val_loss: 0.6906
Epoch 2/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 202ms/step - accuracy: 0.6254 - auc: 0.7594 - loss: 0.6720 - val_accuracy: 0.6333 - val_auc: 0.8156 - val_loss: 0.6752
Epoch 3/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 203ms/step - accuracy: 0.6987 - auc: 0.7032 - loss: 0.6714 - val_accuracy: 0.7000 - val_auc: 0.7644 - val_loss: 0.6580
Epoch 4/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 173ms/step - accuracy: 0.6779 - auc: 0.7188 - loss: 0.6619 - val_accuracy: 0.7667 - val_auc: 0.7667 - val_loss: 0.6282
Epoch 5/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 197ms/step - accuracy: 0.7031 - auc: 0.7096 - loss: 0.6392 - val_accuracy: 0.7000 - val_auc: 0.7733 - val_loss: 0.5907
Epoch 6/400
[1m7/7[0m [32m━━━━━

3it [30:07, 609.20s/it]

Epoch 1/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 366ms/step - accuracy: 0.2650 - auc: 0.2582 - loss: 0.7116 - val_accuracy: 0.5333 - val_auc: 0.5533 - val_loss: 0.6877
Epoch 2/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 200ms/step - accuracy: 0.6670 - auc: 0.6962 - loss: 0.6815 - val_accuracy: 0.6000 - val_auc: 0.6267 - val_loss: 0.6732
Epoch 3/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 203ms/step - accuracy: 0.7610 - auc: 0.7612 - loss: 0.6512 - val_accuracy: 0.5667 - val_auc: 0.5933 - val_loss: 0.6649
Epoch 4/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 196ms/step - accuracy: 0.7714 - auc: 0.7625 - loss: 0.5845 - val_accuracy: 0.5667 - val_auc: 0.6311 - val_loss: 0.6531
Epoch 5/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 172ms/step - accuracy: 0.7714 - auc: 0.7537 - loss: 0.5199 - val_accuracy: 0.6333 - val_auc: 0.6644 - val_loss: 0.6320
Epoch 6/400
[1m7/7[0m [32m━━━━━

4it [40:01, 603.20s/it]

Epoch 1/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 378ms/step - accuracy: 0.6160 - auc: 0.5894 - loss: 0.6848 - val_accuracy: 0.7333 - val_auc: 0.6956 - val_loss: 0.6683
Epoch 2/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 191ms/step - accuracy: 0.7527 - auc: 0.8215 - loss: 0.6484 - val_accuracy: 0.7000 - val_auc: 0.7067 - val_loss: 0.6464
Epoch 3/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 184ms/step - accuracy: 0.6954 - auc: 0.7962 - loss: 0.6245 - val_accuracy: 0.7000 - val_auc: 0.7000 - val_loss: 0.6197
Epoch 4/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 210ms/step - accuracy: 0.7475 - auc: 0.7836 - loss: 0.5858 - val_accuracy: 0.7333 - val_auc: 0.6933 - val_loss: 0.5907
Epoch 5/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 179ms/step - accuracy: 0.7774 - auc: 0.8089 - loss: 0.5425 - val_accuracy: 0.7333 - val_auc: 0.7089 - val_loss: 0.5724
Epoch 6/400
[1m7/7[0m [32m━━━━━

5it [49:33, 594.70s/it]

CPU times: user 1h 26min 46s, sys: 29min 28s, total: 1h 56min 14s
Wall time: 49min 33s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

with open("./logs/LSTMv2_expanded.txt", "w") as f:
    for i in range(NUM_EPOCHS):
        f.write(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}\n")
        f.write(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}\n")

Epoch 1: TRAIN Accuracy = 0.497 Loss = 0.694 AUC = 0.487
Epoch 1: VAL Accuracy = 0.58 Loss = 0.679 AUC = 0.656
Epoch 2: TRAIN Accuracy = 0.671 Loss = 0.667 AUC = 0.725
Epoch 2: VAL Accuracy = 0.68 Loss = 0.652 AUC = 0.705
Epoch 3: TRAIN Accuracy = 0.709 Loss = 0.643 AUC = 0.743
Epoch 3: VAL Accuracy = 0.707 Loss = 0.625 AUC = 0.713
Epoch 4: TRAIN Accuracy = 0.706 Loss = 0.61 AUC = 0.736
Epoch 4: VAL Accuracy = 0.727 Loss = 0.597 AUC = 0.7
Epoch 5: TRAIN Accuracy = 0.731 Loss = 0.575 AUC = 0.74
Epoch 5: VAL Accuracy = 0.727 Loss = 0.571 AUC = 0.73
Epoch 6: TRAIN Accuracy = 0.777 Loss = 0.541 AUC = 0.755
Epoch 6: VAL Accuracy = 0.733 Loss = 0.554 AUC = 0.747
Epoch 7: TRAIN Accuracy = 0.766 Loss = 0.519 AUC = 0.763
Epoch 7: VAL Accuracy = 0.733 Loss = 0.544 AUC = 0.748
Epoch 8: TRAIN Accuracy = 0.769 Loss = 0.511 AUC = 0.762
Epoch 8: VAL Accuracy = 0.747 Loss = 0.53 AUC = 0.762
Epoch 9: TRAIN Accuracy = 0.777 Loss = 0.489 AUC = 0.778
Epoch 9: VAL Accuracy = 0.747 Loss = 0.511 AUC = 0.794
