In [4]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-04-03 10:58:08.700059: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-03 10:58:08.704650: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-03 10:58:08.724753: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743667088.766911  633557 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743667088.778867  633557 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-03 10:58:08.819837: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [5]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=False, activation='tanh'))

    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.002, weight_decay=1e-7, use_ema=True), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [6]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = [1]
INIT_SEED = 5412

In [7]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,Delta_1,Theta_1,Alpha1_1,Alpha2_1,Beta1_1,Beta2_1,Gamma1_1,Gamma2_1,predefinedlabel
0,0,0.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
2,0,0.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0


In [8]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [9]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-04-03 10:58:15.011935: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [10]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 660ms/step - accuracy: 0.5938 - auc: 0.5414 - loss: 0.6890 - val_accuracy: 0.7667 - val_auc: 0.8089 - val_loss: 0.6401
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 225ms/step - accuracy: 0.7558 - auc: 0.7352 - loss: 0.6435 - val_accuracy: 0.8333 - val_auc: 0.8000 - val_loss: 0.5844
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 224ms/step - accuracy: 0.7341 - auc: 0.7931 - loss: 0.6101 - val_accuracy: 0.8333 - val_auc: 0.7356 - val_loss: 0.5284
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 224ms/step - accuracy: 0.7341 - auc: 0.6771 - loss: 0.5784 - val_accuracy: 0.8333 - val_auc: 0.7467 - val_loss: 0.4807
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 205ms/step - accuracy: 0.7341 - auc: 0.8174 - loss: 0.5344 - val_accuracy: 0.8333 - val_auc: 0.8000 - val_loss: 0.4525
Epoch 6/200
[1m7/7[0m [32m━━━━━

1it [07:00, 420.13s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 409ms/step - accuracy: 0.4376 - auc: 0.5623 - loss: 0.6865 - val_accuracy: 0.5000 - val_auc: 0.7600 - val_loss: 0.6441
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 215ms/step - accuracy: 0.6409 - auc: 0.8041 - loss: 0.6285 - val_accuracy: 0.7667 - val_auc: 0.7600 - val_loss: 0.5815
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 255ms/step - accuracy: 0.7287 - auc: 0.7569 - loss: 0.5819 - val_accuracy: 0.7667 - val_auc: 0.7600 - val_loss: 0.5212
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 235ms/step - accuracy: 0.7287 - auc: 0.8318 - loss: 0.5115 - val_accuracy: 0.7667 - val_auc: 0.7600 - val_loss: 0.4851
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 232ms/step - accuracy: 0.7287 - auc: 0.7318 - loss: 0.5230 - val_accuracy: 0.7667 - val_auc: 0.7600 - val_loss: 0.4785
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [14:00, 420.19s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 406ms/step - accuracy: 0.7337 - auc: 0.5920 - loss: 0.6754 - val_accuracy: 0.8333 - val_auc: 0.8000 - val_loss: 0.6113
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 237ms/step - accuracy: 0.7841 - auc: 0.7543 - loss: 0.6128 - val_accuracy: 0.7667 - val_auc: 0.7222 - val_loss: 0.5418
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 234ms/step - accuracy: 0.7841 - auc: 0.6492 - loss: 0.5551 - val_accuracy: 0.7667 - val_auc: 0.7556 - val_loss: 0.5039
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 221ms/step - accuracy: 0.8058 - auc: 0.7089 - loss: 0.5099 - val_accuracy: 0.7667 - val_auc: 0.7489 - val_loss: 0.4987
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 226ms/step - accuracy: 0.7841 - auc: 0.7392 - loss: 0.4777 - val_accuracy: 0.7667 - val_auc: 0.7778 - val_loss: 0.4886
Epoch 6/200
[1m7/7[0m [32m━━━━━

3it [21:08, 423.78s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 385ms/step - accuracy: 0.5899 - auc: 0.4151 - loss: 0.7078 - val_accuracy: 0.7667 - val_auc: 0.7667 - val_loss: 0.6003
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 231ms/step - accuracy: 0.7899 - auc: 0.7657 - loss: 0.5966 - val_accuracy: 0.7667 - val_auc: 0.7111 - val_loss: 0.5424
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 217ms/step - accuracy: 0.7899 - auc: 0.7713 - loss: 0.5326 - val_accuracy: 0.7667 - val_auc: 0.7111 - val_loss: 0.5074
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 201ms/step - accuracy: 0.7899 - auc: 0.8347 - loss: 0.4676 - val_accuracy: 0.7667 - val_auc: 0.7200 - val_loss: 0.4959
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 252ms/step - accuracy: 0.7899 - auc: 0.7925 - loss: 0.4626 - val_accuracy: 0.7667 - val_auc: 0.7222 - val_loss: 0.4808
Epoch 6/200
[1m7/7[0m [32m━━━━━

4it [28:11, 423.62s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 394ms/step - accuracy: 0.5382 - auc: 0.1523 - loss: 0.7515 - val_accuracy: 0.6333 - val_auc: 0.5978 - val_loss: 0.6899
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 243ms/step - accuracy: 0.7357 - auc: 0.7328 - loss: 0.6643 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.6243
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 207ms/step - accuracy: 0.8258 - auc: 0.8470 - loss: 0.5916 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5738
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 235ms/step - accuracy: 0.8258 - auc: 0.8076 - loss: 0.5204 - val_accuracy: 0.7333 - val_auc: 0.7733 - val_loss: 0.5342
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 239ms/step - accuracy: 0.8258 - auc: 0.8266 - loss: 0.4581 - val_accuracy: 0.7333 - val_auc: 0.7867 - val_loss: 0.5185
Epoch 6/200
[1m7/7[0m [32m━━━━━

5it [35:36, 427.30s/it]

CPU times: user 48min 48s, sys: 13min 2s, total: 1h 1min 50s
Wall time: 35min 36s





In [11]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

with open("./logs/GRUv2_emb1.txt", "w") as f:
    for i in range(NUM_EPOCHS):
        f.write(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}\n")
        f.write(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}\n")

Epoch 1: TRAIN Accuracy = 0.586 Loss = 0.691 AUC = 0.517
Epoch 1: VAL Accuracy = 0.7 Loss = 0.637 AUC = 0.747
Epoch 2: TRAIN Accuracy = 0.757 Loss = 0.621 AUC = 0.769
Epoch 2: VAL Accuracy = 0.773 Loss = 0.575 AUC = 0.748
Epoch 3: TRAIN Accuracy = 0.774 Loss = 0.562 AUC = 0.776
Epoch 3: VAL Accuracy = 0.773 Loss = 0.527 AUC = 0.742
Epoch 4: TRAIN Accuracy = 0.777 Loss = 0.514 AUC = 0.773
Epoch 4: VAL Accuracy = 0.773 Loss = 0.499 AUC = 0.75
Epoch 5: TRAIN Accuracy = 0.774 Loss = 0.486 AUC = 0.777
Epoch 5: VAL Accuracy = 0.773 Loss = 0.484 AUC = 0.769
Epoch 6: TRAIN Accuracy = 0.783 Loss = 0.483 AUC = 0.759
Epoch 6: VAL Accuracy = 0.773 Loss = 0.469 AUC = 0.781
Epoch 7: TRAIN Accuracy = 0.777 Loss = 0.477 AUC = 0.769
Epoch 7: VAL Accuracy = 0.787 Loss = 0.455 AUC = 0.787
Epoch 8: TRAIN Accuracy = 0.797 Loss = 0.457 AUC = 0.795
Epoch 8: VAL Accuracy = 0.793 Loss = 0.444 AUC = 0.803
Epoch 9: TRAIN Accuracy = 0.803 Loss = 0.449 AUC = 0.816
Epoch 9: VAL Accuracy = 0.8 Loss = 0.436 AUC = 0.8