In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-04-02 23:45:45.494362: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-02 23:45:45.501420: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-02 23:45:45.528415: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743626745.574670  153528 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743626745.589619  153528 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-02 23:45:45.640362: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=False, activation='tanh'))

    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.002, weight_decay=1e-7, use_ema=True), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [None]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = [1, 2]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,...,Raw_2,Delta_2,Theta_2,Alpha1_2,Alpha2_2,Beta1_2,Beta2_2,Gamma1_2,Gamma2_2,predefinedlabel
0,0,0.0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0.0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0,0.0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,...,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0


In [None]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-04-02 23:45:53.295066: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [None]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 253ms/step - accuracy: 0.5722 - auc: 0.6567 - loss: 0.6602 - val_accuracy: 0.7000 - val_auc: 0.7000 - val_loss: 0.5922
Epoch 2/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 224ms/step - accuracy: 0.7652 - auc: 0.7842 - loss: 0.5217 - val_accuracy: 0.7000 - val_auc: 0.7022 - val_loss: 0.5838
Epoch 3/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 239ms/step - accuracy: 0.7776 - auc: 0.7817 - loss: 0.5114 - val_accuracy: 0.7000 - val_auc: 0.6800 - val_loss: 0.5894
Epoch 4/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 198ms/step - accuracy: 0.7835 - auc: 0.7911 - loss: 0.4961 - val_accuracy: 0.7000 - val_auc: 0.6844 - val_loss: 0.6233
Epoch 5/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 231ms/step - accuracy: 0.7685 - auc: 0.7851 - loss: 0.5017 - val_accuracy: 0.7000 - val_auc: 0.6778 - val_loss: 0.6198
Epoch 6/200
[1m42/42

1it [35:00, 2100.72s/it]

Epoch 1/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 273ms/step - accuracy: 0.6184 - auc: 0.7298 - loss: 0.6230 - val_accuracy: 0.7000 - val_auc: 0.6867 - val_loss: 0.5487
Epoch 2/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 232ms/step - accuracy: 0.7681 - auc: 0.7493 - loss: 0.5165 - val_accuracy: 0.7000 - val_auc: 0.6933 - val_loss: 0.5477
Epoch 3/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 177ms/step - accuracy: 0.7722 - auc: 0.7682 - loss: 0.4983 - val_accuracy: 0.7667 - val_auc: 0.7467 - val_loss: 0.5352
Epoch 4/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 217ms/step - accuracy: 0.7834 - auc: 0.7784 - loss: 0.4908 - val_accuracy: 0.7667 - val_auc: 0.8000 - val_loss: 0.5325
Epoch 5/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 231ms/step - accuracy: 0.7733 - auc: 0.7917 - loss: 0.4916 - val_accuracy: 0.7667 - val_auc: 0.8133 - val_loss: 0.5367
Epoch 6/200
[1m42/42

2it [1:17:52, 2378.00s/it]

Epoch 1/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 347ms/step - accuracy: 0.6694 - auc: 0.6672 - loss: 0.6362 - val_accuracy: 0.7000 - val_auc: 0.7244 - val_loss: 0.6618
Epoch 2/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 278ms/step - accuracy: 0.7545 - auc: 0.7495 - loss: 0.5441 - val_accuracy: 0.7000 - val_auc: 0.7356 - val_loss: 0.6243
Epoch 3/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 286ms/step - accuracy: 0.7567 - auc: 0.7549 - loss: 0.5322 - val_accuracy: 0.7000 - val_auc: 0.7333 - val_loss: 0.6438
Epoch 4/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 324ms/step - accuracy: 0.7428 - auc: 0.7773 - loss: 0.5138 - val_accuracy: 0.7000 - val_auc: 0.7200 - val_loss: 0.6097
Epoch 5/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 264ms/step - accuracy: 0.7565 - auc: 0.7697 - loss: 0.5269 - val_accuracy: 0.7000 - val_auc: 0.7444 - val_loss: 0.6327
Epoch 6/200
[1m42/4

3it [2:13:11, 2807.45s/it]

Epoch 1/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 342ms/step - accuracy: 0.5575 - auc: 0.5965 - loss: 0.6560 - val_accuracy: 0.6333 - val_auc: 0.7289 - val_loss: 0.6439
Epoch 2/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 290ms/step - accuracy: 0.7486 - auc: 0.7124 - loss: 0.5420 - val_accuracy: 0.7333 - val_auc: 0.7244 - val_loss: 0.6093
Epoch 3/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 306ms/step - accuracy: 0.7533 - auc: 0.7070 - loss: 0.5302 - val_accuracy: 0.7000 - val_auc: 0.7333 - val_loss: 0.6466
Epoch 4/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 275ms/step - accuracy: 0.7532 - auc: 0.7258 - loss: 0.5139 - val_accuracy: 0.7000 - val_auc: 0.7933 - val_loss: 0.6433
Epoch 5/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 329ms/step - accuracy: 0.7592 - auc: 0.7392 - loss: 0.5033 - val_accuracy: 0.6667 - val_auc: 0.7778 - val_loss: 0.7066
Epoch 6/200
[1m42/4

4it [3:35:16, 3643.70s/it]

Epoch 1/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m206s[0m 313ms/step - accuracy: 0.7137 - auc: 0.6457 - loss: 0.6391 - val_accuracy: 0.6667 - val_auc: 0.7111 - val_loss: 0.6046
Epoch 2/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 274ms/step - accuracy: 0.8001 - auc: 0.7870 - loss: 0.4783 - val_accuracy: 0.7333 - val_auc: 0.7333 - val_loss: 0.5765
Epoch 3/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 291ms/step - accuracy: 0.7969 - auc: 0.8073 - loss: 0.4621 - val_accuracy: 0.7333 - val_auc: 0.7556 - val_loss: 0.5850
Epoch 4/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 294ms/step - accuracy: 0.8050 - auc: 0.8008 - loss: 0.4578 - val_accuracy: 0.7000 - val_auc: 0.7956 - val_loss: 0.5891
Epoch 5/200
[1m42/42[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 277ms/step - accuracy: 0.8115 - auc: 0.8123 - loss: 0.4460 - val_accuracy: 0.7333 - val_auc: 0.8089 - val_loss: 0.5683
Epoch 6/200
[1m42/

5it [4:35:54, 3310.96s/it]

CPU times: user 6h 10min 16s, sys: 1h 45min 35s, total: 7h 55min 52s
Wall time: 4h 35min 54s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

with open("./logs/GRUv2_emb2.txt", "w") as f:
    for i in range(NUM_EPOCHS):
        f.write(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}\n")
        f.write(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}\n")

Epoch 1: TRAIN Accuracy = 0.69 Loss = 0.596 AUC = 0.739
Epoch 1: VAL Accuracy = 0.68 Loss = 0.61 AUC = 0.71
Epoch 2: TRAIN Accuracy = 0.766 Loss = 0.517 AUC = 0.762
Epoch 2: VAL Accuracy = 0.713 Loss = 0.588 AUC = 0.718
Epoch 3: TRAIN Accuracy = 0.772 Loss = 0.503 AUC = 0.765
Epoch 3: VAL Accuracy = 0.72 Loss = 0.6 AUC = 0.73
Epoch 4: TRAIN Accuracy = 0.773 Loss = 0.494 AUC = 0.776
Epoch 4: VAL Accuracy = 0.713 Loss = 0.6 AUC = 0.759
Epoch 5: TRAIN Accuracy = 0.77 Loss = 0.496 AUC = 0.772
Epoch 5: VAL Accuracy = 0.713 Loss = 0.613 AUC = 0.764
Epoch 6: TRAIN Accuracy = 0.771 Loss = 0.49 AUC = 0.773
Epoch 6: VAL Accuracy = 0.707 Loss = 0.626 AUC = 0.76
Epoch 7: TRAIN Accuracy = 0.774 Loss = 0.488 AUC = 0.776
Epoch 7: VAL Accuracy = 0.713 Loss = 0.601 AUC = 0.782
Epoch 8: TRAIN Accuracy = 0.773 Loss = 0.489 AUC = 0.78
Epoch 8: VAL Accuracy = 0.707 Loss = 0.62 AUC = 0.78
Epoch 9: TRAIN Accuracy = 0.77 Loss = 0.483 AUC = 0.786
Epoch 9: VAL Accuracy = 0.72 Loss = 0.617 AUC = 0.78
Epoch 10: T

Epoch 194: TRAIN Accuracy = 0.795 Loss = 0.407 AUC = 0.88
Epoch 194: VAL Accuracy = 0.72 Loss = 0.679 AUC = 0.818
Epoch 195: TRAIN Accuracy = 0.795 Loss = 0.392 AUC = 0.886
Epoch 195: VAL Accuracy = 0.707 Loss = 0.68 AUC = 0.812
Epoch 196: TRAIN Accuracy = 0.79 Loss = 0.4 AUC = 0.88
Epoch 196: VAL Accuracy = 0.707 Loss = 0.684 AUC = 0.814
Epoch 197: TRAIN Accuracy = 0.797 Loss = 0.403 AUC = 0.877
Epoch 197: VAL Accuracy = 0.713 Loss = 0.659 AUC = 0.824
Epoch 198: TRAIN Accuracy = 0.8 Loss = 0.394 AUC = 0.887
Epoch 198: VAL Accuracy = 0.727 Loss = 0.641 AUC = 0.83
Epoch 199: TRAIN Accuracy = 0.794 Loss = 0.408 AUC = 0.877
Epoch 199: VAL Accuracy = 0.713 Loss = 0.658 AUC = 0.828
Epoch 200: TRAIN Accuracy = 0.789 Loss = 0.398 AUC = 0.885
Epoch 200: VAL Accuracy = 0.72 Loss = 0.634 AUC = 0.827
