In [4]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-04-03 10:58:29.427804: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-03 10:58:29.437285: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-03 10:58:29.472094: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743667109.558531  633877 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743667109.576418  633877 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-03 10:58:29.628376: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [5]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=False, activation='tanh'))

    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.002, weight_decay=1e-7, use_ema=True), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [6]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = [1, 2]
INIT_SEED = 5412

In [7]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,...,Gamma2_1,Delta_2,Theta_2,Alpha1_2,Alpha2_2,Beta1_2,Beta2_2,Gamma1_2,Gamma2_2,predefinedlabel
0,0,0.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,...,8293.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0,0.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,...,2740.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0


In [8]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [9]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-04-03 10:58:35.465453: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [10]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 391ms/step - accuracy: 0.6344 - auc: 0.4849 - loss: 0.6946 - val_accuracy: 0.7667 - val_auc: 0.7822 - val_loss: 0.6508
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 274ms/step - accuracy: 0.6845 - auc: 0.7684 - loss: 0.6527 - val_accuracy: 0.8000 - val_auc: 0.7889 - val_loss: 0.5887
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 305ms/step - accuracy: 0.7341 - auc: 0.7983 - loss: 0.6172 - val_accuracy: 0.8000 - val_auc: 0.7889 - val_loss: 0.5380
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 226ms/step - accuracy: 0.7228 - auc: 0.6260 - loss: 0.5984 - val_accuracy: 0.8000 - val_auc: 0.7600 - val_loss: 0.5002
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 248ms/step - accuracy: 0.7341 - auc: 0.8326 - loss: 0.5465 - val_accuracy: 0.8000 - val_auc: 0.7889 - val_loss: 0.4727
Epoch 6/200
[1m7/7[0m [32m━━━━━

1it [07:09, 429.20s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 409ms/step - accuracy: 0.4807 - auc: 0.6360 - loss: 0.6707 - val_accuracy: 0.7000 - val_auc: 0.7600 - val_loss: 0.6357
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 215ms/step - accuracy: 0.7287 - auc: 0.8114 - loss: 0.6139 - val_accuracy: 0.7667 - val_auc: 0.7600 - val_loss: 0.5651
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 273ms/step - accuracy: 0.7287 - auc: 0.7407 - loss: 0.5631 - val_accuracy: 0.7667 - val_auc: 0.7600 - val_loss: 0.5071
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 231ms/step - accuracy: 0.7287 - auc: 0.8301 - loss: 0.5165 - val_accuracy: 0.7667 - val_auc: 0.7600 - val_loss: 0.4785
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 232ms/step - accuracy: 0.7287 - auc: 0.7358 - loss: 0.5261 - val_accuracy: 0.8000 - val_auc: 0.7600 - val_loss: 0.4729
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [14:08, 423.61s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 378ms/step - accuracy: 0.6118 - auc: 0.7248 - loss: 0.6388 - val_accuracy: 0.7667 - val_auc: 0.7778 - val_loss: 0.5739
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 227ms/step - accuracy: 0.8058 - auc: 0.7560 - loss: 0.5727 - val_accuracy: 0.7667 - val_auc: 0.7489 - val_loss: 0.5199
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 233ms/step - accuracy: 0.7841 - auc: 0.6580 - loss: 0.5292 - val_accuracy: 0.7667 - val_auc: 0.7444 - val_loss: 0.4913
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 227ms/step - accuracy: 0.7841 - auc: 0.7056 - loss: 0.5044 - val_accuracy: 0.7667 - val_auc: 0.7556 - val_loss: 0.4901
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 244ms/step - accuracy: 0.7841 - auc: 0.7526 - loss: 0.4848 - val_accuracy: 0.7667 - val_auc: 0.7778 - val_loss: 0.4839
Epoch 6/200
[1m7/7[0m [32m━━━━━

3it [21:11, 423.18s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 426ms/step - accuracy: 0.4937 - auc: 0.3386 - loss: 0.7251 - val_accuracy: 0.7333 - val_auc: 0.5778 - val_loss: 0.6695
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 229ms/step - accuracy: 0.7899 - auc: 0.7791 - loss: 0.6155 - val_accuracy: 0.7333 - val_auc: 0.6889 - val_loss: 0.5978
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 238ms/step - accuracy: 0.7899 - auc: 0.7653 - loss: 0.5528 - val_accuracy: 0.7333 - val_auc: 0.6978 - val_loss: 0.5532
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 223ms/step - accuracy: 0.7899 - auc: 0.8291 - loss: 0.4907 - val_accuracy: 0.7333 - val_auc: 0.7222 - val_loss: 0.5336
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 258ms/step - accuracy: 0.7899 - auc: 0.7979 - loss: 0.4823 - val_accuracy: 0.7333 - val_auc: 0.7267 - val_loss: 0.5254
Epoch 6/200
[1m7/7[0m [32m━━━━━

4it [28:23, 426.62s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 422ms/step - accuracy: 0.5572 - auc: 0.2193 - loss: 0.7468 - val_accuracy: 0.6667 - val_auc: 0.5711 - val_loss: 0.6891
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 224ms/step - accuracy: 0.8383 - auc: 0.8233 - loss: 0.6560 - val_accuracy: 0.7333 - val_auc: 0.7289 - val_loss: 0.6283
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 243ms/step - accuracy: 0.8258 - auc: 0.8177 - loss: 0.5919 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5793
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 296ms/step - accuracy: 0.8258 - auc: 0.8135 - loss: 0.5246 - val_accuracy: 0.7333 - val_auc: 0.7533 - val_loss: 0.5402
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 224ms/step - accuracy: 0.8258 - auc: 0.8331 - loss: 0.4514 - val_accuracy: 0.7333 - val_auc: 0.7200 - val_loss: 0.5236
Epoch 6/200
[1m7/7[0m [32m━━━━━

5it [35:29, 425.83s/it]

CPU times: user 48min 56s, sys: 13min 4s, total: 1h 2min
Wall time: 35min 29s





In [11]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

with open("./logs/GRUv2_emb2.txt", "w") as f:
    for i in range(NUM_EPOCHS):
        f.write(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}\n")
        f.write(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}\n")

Epoch 1: TRAIN Accuracy = 0.574 Loss = 0.682 AUC = 0.535
Epoch 1: VAL Accuracy = 0.727 Loss = 0.644 AUC = 0.694
Epoch 2: TRAIN Accuracy = 0.771 Loss = 0.612 AUC = 0.78
Epoch 2: VAL Accuracy = 0.76 Loss = 0.58 AUC = 0.743
Epoch 3: TRAIN Accuracy = 0.774 Loss = 0.56 AUC = 0.774
Epoch 3: VAL Accuracy = 0.76 Loss = 0.534 AUC = 0.748
Epoch 4: TRAIN Accuracy = 0.771 Loss = 0.52 AUC = 0.77
Epoch 4: VAL Accuracy = 0.76 Loss = 0.509 AUC = 0.75
Epoch 5: TRAIN Accuracy = 0.774 Loss = 0.493 AUC = 0.781
Epoch 5: VAL Accuracy = 0.767 Loss = 0.496 AUC = 0.755
Epoch 6: TRAIN Accuracy = 0.783 Loss = 0.481 AUC = 0.758
Epoch 6: VAL Accuracy = 0.773 Loss = 0.483 AUC = 0.766
Epoch 7: TRAIN Accuracy = 0.786 Loss = 0.473 AUC = 0.771
Epoch 7: VAL Accuracy = 0.78 Loss = 0.471 AUC = 0.78
Epoch 8: TRAIN Accuracy = 0.789 Loss = 0.463 AUC = 0.799
Epoch 8: VAL Accuracy = 0.787 Loss = 0.46 AUC = 0.806
Epoch 9: TRAIN Accuracy = 0.797 Loss = 0.454 AUC = 0.813
Epoch 9: VAL Accuracy = 0.787 Loss = 0.449 AUC = 0.81
Epoch