In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Conv1D, Dense, LSTM, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-04-07 22:51:57.859044: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-07 22:51:57.867123: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-07 22:51:57.884949: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744055517.919288  730871 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744055517.930470  730871 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-07 22:51:57.969397: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Conv1D(filters=50, kernel_size=3, activation='relu', padding='causal'))

    model.add(Dropout(0.2))
    model.add(LSTM(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=False, activation='tanh'))

    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.002, weight_decay=1e-7, use_ema=True), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,0.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,0.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,0.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-04-07 22:52:02.018970: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 288ms/step - accuracy: 0.5848 - auc: 0.6166 - loss: 0.6832 - val_accuracy: 0.7000 - val_auc: 0.6911 - val_loss: 0.6551
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 152ms/step - accuracy: 0.6950 - auc: 0.6843 - loss: 0.6527 - val_accuracy: 0.7000 - val_auc: 0.7089 - val_loss: 0.6064
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 158ms/step - accuracy: 0.6913 - auc: 0.6493 - loss: 0.6157 - val_accuracy: 0.7000 - val_auc: 0.7222 - val_loss: 0.5850
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 159ms/step - accuracy: 0.7068 - auc: 0.6060 - loss: 0.6009 - val_accuracy: 0.7000 - val_auc: 0.7756 - val_loss: 0.5656
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 155ms/step - accuracy: 0.6975 - auc: 0.6894 - loss: 0.5967 - val_accuracy: 0.7000 - val_auc: 0.7378 - val_loss: 0.5526
Epoch 6/200
[1m7/7[0m [32m━━━━━━

1it [04:14, 254.56s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 296ms/step - accuracy: 0.4209 - auc: 0.4312 - loss: 0.6958 - val_accuracy: 0.6667 - val_auc: 0.6933 - val_loss: 0.6864
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 165ms/step - accuracy: 0.6991 - auc: 0.7323 - loss: 0.6829 - val_accuracy: 0.7667 - val_auc: 0.7067 - val_loss: 0.6748
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 161ms/step - accuracy: 0.7483 - auc: 0.8276 - loss: 0.6613 - val_accuracy: 0.6667 - val_auc: 0.6622 - val_loss: 0.6597
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 158ms/step - accuracy: 0.6802 - auc: 0.7702 - loss: 0.6543 - val_accuracy: 0.8000 - val_auc: 0.7111 - val_loss: 0.6281
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 165ms/step - accuracy: 0.7076 - auc: 0.7323 - loss: 0.6474 - val_accuracy: 0.7333 - val_auc: 0.7133 - val_loss: 0.5907
Epoch 6/200
[1m7/7[0m [32m━━━━━━

2it [08:04, 240.00s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 280ms/step - accuracy: 0.6195 - auc: 0.6710 - loss: 0.6799 - val_accuracy: 0.7000 - val_auc: 0.8178 - val_loss: 0.6564
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 147ms/step - accuracy: 0.6590 - auc: 0.6439 - loss: 0.6598 - val_accuracy: 0.6333 - val_auc: 0.7822 - val_loss: 0.6343
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 117ms/step - accuracy: 0.7253 - auc: 0.7411 - loss: 0.6288 - val_accuracy: 0.6333 - val_auc: 0.7444 - val_loss: 0.6232
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 145ms/step - accuracy: 0.6896 - auc: 0.7181 - loss: 0.6131 - val_accuracy: 0.6000 - val_auc: 0.6844 - val_loss: 0.6497
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 161ms/step - accuracy: 0.6792 - auc: 0.7013 - loss: 0.6122 - val_accuracy: 0.6333 - val_auc: 0.6533 - val_loss: 0.6549
Epoch 6/200
[1m7/7[0m [32m━━━━━━

3it [11:52, 234.42s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 299ms/step - accuracy: 0.5618 - auc: 0.5909 - loss: 0.6860 - val_accuracy: 0.5000 - val_auc: 0.5467 - val_loss: 0.6796
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 163ms/step - accuracy: 0.6403 - auc: 0.7505 - loss: 0.6599 - val_accuracy: 0.6667 - val_auc: 0.5733 - val_loss: 0.6664
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 166ms/step - accuracy: 0.7457 - auc: 0.7800 - loss: 0.6238 - val_accuracy: 0.6333 - val_auc: 0.6333 - val_loss: 0.6427
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 157ms/step - accuracy: 0.7228 - auc: 0.8099 - loss: 0.5759 - val_accuracy: 0.6000 - val_auc: 0.6022 - val_loss: 0.6326
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 171ms/step - accuracy: 0.7674 - auc: 0.7593 - loss: 0.5318 - val_accuracy: 0.7333 - val_auc: 0.6289 - val_loss: 0.5852
Epoch 6/200
[1m7/7[0m [32m━━━━━━

4it [15:39, 231.69s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 273ms/step - accuracy: 0.4517 - auc: 0.5140 - loss: 0.6936 - val_accuracy: 0.6667 - val_auc: 0.7000 - val_loss: 0.6756
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 161ms/step - accuracy: 0.7613 - auc: 0.7316 - loss: 0.6702 - val_accuracy: 0.7000 - val_auc: 0.6956 - val_loss: 0.6558
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 165ms/step - accuracy: 0.7459 - auc: 0.7734 - loss: 0.6390 - val_accuracy: 0.6667 - val_auc: 0.7267 - val_loss: 0.6248
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 159ms/step - accuracy: 0.7459 - auc: 0.7888 - loss: 0.5821 - val_accuracy: 0.6667 - val_auc: 0.7311 - val_loss: 0.6138
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 157ms/step - accuracy: 0.7762 - auc: 0.8454 - loss: 0.5238 - val_accuracy: 0.6667 - val_auc: 0.7000 - val_loss: 0.6060
Epoch 6/200
[1m7/7[0m [32m━━━━━━

5it [19:30, 234.19s/it]

CPU times: user 40min 11s, sys: 14min 29s, total: 54min 40s
Wall time: 19min 30s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

with open("./logs/LSTMv2_conv.txt", "w") as f:
    for i in range(NUM_EPOCHS):
        f.write(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}\n")
        f.write(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}\n")

Epoch 1: TRAIN Accuracy = 0.54 Loss = 0.686 AUC = 0.566
Epoch 1: VAL Accuracy = 0.647 Loss = 0.671 AUC = 0.69
Epoch 2: TRAIN Accuracy = 0.711 Loss = 0.657 AUC = 0.742
Epoch 2: VAL Accuracy = 0.693 Loss = 0.648 AUC = 0.693
Epoch 3: TRAIN Accuracy = 0.731 Loss = 0.622 AUC = 0.756
Epoch 3: VAL Accuracy = 0.66 Loss = 0.627 AUC = 0.698
Epoch 4: TRAIN Accuracy = 0.72 Loss = 0.594 AUC = 0.742
Epoch 4: VAL Accuracy = 0.673 Loss = 0.618 AUC = 0.701
Epoch 5: TRAIN Accuracy = 0.729 Loss = 0.579 AUC = 0.74
Epoch 5: VAL Accuracy = 0.693 Loss = 0.598 AUC = 0.687
Epoch 6: TRAIN Accuracy = 0.74 Loss = 0.554 AUC = 0.748
Epoch 6: VAL Accuracy = 0.713 Loss = 0.577 AUC = 0.733
Epoch 7: TRAIN Accuracy = 0.754 Loss = 0.54 AUC = 0.753
Epoch 7: VAL Accuracy = 0.74 Loss = 0.56 AUC = 0.755
Epoch 8: TRAIN Accuracy = 0.766 Loss = 0.527 AUC = 0.775
Epoch 8: VAL Accuracy = 0.747 Loss = 0.551 AUC = 0.786
Epoch 9: TRAIN Accuracy = 0.76 Loss = 0.522 AUC = 0.79
Epoch 9: VAL Accuracy = 0.753 Loss = 0.532 AUC = 0.812
Epo