In [4]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-04-03 10:57:54.978333: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-03 10:57:55.074615: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-03 10:57:55.182226: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743667075.269473  634184 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743667075.296773  634184 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-03 10:57:55.526938: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [5]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=False, activation='tanh'))

    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.002, weight_decay=1e-7, use_ema=True), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [6]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = []
USE_DIFF = True
INIT_SEED = 5412

In [7]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

if USE_DIFF:
    for feature_name in FEATURES:
        data[feature_name] = data[feature_name] - data.groupby(IDS)[feature_name].shift(1).fillna(0)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,0.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,0.0,-228176.0,-62529.0,-32296.0,-21751.0,-25200.0,-41410.0,-27935.0,-5553.0,0.0
2,0,0.0,684566.0,355662.0,200560.0,59867.0,33547.0,126849.0,51950.0,22614.0,0.0


In [8]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [9]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-04-03 10:57:59.285447: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [10]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 393ms/step - accuracy: 0.6715 - auc: 0.3941 - loss: 0.7080 - val_accuracy: 0.6000 - val_auc: 0.4978 - val_loss: 0.6937
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 253ms/step - accuracy: 0.6740 - auc: 0.6318 - loss: 0.6752 - val_accuracy: 0.8000 - val_auc: 0.7444 - val_loss: 0.6405
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 331ms/step - accuracy: 0.7285 - auc: 0.7591 - loss: 0.6408 - val_accuracy: 0.8000 - val_auc: 0.7889 - val_loss: 0.5811
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 313ms/step - accuracy: 0.7285 - auc: 0.6176 - loss: 0.6112 - val_accuracy: 0.8000 - val_auc: 0.7889 - val_loss: 0.5239
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 335ms/step - accuracy: 0.7341 - auc: 0.7444 - loss: 0.5717 - val_accuracy: 0.8000 - val_auc: 0.7889 - val_loss: 0.4792
Epoch 6/200
[1m7/7[0m [32m━━━━━━

1it [07:12, 432.90s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 418ms/step - accuracy: 0.5095 - auc: 0.5536 - loss: 0.6733 - val_accuracy: 0.5000 - val_auc: 0.7600 - val_loss: 0.6069
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 259ms/step - accuracy: 0.6148 - auc: 0.7426 - loss: 0.6057 - val_accuracy: 0.7667 - val_auc: 0.7600 - val_loss: 0.5530
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 263ms/step - accuracy: 0.7132 - auc: 0.7301 - loss: 0.5646 - val_accuracy: 0.7667 - val_auc: 0.7200 - val_loss: 0.5058
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 212ms/step - accuracy: 0.7287 - auc: 0.8017 - loss: 0.5192 - val_accuracy: 0.7667 - val_auc: 0.7600 - val_loss: 0.4863
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 213ms/step - accuracy: 0.7287 - auc: 0.7531 - loss: 0.5209 - val_accuracy: 0.7667 - val_auc: 0.7600 - val_loss: 0.4810
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [14:30, 435.78s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 430ms/step - accuracy: 0.5719 - auc: 0.2512 - loss: 0.7297 - val_accuracy: 0.5000 - val_auc: 0.3667 - val_loss: 0.7171
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 208ms/step - accuracy: 0.6303 - auc: 0.4275 - loss: 0.6901 - val_accuracy: 0.7000 - val_auc: 0.7756 - val_loss: 0.6658
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 237ms/step - accuracy: 0.8058 - auc: 0.7482 - loss: 0.6494 - val_accuracy: 0.7667 - val_auc: 0.7978 - val_loss: 0.6225
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 213ms/step - accuracy: 0.7785 - auc: 0.6770 - loss: 0.6292 - val_accuracy: 0.7667 - val_auc: 0.7889 - val_loss: 0.5726
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 212ms/step - accuracy: 0.7954 - auc: 0.7951 - loss: 0.5736 - val_accuracy: 0.7667 - val_auc: 0.7778 - val_loss: 0.5221
Epoch 6/200
[1m7/7[0m [32m━━━━━

3it [21:36, 431.27s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 419ms/step - accuracy: 0.5596 - auc: 0.3393 - loss: 0.7442 - val_accuracy: 0.7333 - val_auc: 0.7378 - val_loss: 0.6605
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 212ms/step - accuracy: 0.6907 - auc: 0.6391 - loss: 0.6525 - val_accuracy: 0.8000 - val_auc: 0.8111 - val_loss: 0.5928
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 224ms/step - accuracy: 0.7956 - auc: 0.7686 - loss: 0.5960 - val_accuracy: 0.8000 - val_auc: 0.7444 - val_loss: 0.5321
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 228ms/step - accuracy: 0.7899 - auc: 0.8360 - loss: 0.5125 - val_accuracy: 0.8000 - val_auc: 0.7111 - val_loss: 0.4884
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 221ms/step - accuracy: 0.7899 - auc: 0.7861 - loss: 0.4797 - val_accuracy: 0.8000 - val_auc: 0.7556 - val_loss: 0.4730
Epoch 6/200
[1m7/7[0m [32m━━━━━

4it [28:40, 428.21s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 448ms/step - accuracy: 0.6757 - auc: 0.5117 - loss: 0.6902 - val_accuracy: 0.7667 - val_auc: 0.7800 - val_loss: 0.6566
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 264ms/step - accuracy: 0.7762 - auc: 0.8034 - loss: 0.6306 - val_accuracy: 0.7333 - val_auc: 0.7067 - val_loss: 0.6117
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 238ms/step - accuracy: 0.8413 - auc: 0.8353 - loss: 0.5713 - val_accuracy: 0.7333 - val_auc: 0.7067 - val_loss: 0.5719
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 258ms/step - accuracy: 0.8331 - auc: 0.8041 - loss: 0.4999 - val_accuracy: 0.7333 - val_auc: 0.7067 - val_loss: 0.5531
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 224ms/step - accuracy: 0.8258 - auc: 0.8126 - loss: 0.4564 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5496
Epoch 6/200
[1m7/7[0m [32m━━━━━

5it [35:59, 431.89s/it]

CPU times: user 49min 51s, sys: 13min 18s, total: 1h 3min 9s
Wall time: 35min 59s





In [11]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

with open("./logs/GRUv2_emb4.txt", "w") as f:
    for i in range(NUM_EPOCHS):
        f.write(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}\n")
        f.write(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}\n")

Epoch 1: TRAIN Accuracy = 0.569 Loss = 0.702 AUC = 0.441
Epoch 1: VAL Accuracy = 0.62 Loss = 0.667 AUC = 0.628
Epoch 2: TRAIN Accuracy = 0.7 Loss = 0.646 AUC = 0.689
Epoch 2: VAL Accuracy = 0.76 Loss = 0.613 AUC = 0.76
Epoch 3: TRAIN Accuracy = 0.78 Loss = 0.596 AUC = 0.778
Epoch 3: VAL Accuracy = 0.773 Loss = 0.563 AUC = 0.752
Epoch 4: TRAIN Accuracy = 0.769 Loss = 0.55 AUC = 0.754
Epoch 4: VAL Accuracy = 0.773 Loss = 0.525 AUC = 0.751
Epoch 5: TRAIN Accuracy = 0.777 Loss = 0.516 AUC = 0.77
Epoch 5: VAL Accuracy = 0.773 Loss = 0.501 AUC = 0.766
Epoch 6: TRAIN Accuracy = 0.78 Loss = 0.497 AUC = 0.761
Epoch 6: VAL Accuracy = 0.773 Loss = 0.487 AUC = 0.75
Epoch 7: TRAIN Accuracy = 0.777 Loss = 0.497 AUC = 0.75
Epoch 7: VAL Accuracy = 0.787 Loss = 0.476 AUC = 0.768
Epoch 8: TRAIN Accuracy = 0.791 Loss = 0.473 AUC = 0.784
Epoch 8: VAL Accuracy = 0.787 Loss = 0.467 AUC = 0.786
Epoch 9: TRAIN Accuracy = 0.791 Loss = 0.459 AUC = 0.813
Epoch 9: VAL Accuracy = 0.78 Loss = 0.458 AUC = 0.823
Epoc