In [4]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-04-03 10:58:00.697187: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-03 10:58:00.700900: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-03 10:58:00.710653: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743667080.727134  634027 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743667080.732049  634027 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-03 10:58:00.752056: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [5]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=False, activation='tanh'))

    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.002, weight_decay=1e-7, use_ema=True), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [6]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = [2]
INIT_SEED = 5412

In [7]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,Delta_2,Theta_2,Alpha1_2,Alpha2_2,Beta1_2,Beta2_2,Gamma1_2,Gamma2_2,predefinedlabel
0,0,0.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0,0.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0


In [8]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [9]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-04-03 10:58:03.821690: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [10]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 595ms/step - accuracy: 0.6301 - auc: 0.5507 - loss: 0.6882 - val_accuracy: 0.8000 - val_auc: 0.8000 - val_loss: 0.6294
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 321ms/step - accuracy: 0.7216 - auc: 0.6951 - loss: 0.6477 - val_accuracy: 0.8000 - val_auc: 0.7889 - val_loss: 0.5737
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 321ms/step - accuracy: 0.7341 - auc: 0.7918 - loss: 0.6143 - val_accuracy: 0.8000 - val_auc: 0.8000 - val_loss: 0.5277
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 325ms/step - accuracy: 0.7285 - auc: 0.6683 - loss: 0.5823 - val_accuracy: 0.8000 - val_auc: 0.7600 - val_loss: 0.4919
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 374ms/step - accuracy: 0.7341 - auc: 0.8037 - loss: 0.5452 - val_accuracy: 0.8000 - val_auc: 0.7778 - val_loss: 0.4708
Epoch 6/200
[1m7/7[0m [32m━━━━━

1it [07:02, 422.17s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 389ms/step - accuracy: 0.4129 - auc: 0.4792 - loss: 0.6884 - val_accuracy: 0.5000 - val_auc: 0.7600 - val_loss: 0.6489
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 216ms/step - accuracy: 0.7037 - auc: 0.8074 - loss: 0.6262 - val_accuracy: 0.7667 - val_auc: 0.7600 - val_loss: 0.5762
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 193ms/step - accuracy: 0.7287 - auc: 0.7407 - loss: 0.5749 - val_accuracy: 0.7667 - val_auc: 0.7600 - val_loss: 0.5099
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 244ms/step - accuracy: 0.7287 - auc: 0.8227 - loss: 0.5099 - val_accuracy: 0.8000 - val_auc: 0.7600 - val_loss: 0.4716
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 210ms/step - accuracy: 0.7287 - auc: 0.7485 - loss: 0.5273 - val_accuracy: 0.8000 - val_auc: 0.7600 - val_loss: 0.4650
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [13:57, 418.19s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 430ms/step - accuracy: 0.6324 - auc: 0.5248 - loss: 0.6827 - val_accuracy: 0.7667 - val_auc: 0.7756 - val_loss: 0.6149
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 238ms/step - accuracy: 0.8058 - auc: 0.7720 - loss: 0.6106 - val_accuracy: 0.7667 - val_auc: 0.7178 - val_loss: 0.5523
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 251ms/step - accuracy: 0.7728 - auc: 0.6573 - loss: 0.5597 - val_accuracy: 0.7667 - val_auc: 0.7467 - val_loss: 0.5127
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 234ms/step - accuracy: 0.8058 - auc: 0.7088 - loss: 0.5163 - val_accuracy: 0.7667 - val_auc: 0.7311 - val_loss: 0.5077
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 215ms/step - accuracy: 0.7841 - auc: 0.7383 - loss: 0.4909 - val_accuracy: 0.7667 - val_auc: 0.7444 - val_loss: 0.5011
Epoch 6/200
[1m7/7[0m [32m━━━━━

3it [21:12, 425.85s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 417ms/step - accuracy: 0.6205 - auc: 0.4746 - loss: 0.7010 - val_accuracy: 0.7667 - val_auc: 0.8111 - val_loss: 0.6021
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 227ms/step - accuracy: 0.7899 - auc: 0.7722 - loss: 0.5893 - val_accuracy: 0.7667 - val_auc: 0.7667 - val_loss: 0.5338
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 236ms/step - accuracy: 0.7899 - auc: 0.7698 - loss: 0.5281 - val_accuracy: 0.7667 - val_auc: 0.7222 - val_loss: 0.4985
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 198ms/step - accuracy: 0.7899 - auc: 0.8390 - loss: 0.4608 - val_accuracy: 0.7667 - val_auc: 0.7444 - val_loss: 0.4918
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 260ms/step - accuracy: 0.7981 - auc: 0.7984 - loss: 0.4608 - val_accuracy: 0.7667 - val_auc: 0.7333 - val_loss: 0.4849
Epoch 6/200
[1m7/7[0m [32m━━━━━

4it [28:26, 428.87s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 424ms/step - accuracy: 0.5551 - auc: 0.2068 - loss: 0.7352 - val_accuracy: 0.6333 - val_auc: 0.6200 - val_loss: 0.6770
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 214ms/step - accuracy: 0.7307 - auc: 0.6814 - loss: 0.6573 - val_accuracy: 0.7333 - val_auc: 0.7111 - val_loss: 0.6090
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 228ms/step - accuracy: 0.8258 - auc: 0.8332 - loss: 0.5808 - val_accuracy: 0.7333 - val_auc: 0.7067 - val_loss: 0.5563
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 262ms/step - accuracy: 0.8258 - auc: 0.8036 - loss: 0.5070 - val_accuracy: 0.7333 - val_auc: 0.7067 - val_loss: 0.5242
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 213ms/step - accuracy: 0.8258 - auc: 0.8269 - loss: 0.4499 - val_accuracy: 0.7333 - val_auc: 0.7200 - val_loss: 0.5134
Epoch 6/200
[1m7/7[0m [32m━━━━━

5it [35:39, 427.95s/it]

CPU times: user 49min, sys: 13min 14s, total: 1h 2min 14s
Wall time: 35min 39s





In [11]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

with open("./logs/GRUv2_emb3.txt", "w") as f:
    for i in range(NUM_EPOCHS):
        f.write(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}\n")
        f.write(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}\n")

Epoch 1: TRAIN Accuracy = 0.583 Loss = 0.687 AUC = 0.54
Epoch 1: VAL Accuracy = 0.693 Loss = 0.634 AUC = 0.753
Epoch 2: TRAIN Accuracy = 0.766 Loss = 0.616 AUC = 0.759
Epoch 2: VAL Accuracy = 0.767 Loss = 0.569 AUC = 0.749
Epoch 3: TRAIN Accuracy = 0.771 Loss = 0.56 AUC = 0.777
Epoch 3: VAL Accuracy = 0.767 Loss = 0.521 AUC = 0.747
Epoch 4: TRAIN Accuracy = 0.774 Loss = 0.511 AUC = 0.771
Epoch 4: VAL Accuracy = 0.773 Loss = 0.497 AUC = 0.74
Epoch 5: TRAIN Accuracy = 0.777 Loss = 0.491 AUC = 0.778
Epoch 5: VAL Accuracy = 0.773 Loss = 0.487 AUC = 0.747
Epoch 6: TRAIN Accuracy = 0.777 Loss = 0.483 AUC = 0.757
Epoch 6: VAL Accuracy = 0.773 Loss = 0.474 AUC = 0.779
Epoch 7: TRAIN Accuracy = 0.789 Loss = 0.474 AUC = 0.772
Epoch 7: VAL Accuracy = 0.78 Loss = 0.462 AUC = 0.79
Epoch 8: TRAIN Accuracy = 0.789 Loss = 0.46 AUC = 0.792
Epoch 8: VAL Accuracy = 0.787 Loss = 0.453 AUC = 0.814
Epoch 9: TRAIN Accuracy = 0.791 Loss = 0.454 AUC = 0.815
Epoch 9: VAL Accuracy = 0.787 Loss = 0.444 AUC = 0.82