In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-04-02 23:45:47.311812: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-02 23:45:47.317696: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-02 23:45:47.348903: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743626747.414678  153586 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743626747.430892  153586 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-02 23:45:47.493139: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=False, activation='tanh'))

    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.002, weight_decay=1e-7, use_ema=True), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [None]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = [2]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,...,Raw_2,Delta_2,Theta_2,Alpha1_2,Alpha2_2,Beta1_2,Beta2_2,Gamma1_2,Gamma2_2,predefinedlabel
0,0,0.0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0.0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0,0.0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,...,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0


In [None]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-04-02 23:45:54.803709: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [None]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 261ms/step - accuracy: 0.6118 - auc: 0.6400 - loss: 0.6626 - val_accuracy: 0.7000 - val_auc: 0.7356 - val_loss: 0.6010
Epoch 2/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 180ms/step - accuracy: 0.7638 - auc: 0.7659 - loss: 0.5333 - val_accuracy: 0.7000 - val_auc: 0.7022 - val_loss: 0.5818
Epoch 3/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 248ms/step - accuracy: 0.7640 - auc: 0.7730 - loss: 0.5198 - val_accuracy: 0.7000 - val_auc: 0.6933 - val_loss: 0.5923
Epoch 4/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 237ms/step - accuracy: 0.7803 - auc: 0.7739 - loss: 0.5134 - val_accuracy: 0.7000 - val_auc: 0.6778 - val_loss: 0.5983
Epoch 5/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 202ms/step - accuracy: 0.7729 - auc: 0.7634 - loss: 0.5034 - val_accuracy: 0.7000 - val_auc: 0.7222 - val_loss: 0.5983
Epoch 6/200
[1m42/42

1it [35:31, 2131.93s/it]

Epoch 1/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 277ms/step - accuracy: 0.6165 - auc: 0.7019 - loss: 0.6288 - val_accuracy: 0.7000 - val_auc: 0.6956 - val_loss: 0.5490
Epoch 2/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 220ms/step - accuracy: 0.7711 - auc: 0.7524 - loss: 0.5162 - val_accuracy: 0.7000 - val_auc: 0.6956 - val_loss: 0.5470
Epoch 3/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 195ms/step - accuracy: 0.7609 - auc: 0.7606 - loss: 0.5078 - val_accuracy: 0.7000 - val_auc: 0.7556 - val_loss: 0.5432
Epoch 4/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 231ms/step - accuracy: 0.7601 - auc: 0.7596 - loss: 0.5026 - val_accuracy: 0.7333 - val_auc: 0.7822 - val_loss: 0.5466
Epoch 5/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 220ms/step - accuracy: 0.7640 - auc: 0.7861 - loss: 0.4981 - val_accuracy: 0.7667 - val_auc: 0.8244 - val_loss: 0.5346
Epoch 6/200
[1m42/42[

2it [1:17:03, 2343.50s/it]

Epoch 1/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 308ms/step - accuracy: 0.6683 - auc: 0.6833 - loss: 0.6338 - val_accuracy: 0.7000 - val_auc: 0.7356 - val_loss: 0.6675
Epoch 2/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 318ms/step - accuracy: 0.7560 - auc: 0.7453 - loss: 0.5495 - val_accuracy: 0.7000 - val_auc: 0.7333 - val_loss: 0.6448
Epoch 3/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 271ms/step - accuracy: 0.7489 - auc: 0.7466 - loss: 0.5395 - val_accuracy: 0.7000 - val_auc: 0.7333 - val_loss: 0.6301
Epoch 4/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 297ms/step - accuracy: 0.7522 - auc: 0.7709 - loss: 0.5273 - val_accuracy: 0.7000 - val_auc: 0.7533 - val_loss: 0.6298
Epoch 5/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 272ms/step - accuracy: 0.7605 - auc: 0.7672 - loss: 0.5274 - val_accuracy: 0.7333 - val_auc: 0.7622 - val_loss: 0.6397
Epoch 6/200
[1m42/4

3it [2:13:36, 2822.89s/it]

Epoch 1/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 331ms/step - accuracy: 0.4943 - auc: 0.5895 - loss: 0.6501 - val_accuracy: 0.7667 - val_auc: 0.6956 - val_loss: 0.5992
Epoch 2/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 310ms/step - accuracy: 0.7537 - auc: 0.7057 - loss: 0.5454 - val_accuracy: 0.7667 - val_auc: 0.7489 - val_loss: 0.5993
Epoch 3/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 272ms/step - accuracy: 0.7549 - auc: 0.7127 - loss: 0.5326 - val_accuracy: 0.7667 - val_auc: 0.7911 - val_loss: 0.6148
Epoch 4/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 337ms/step - accuracy: 0.7511 - auc: 0.7297 - loss: 0.5151 - val_accuracy: 0.7667 - val_auc: 0.7489 - val_loss: 0.6078
Epoch 5/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 306ms/step - accuracy: 0.7410 - auc: 0.7363 - loss: 0.5244 - val_accuracy: 0.7667 - val_auc: 0.8044 - val_loss: 0.6044
Epoch 6/200
[1m42/4

4it [3:41:19, 3785.88s/it]

Epoch 1/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 358ms/step - accuracy: 0.6938 - auc: 0.6550 - loss: 0.6351 - val_accuracy: 0.7333 - val_auc: 0.6911 - val_loss: 0.5958
Epoch 2/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 322ms/step - accuracy: 0.7987 - auc: 0.7834 - loss: 0.4770 - val_accuracy: 0.7000 - val_auc: 0.7467 - val_loss: 0.5789
Epoch 3/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 275ms/step - accuracy: 0.8034 - auc: 0.8064 - loss: 0.4621 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5859
Epoch 4/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 277ms/step - accuracy: 0.7976 - auc: 0.7798 - loss: 0.4656 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.6028
Epoch 5/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 309ms/step - accuracy: 0.7922 - auc: 0.8058 - loss: 0.4592 - val_accuracy: 0.7000 - val_auc: 0.7667 - val_loss: 0.6233
Epoch 6/200
[1m42/4

5it [4:36:38, 3319.74s/it]

CPU times: user 6h 10min 32s, sys: 1h 46min 13s, total: 7h 56min 45s
Wall time: 4h 36min 38s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

with open("./logs/GRUv2_emb3.txt", "w") as f:
    for i in range(NUM_EPOCHS):
        f.write(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}\n")
        f.write(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}\n")

Epoch 1: TRAIN Accuracy = 0.693 Loss = 0.594 AUC = 0.74
Epoch 1: VAL Accuracy = 0.72 Loss = 0.603 AUC = 0.711
Epoch 2: TRAIN Accuracy = 0.769 Loss = 0.518 AUC = 0.758
Epoch 2: VAL Accuracy = 0.713 Loss = 0.59 AUC = 0.725
Epoch 3: TRAIN Accuracy = 0.768 Loss = 0.506 AUC = 0.762
Epoch 3: VAL Accuracy = 0.72 Loss = 0.593 AUC = 0.744
Epoch 4: TRAIN Accuracy = 0.771 Loss = 0.501 AUC = 0.767
Epoch 4: VAL Accuracy = 0.727 Loss = 0.597 AUC = 0.742
Epoch 5: TRAIN Accuracy = 0.771 Loss = 0.5 AUC = 0.764
Epoch 5: VAL Accuracy = 0.733 Loss = 0.6 AUC = 0.776
Epoch 6: TRAIN Accuracy = 0.772 Loss = 0.497 AUC = 0.769
Epoch 6: VAL Accuracy = 0.727 Loss = 0.605 AUC = 0.802
Epoch 7: TRAIN Accuracy = 0.775 Loss = 0.492 AUC = 0.77
Epoch 7: VAL Accuracy = 0.733 Loss = 0.605 AUC = 0.789
Epoch 8: TRAIN Accuracy = 0.774 Loss = 0.491 AUC = 0.78
Epoch 8: VAL Accuracy = 0.733 Loss = 0.598 AUC = 0.82
Epoch 9: TRAIN Accuracy = 0.774 Loss = 0.488 AUC = 0.787
Epoch 9: VAL Accuracy = 0.733 Loss = 0.599 AUC = 0.794
Epo