In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-21 00:50:48.670265: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 00:50:48.677169: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 00:50:48.696459: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742507448.726163   89756 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742507448.738502   89756 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-21 00:50:48.781882: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=False))
    
    model.add(Dropout(0.2))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,0.0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,0.0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,0.0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-21 00:50:51.670535: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 280ms/step - accuracy: 0.4487 - auc: 0.6030 - loss: 0.6647 - val_accuracy: 0.7333 - val_auc: 0.6889 - val_loss: 0.6469
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 169ms/step - accuracy: 0.6047 - auc: 0.6564 - loss: 0.6426 - val_accuracy: 0.8000 - val_auc: 0.8178 - val_loss: 0.5855
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 155ms/step - accuracy: 0.7305 - auc: 0.7561 - loss: 0.6041 - val_accuracy: 0.8000 - val_auc: 0.8067 - val_loss: 0.5493
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 166ms/step - accuracy: 0.6871 - auc: 0.6990 - loss: 0.5925 - val_accuracy: 0.8000 - val_auc: 0.8222 - val_loss: 0.5026
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 172ms/step - accuracy: 0.7068 - auc: 0.7105 - loss: 0.5735 - val_accuracy: 0.8000 - val_auc: 0.8222 - val_loss: 0.4866
Epoch 6/200
[1m7/7[0m [32m━━━━━━

1it [06:15, 375.43s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 471ms/step - accuracy: 0.7474 - auc: 0.6997 - loss: 0.6776 - val_accuracy: 0.7667 - val_auc: 0.7867 - val_loss: 0.6481
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 257ms/step - accuracy: 0.7364 - auc: 0.7328 - loss: 0.6517 - val_accuracy: 0.7667 - val_auc: 0.7867 - val_loss: 0.6149
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 256ms/step - accuracy: 0.7554 - auc: 0.7675 - loss: 0.6267 - val_accuracy: 0.7667 - val_auc: 0.7600 - val_loss: 0.5773
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 234ms/step - accuracy: 0.7554 - auc: 0.8295 - loss: 0.5963 - val_accuracy: 0.7667 - val_auc: 0.7733 - val_loss: 0.5382
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 243ms/step - accuracy: 0.7400 - auc: 0.7196 - loss: 0.5584 - val_accuracy: 0.7667 - val_auc: 0.7733 - val_loss: 0.5059
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [12:44, 383.54s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 487ms/step - accuracy: 0.4649 - auc: 0.7667 - loss: 0.6397 - val_accuracy: 0.7667 - val_auc: 0.8111 - val_loss: 0.5938
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 244ms/step - accuracy: 0.6085 - auc: 0.7318 - loss: 0.6082 - val_accuracy: 0.7667 - val_auc: 0.7889 - val_loss: 0.5420
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 250ms/step - accuracy: 0.7779 - auc: 0.7700 - loss: 0.5882 - val_accuracy: 0.7000 - val_auc: 0.7889 - val_loss: 0.5215
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 249ms/step - accuracy: 0.6099 - auc: 0.6808 - loss: 0.5578 - val_accuracy: 0.7667 - val_auc: 0.8111 - val_loss: 0.5150
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 260ms/step - accuracy: 0.7841 - auc: 0.7910 - loss: 0.5489 - val_accuracy: 0.7667 - val_auc: 0.8111 - val_loss: 0.5180
Epoch 6/200
[1m7/7[0m [32m━━━━━

3it [19:10, 384.43s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 444ms/step - accuracy: 0.5783 - auc: 0.5227 - loss: 0.6894 - val_accuracy: 0.7333 - val_auc: 0.6889 - val_loss: 0.6736
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 230ms/step - accuracy: 0.7155 - auc: 0.6926 - loss: 0.6718 - val_accuracy: 0.7667 - val_auc: 0.8089 - val_loss: 0.6464
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 241ms/step - accuracy: 0.7638 - auc: 0.7728 - loss: 0.6421 - val_accuracy: 0.7667 - val_auc: 0.7667 - val_loss: 0.6106
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 255ms/step - accuracy: 0.7340 - auc: 0.7079 - loss: 0.6131 - val_accuracy: 0.7667 - val_auc: 0.7667 - val_loss: 0.5653
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 239ms/step - accuracy: 0.7899 - auc: 0.7991 - loss: 0.5585 - val_accuracy: 0.7667 - val_auc: 0.7667 - val_loss: 0.5287
Epoch 6/200
[1m7/7[0m [32m━━━━━

4it [25:38, 385.79s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 416ms/step - accuracy: 0.6868 - auc: 0.7338 - loss: 0.6497 - val_accuracy: 0.7333 - val_auc: 0.7822 - val_loss: 0.5874
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 244ms/step - accuracy: 0.8413 - auc: 0.8680 - loss: 0.5969 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5408
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 236ms/step - accuracy: 0.8258 - auc: 0.8786 - loss: 0.5366 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5163
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 248ms/step - accuracy: 0.8258 - auc: 0.8435 - loss: 0.4674 - val_accuracy: 0.7333 - val_auc: 0.7333 - val_loss: 0.5057
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 238ms/step - accuracy: 0.8258 - auc: 0.8532 - loss: 0.4418 - val_accuracy: 0.7333 - val_auc: 0.8267 - val_loss: 0.5111
Epoch 6/200
[1m7/7[0m [32m━━━━━

5it [32:18, 387.78s/it]

CPU times: user 1h 3min 31s, sys: 23min 10s, total: 1h 26min 42s
Wall time: 32min 18s





In [9]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

Epoch 1: TRAIN Accuracy = 0.62 Loss = 0.661 AUC = 0.667
Epoch 1: VAL Accuracy = 0.747 Loss = 0.63 AUC = 0.752
Epoch 2: TRAIN Accuracy = 0.723 Loss = 0.629 AUC = 0.74
Epoch 2: VAL Accuracy = 0.767 Loss = 0.586 AUC = 0.79
Epoch 3: TRAIN Accuracy = 0.766 Loss = 0.587 AUC = 0.783
Epoch 3: VAL Accuracy = 0.753 Loss = 0.555 AUC = 0.774
Epoch 4: TRAIN Accuracy = 0.743 Loss = 0.557 AUC = 0.746
Epoch 4: VAL Accuracy = 0.767 Loss = 0.525 AUC = 0.781
Epoch 5: TRAIN Accuracy = 0.771 Loss = 0.531 AUC = 0.761
Epoch 5: VAL Accuracy = 0.767 Loss = 0.51 AUC = 0.8
Epoch 6: TRAIN Accuracy = 0.771 Loss = 0.508 AUC = 0.807
Epoch 6: VAL Accuracy = 0.767 Loss = 0.49 AUC = 0.799
Epoch 7: TRAIN Accuracy = 0.78 Loss = 0.491 AUC = 0.772
Epoch 7: VAL Accuracy = 0.78 Loss = 0.466 AUC = 0.8
Epoch 8: TRAIN Accuracy = 0.78 Loss = 0.477 AUC = 0.775
Epoch 8: VAL Accuracy = 0.8 Loss = 0.448 AUC = 0.811
Epoch 9: TRAIN Accuracy = 0.791 Loss = 0.47 AUC = 0.795
Epoch 9: VAL Accuracy = 0.813 Loss = 0.437 AUC = 0.807
Epoch 10