In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout, LSTM
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-21 02:12:28.262334: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 02:12:28.276533: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 02:12:28.309598: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742512348.358969  273836 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742512348.374373  273836 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-21 02:12:28.439004: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=False))

    model.add(Dropout(0.2))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = [1]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,...,Raw_1,Delta_1,Theta_1,Alpha1_1,Alpha2_1,Beta1_1,Beta2_1,Gamma1_1,Gamma2_1,predefinedlabel
0,0,0.0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0.0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,...,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
2,0,0.0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,...,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-21 02:12:34.406414: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)


0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 450ms/step - accuracy: 0.6040 - auc: 0.6746 - loss: 0.6831 - val_accuracy: 0.4667 - val_auc: 0.6689 - val_loss: 0.6556
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 241ms/step - accuracy: 0.5050 - auc: 0.5937 - loss: 0.6720 - val_accuracy: 0.6667 - val_auc: 0.7000 - val_loss: 0.6070
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 228ms/step - accuracy: 0.6739 - auc: 0.6678 - loss: 0.6439 - val_accuracy: 0.7000 - val_auc: 0.7467 - val_loss: 0.5948
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 257ms/step - accuracy: 0.6604 - auc: 0.6445 - loss: 0.6290 - val_accuracy: 0.7000 - val_auc: 0.7244 - val_loss: 0.5679
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 298ms/step - accuracy: 0.7018 - auc: 0.6635 - loss: 0.6153 - val_accuracy: 0.7333 - val_auc: 0.7044 - val_loss: 0.5528
Epoch 6/200
[1m7/7[0m [32m━━━━━

1it [07:19, 439.30s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 408ms/step - accuracy: 0.5212 - auc: 0.6001 - loss: 0.6910 - val_accuracy: 0.5667 - val_auc: 0.5978 - val_loss: 0.6889
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 237ms/step - accuracy: 0.6859 - auc: 0.7026 - loss: 0.6798 - val_accuracy: 0.6333 - val_auc: 0.6867 - val_loss: 0.6774
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 225ms/step - accuracy: 0.6585 - auc: 0.7355 - loss: 0.6630 - val_accuracy: 0.6000 - val_auc: 0.7000 - val_loss: 0.6692
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 238ms/step - accuracy: 0.6523 - auc: 0.7515 - loss: 0.6571 - val_accuracy: 0.7000 - val_auc: 0.7578 - val_loss: 0.6448
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 250ms/step - accuracy: 0.7183 - auc: 0.6610 - loss: 0.6445 - val_accuracy: 0.7667 - val_auc: 0.7511 - val_loss: 0.6361
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [14:28, 433.53s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 448ms/step - accuracy: 0.4425 - auc: 0.6820 - loss: 0.6855 - val_accuracy: 0.6333 - val_auc: 0.7644 - val_loss: 0.6703
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 238ms/step - accuracy: 0.5721 - auc: 0.7403 - loss: 0.6689 - val_accuracy: 0.6333 - val_auc: 0.8133 - val_loss: 0.6292
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 258ms/step - accuracy: 0.4878 - auc: 0.7297 - loss: 0.6435 - val_accuracy: 0.7000 - val_auc: 0.7889 - val_loss: 0.5950
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 242ms/step - accuracy: 0.6593 - auc: 0.7542 - loss: 0.6174 - val_accuracy: 0.7000 - val_auc: 0.7889 - val_loss: 0.5805
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 258ms/step - accuracy: 0.7113 - auc: 0.7460 - loss: 0.5866 - val_accuracy: 0.7000 - val_auc: 0.7889 - val_loss: 0.5737
Epoch 6/200
[1m7/7[0m [32m━━━━━

3it [21:38, 431.78s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 465ms/step - accuracy: 0.4673 - auc: 0.4839 - loss: 0.6926 - val_accuracy: 0.5000 - val_auc: 0.7022 - val_loss: 0.6809
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 238ms/step - accuracy: 0.6319 - auc: 0.6979 - loss: 0.6780 - val_accuracy: 0.4667 - val_auc: 0.6356 - val_loss: 0.6770
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 227ms/step - accuracy: 0.5333 - auc: 0.6982 - loss: 0.6563 - val_accuracy: 0.6667 - val_auc: 0.7022 - val_loss: 0.6627
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 269ms/step - accuracy: 0.7899 - auc: 0.7626 - loss: 0.6269 - val_accuracy: 0.7333 - val_auc: 0.7489 - val_loss: 0.6375
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 253ms/step - accuracy: 0.7663 - auc: 0.7548 - loss: 0.6034 - val_accuracy: 0.7000 - val_auc: 0.7489 - val_loss: 0.6087
Epoch 6/200
[1m7/7[0m [32m━━━━━

4it [28:46, 430.38s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 445ms/step - accuracy: 0.4723 - auc: 0.4888 - loss: 0.6919 - val_accuracy: 0.5333 - val_auc: 0.6089 - val_loss: 0.6839
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 240ms/step - accuracy: 0.6594 - auc: 0.7100 - loss: 0.6798 - val_accuracy: 0.6667 - val_auc: 0.6822 - val_loss: 0.6697
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 238ms/step - accuracy: 0.7895 - auc: 0.8226 - loss: 0.6599 - val_accuracy: 0.6333 - val_auc: 0.7222 - val_loss: 0.6511
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 245ms/step - accuracy: 0.7627 - auc: 0.7717 - loss: 0.6509 - val_accuracy: 0.6333 - val_auc: 0.7511 - val_loss: 0.6292
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 260ms/step - accuracy: 0.7246 - auc: 0.8310 - loss: 0.6266 - val_accuracy: 0.6667 - val_auc: 0.7578 - val_loss: 0.6140
Epoch 6/200
[1m7/7[0m [32m━━━━━

5it [35:44, 428.81s/it]

CPU times: user 52min 51s, sys: 14min 53s, total: 1h 7min 44s
Wall time: 35min 44s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

Epoch 1: TRAIN Accuracy = 0.534 Loss = 0.684 AUC = 0.627
Epoch 1: VAL Accuracy = 0.54 Loss = 0.676 AUC = 0.668
Epoch 2: TRAIN Accuracy = 0.62 Loss = 0.67 AUC = 0.696
Epoch 2: VAL Accuracy = 0.613 Loss = 0.652 AUC = 0.704
Epoch 3: TRAIN Accuracy = 0.651 Loss = 0.646 AUC = 0.727
Epoch 3: VAL Accuracy = 0.66 Loss = 0.635 AUC = 0.732
Epoch 4: TRAIN Accuracy = 0.703 Loss = 0.627 AUC = 0.727
Epoch 4: VAL Accuracy = 0.693 Loss = 0.612 AUC = 0.754
Epoch 5: TRAIN Accuracy = 0.726 Loss = 0.605 AUC = 0.735
Epoch 5: VAL Accuracy = 0.713 Loss = 0.597 AUC = 0.75
Epoch 6: TRAIN Accuracy = 0.76 Loss = 0.568 AUC = 0.792
Epoch 6: VAL Accuracy = 0.733 Loss = 0.575 AUC = 0.742
Epoch 7: TRAIN Accuracy = 0.743 Loss = 0.565 AUC = 0.755
Epoch 7: VAL Accuracy = 0.753 Loss = 0.551 AUC = 0.763
Epoch 8: TRAIN Accuracy = 0.774 Loss = 0.539 AUC = 0.764
Epoch 8: VAL Accuracy = 0.76 Loss = 0.535 AUC = 0.776
Epoch 9: TRAIN Accuracy = 0.766 Loss = 0.52 AUC = 0.776
Epoch 9: VAL Accuracy = 0.753 Loss = 0.53 AUC = 0.788
E