In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-05-03 09:29:17.621551: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-03 09:29:17.630418: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-03 09:29:17.661275: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746253757.716852  436913 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746253757.734902  436913 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-03 09:29:17.788473: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 400
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=False, activation='tanh'))

    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.002, weight_decay=1e-7, use_ema=True), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,0.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,0.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,0.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-05-03 09:29:24.836035: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 604ms/step - accuracy: 0.5862 - auc: 0.4123 - loss: 0.7037 - val_accuracy: 0.5333 - val_auc: 0.4844 - val_loss: 0.6978
Epoch 2/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 318ms/step - accuracy: 0.6338 - auc: 0.5856 - loss: 0.6782 - val_accuracy: 0.8000 - val_auc: 0.8000 - val_loss: 0.6406
Epoch 3/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 310ms/step - accuracy: 0.7285 - auc: 0.8000 - loss: 0.6400 - val_accuracy: 0.8000 - val_auc: 0.7889 - val_loss: 0.5786
Epoch 4/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 310ms/step - accuracy: 0.7285 - auc: 0.6485 - loss: 0.6117 - val_accuracy: 0.8000 - val_auc: 0.7889 - val_loss: 0.5288
Epoch 5/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 317ms/step - accuracy: 0.7285 - auc: 0.8326 - loss: 0.5717 - val_accuracy: 0.8000 - val_auc: 0.7889 - val_loss: 0.4836
Epoch 6/400
[1m7/7[0m [32m━━━━━

1it [16:14, 974.46s/it]

Epoch 1/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 582ms/step - accuracy: 0.4043 - auc: 0.4572 - loss: 0.6931 - val_accuracy: 0.5000 - val_auc: 0.7644 - val_loss: 0.6484
Epoch 2/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 301ms/step - accuracy: 0.7504 - auc: 0.8104 - loss: 0.6388 - val_accuracy: 0.7667 - val_auc: 0.7556 - val_loss: 0.6029
Epoch 3/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 305ms/step - accuracy: 0.7287 - auc: 0.7234 - loss: 0.5980 - val_accuracy: 0.7667 - val_auc: 0.7556 - val_loss: 0.5541
Epoch 4/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 309ms/step - accuracy: 0.7287 - auc: 0.8443 - loss: 0.5329 - val_accuracy: 0.7667 - val_auc: 0.7556 - val_loss: 0.5181
Epoch 5/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 369ms/step - accuracy: 0.7287 - auc: 0.7683 - loss: 0.5393 - val_accuracy: 0.7667 - val_auc: 0.7600 - val_loss: 0.4968
Epoch 6/400
[1m7/7[0m [32m━━━━━

2it [32:35, 978.54s/it]

Epoch 1/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 533ms/step - accuracy: 0.5942 - auc: 0.3761 - loss: 0.7159 - val_accuracy: 0.6000 - val_auc: 0.5244 - val_loss: 0.6919
Epoch 2/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 333ms/step - accuracy: 0.7463 - auc: 0.7506 - loss: 0.6451 - val_accuracy: 0.7333 - val_auc: 0.7644 - val_loss: 0.6107
Epoch 3/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 378ms/step - accuracy: 0.7841 - auc: 0.6901 - loss: 0.5829 - val_accuracy: 0.7667 - val_auc: 0.7511 - val_loss: 0.5423
Epoch 4/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 334ms/step - accuracy: 0.8058 - auc: 0.7063 - loss: 0.5271 - val_accuracy: 0.7667 - val_auc: 0.7400 - val_loss: 0.5039
Epoch 5/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 380ms/step - accuracy: 0.7841 - auc: 0.7346 - loss: 0.4987 - val_accuracy: 0.7667 - val_auc: 0.7778 - val_loss: 0.4937
Epoch 6/400
[1m7/7[0m [32m━━━━━

3it [49:34, 996.81s/it]

Epoch 1/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 419ms/step - accuracy: 0.6012 - auc: 0.4306 - loss: 0.7144 - val_accuracy: 0.8000 - val_auc: 0.8000 - val_loss: 0.6049
Epoch 2/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 312ms/step - accuracy: 0.7899 - auc: 0.7706 - loss: 0.5857 - val_accuracy: 0.7667 - val_auc: 0.7667 - val_loss: 0.5305
Epoch 3/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 337ms/step - accuracy: 0.7899 - auc: 0.7721 - loss: 0.5144 - val_accuracy: 0.7667 - val_auc: 0.7667 - val_loss: 0.4815
Epoch 4/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 381ms/step - accuracy: 0.7899 - auc: 0.8454 - loss: 0.4630 - val_accuracy: 0.8000 - val_auc: 0.8133 - val_loss: 0.4535
Epoch 5/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 382ms/step - accuracy: 0.7899 - auc: 0.8165 - loss: 0.4582 - val_accuracy: 0.8000 - val_auc: 0.8222 - val_loss: 0.4331
Epoch 6/400
[1m7/7[0m [32m━━━━━

4it [1:07:35, 1029.99s/it]

Epoch 1/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 553ms/step - accuracy: 0.5593 - auc: 0.2060 - loss: 0.7374 - val_accuracy: 0.5000 - val_auc: 0.5356 - val_loss: 0.6929
Epoch 2/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 334ms/step - accuracy: 0.7244 - auc: 0.6840 - loss: 0.6704 - val_accuracy: 0.7333 - val_auc: 0.7378 - val_loss: 0.6413
Epoch 3/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 328ms/step - accuracy: 0.8258 - auc: 0.8323 - loss: 0.6092 - val_accuracy: 0.7333 - val_auc: 0.6800 - val_loss: 0.5926
Epoch 4/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 330ms/step - accuracy: 0.8258 - auc: 0.8035 - loss: 0.5439 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5510
Epoch 5/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 326ms/step - accuracy: 0.8258 - auc: 0.8158 - loss: 0.4789 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5260
Epoch 6/400
[1m7/7[0m [32m━━━━━

5it [1:25:32, 1026.43s/it]

CPU times: user 2h 50min 13s, sys: 59min 3s, total: 3h 49min 16s
Wall time: 1h 25min 32s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

with open("./logs/GRUv2_expanded.txt", "w") as f:
    for i in range(NUM_EPOCHS):
        f.write(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}\n")
        f.write(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}\n")

Epoch 1: TRAIN Accuracy = 0.551 Loss = 0.703 AUC = 0.451
Epoch 1: VAL Accuracy = 0.587 Loss = 0.667 AUC = 0.622
Epoch 2: TRAIN Accuracy = 0.734 Loss = 0.636 AUC = 0.739
Epoch 2: VAL Accuracy = 0.76 Loss = 0.605 AUC = 0.765
Epoch 3: TRAIN Accuracy = 0.771 Loss = 0.578 AUC = 0.774
Epoch 3: VAL Accuracy = 0.767 Loss = 0.55 AUC = 0.748
Epoch 4: TRAIN Accuracy = 0.774 Loss = 0.53 AUC = 0.774
Epoch 4: VAL Accuracy = 0.773 Loss = 0.511 AUC = 0.769
Epoch 5: TRAIN Accuracy = 0.771 Loss = 0.505 AUC = 0.778
Epoch 5: VAL Accuracy = 0.773 Loss = 0.487 AUC = 0.779
Epoch 6: TRAIN Accuracy = 0.777 Loss = 0.489 AUC = 0.757
Epoch 6: VAL Accuracy = 0.78 Loss = 0.474 AUC = 0.794
Epoch 7: TRAIN Accuracy = 0.783 Loss = 0.48 AUC = 0.769
Epoch 7: VAL Accuracy = 0.78 Loss = 0.461 AUC = 0.796
Epoch 8: TRAIN Accuracy = 0.789 Loss = 0.465 AUC = 0.791
Epoch 8: VAL Accuracy = 0.78 Loss = 0.45 AUC = 0.822
Epoch 9: TRAIN Accuracy = 0.791 Loss = 0.455 AUC = 0.816
Epoch 9: VAL Accuracy = 0.793 Loss = 0.441 AUC = 0.847
