In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout, LSTM
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-21 00:51:17.518956: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 00:51:17.527500: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 00:51:17.551310: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742507477.586181   90842 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742507477.595166   90842 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-21 00:51:17.628794: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=False))

    model.add(Dropout(0.2))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,0.0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,0.0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,0.0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-21 00:51:22.563509: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)


0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 403ms/step - accuracy: 0.5835 - auc: 0.5905 - loss: 0.6901 - val_accuracy: 0.5000 - val_auc: 0.5822 - val_loss: 0.6851
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 222ms/step - accuracy: 0.5074 - auc: 0.5542 - loss: 0.6880 - val_accuracy: 0.5333 - val_auc: 0.7422 - val_loss: 0.6672
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 228ms/step - accuracy: 0.6176 - auc: 0.6474 - loss: 0.6739 - val_accuracy: 0.5667 - val_auc: 0.7511 - val_loss: 0.6459
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 222ms/step - accuracy: 0.5876 - auc: 0.7138 - loss: 0.6500 - val_accuracy: 0.6000 - val_auc: 0.7067 - val_loss: 0.6227
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 211ms/step - accuracy: 0.6558 - auc: 0.7089 - loss: 0.6249 - val_accuracy: 0.6667 - val_auc: 0.6711 - val_loss: 0.6095
Epoch 6/200
[1m7/7[0m [32m━━━━━

1it [06:17, 377.31s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 406ms/step - accuracy: 0.5298 - auc: 0.7168 - loss: 0.6766 - val_accuracy: 0.5000 - val_auc: 0.7333 - val_loss: 0.6692
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 231ms/step - accuracy: 0.6124 - auc: 0.7007 - loss: 0.6598 - val_accuracy: 0.6667 - val_auc: 0.6933 - val_loss: 0.6558
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 229ms/step - accuracy: 0.6896 - auc: 0.7439 - loss: 0.6361 - val_accuracy: 0.7000 - val_auc: 0.7156 - val_loss: 0.6301
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 230ms/step - accuracy: 0.6865 - auc: 0.7660 - loss: 0.6214 - val_accuracy: 0.7333 - val_auc: 0.7222 - val_loss: 0.6047
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 238ms/step - accuracy: 0.7195 - auc: 0.6956 - loss: 0.5870 - val_accuracy: 0.7667 - val_auc: 0.8044 - val_loss: 0.5605
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [12:33, 376.73s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 409ms/step - accuracy: 0.3999 - auc: 0.6488 - loss: 0.6936 - val_accuracy: 0.5667 - val_auc: 0.7200 - val_loss: 0.6802
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 240ms/step - accuracy: 0.5477 - auc: 0.6857 - loss: 0.6831 - val_accuracy: 0.4667 - val_auc: 0.7200 - val_loss: 0.6719
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 226ms/step - accuracy: 0.5912 - auc: 0.6729 - loss: 0.6777 - val_accuracy: 0.6333 - val_auc: 0.8644 - val_loss: 0.6582
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 237ms/step - accuracy: 0.4936 - auc: 0.7512 - loss: 0.6695 - val_accuracy: 0.5333 - val_auc: 0.7311 - val_loss: 0.6469
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 228ms/step - accuracy: 0.5506 - auc: 0.6849 - loss: 0.6493 - val_accuracy: 0.6667 - val_auc: 0.7778 - val_loss: 0.6167
Epoch 6/200
[1m7/7[0m [32m━━━━━

3it [18:50, 376.87s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 427ms/step - accuracy: 0.5118 - auc: 0.6390 - loss: 0.6845 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.6554
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 229ms/step - accuracy: 0.7194 - auc: 0.7474 - loss: 0.6422 - val_accuracy: 0.7000 - val_auc: 0.7356 - val_loss: 0.6215
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 225ms/step - accuracy: 0.7782 - auc: 0.7953 - loss: 0.5959 - val_accuracy: 0.7333 - val_auc: 0.7356 - val_loss: 0.5894
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 251ms/step - accuracy: 0.7751 - auc: 0.7635 - loss: 0.5633 - val_accuracy: 0.7333 - val_auc: 0.7600 - val_loss: 0.5657
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 238ms/step - accuracy: 0.7864 - auc: 0.7450 - loss: 0.5292 - val_accuracy: 0.7667 - val_auc: 0.7533 - val_loss: 0.5348
Epoch 6/200
[1m7/7[0m [32m━━━━━

4it [25:12, 378.65s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 422ms/step - accuracy: 0.4891 - auc: 0.4134 - loss: 0.6963 - val_accuracy: 0.6667 - val_auc: 0.6467 - val_loss: 0.6737
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 223ms/step - accuracy: 0.7682 - auc: 0.8538 - loss: 0.6577 - val_accuracy: 0.6667 - val_auc: 0.6622 - val_loss: 0.6522
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 235ms/step - accuracy: 0.7726 - auc: 0.8416 - loss: 0.6313 - val_accuracy: 0.6667 - val_auc: 0.7156 - val_loss: 0.6282
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 237ms/step - accuracy: 0.7803 - auc: 0.8336 - loss: 0.6025 - val_accuracy: 0.7000 - val_auc: 0.7533 - val_loss: 0.6020
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 212ms/step - accuracy: 0.7949 - auc: 0.8098 - loss: 0.5693 - val_accuracy: 0.7000 - val_auc: 0.7178 - val_loss: 0.5849
Epoch 6/200
[1m7/7[0m [32m━━━━━━

5it [31:39, 379.99s/it]

CPU times: user 57min 32s, sys: 21min 2s, total: 1h 18min 34s
Wall time: 31min 39s





In [10]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

Epoch 1: TRAIN Accuracy = 0.531 Loss = 0.684 AUC = 0.639
Epoch 1: VAL Accuracy = 0.593 Loss = 0.673 AUC = 0.686
Epoch 2: TRAIN Accuracy = 0.64 Loss = 0.665 AUC = 0.715
Epoch 2: VAL Accuracy = 0.607 Loss = 0.654 AUC = 0.711
Epoch 3: TRAIN Accuracy = 0.7 Loss = 0.637 AUC = 0.75
Epoch 3: VAL Accuracy = 0.66 Loss = 0.63 AUC = 0.756
Epoch 4: TRAIN Accuracy = 0.674 Loss = 0.618 AUC = 0.76
Epoch 4: VAL Accuracy = 0.66 Loss = 0.608 AUC = 0.735
Epoch 5: TRAIN Accuracy = 0.703 Loss = 0.586 AUC = 0.737
Epoch 5: VAL Accuracy = 0.713 Loss = 0.581 AUC = 0.745
Epoch 6: TRAIN Accuracy = 0.746 Loss = 0.557 AUC = 0.769
Epoch 6: VAL Accuracy = 0.727 Loss = 0.552 AUC = 0.763
Epoch 7: TRAIN Accuracy = 0.754 Loss = 0.536 AUC = 0.78
Epoch 7: VAL Accuracy = 0.76 Loss = 0.524 AUC = 0.789
Epoch 8: TRAIN Accuracy = 0.737 Loss = 0.516 AUC = 0.777
Epoch 8: VAL Accuracy = 0.773 Loss = 0.502 AUC = 0.81
Epoch 9: TRAIN Accuracy = 0.763 Loss = 0.497 AUC = 0.79
Epoch 9: VAL Accuracy = 0.767 Loss = 0.493 AUC = 0.828
Epoc