In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Conv1D, Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-04-07 23:41:16.049596: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-07 23:41:16.054651: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-07 23:41:16.076701: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744058476.116732  895684 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744058476.129348  895684 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-07 23:41:16.173207: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Conv1D(filters=50, kernel_size=5, activation='relu', padding='causal'))
    model.add(Conv1D(filters=50, kernel_size=3, activation='relu', padding='causal'))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=False, activation='tanh'))

    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.002, weight_decay=1e-7, use_ema=True), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,0.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,0.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,0.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-04-07 23:41:20.681891: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 443ms/step - accuracy: 0.5408 - auc: 0.5398 - loss: 0.6892 - val_accuracy: 0.8000 - val_auc: 0.8222 - val_loss: 0.5748
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 203ms/step - accuracy: 0.7341 - auc: 0.6943 - loss: 0.6281 - val_accuracy: 0.8000 - val_auc: 0.8400 - val_loss: 0.5279
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 239ms/step - accuracy: 0.7285 - auc: 0.8005 - loss: 0.5734 - val_accuracy: 0.8000 - val_auc: 0.8444 - val_loss: 0.4919
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 223ms/step - accuracy: 0.7203 - auc: 0.7076 - loss: 0.5593 - val_accuracy: 0.8000 - val_auc: 0.8600 - val_loss: 0.4742
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 230ms/step - accuracy: 0.7285 - auc: 0.6947 - loss: 0.5666 - val_accuracy: 0.8000 - val_auc: 0.8844 - val_loss: 0.4644
Epoch 6/200
[1m7/7[0m [32m━━━━━

1it [05:41, 341.97s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 340ms/step - accuracy: 0.5164 - auc: 0.5138 - loss: 0.6905 - val_accuracy: 0.7000 - val_auc: 0.7156 - val_loss: 0.6450
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 214ms/step - accuracy: 0.6808 - auc: 0.7210 - loss: 0.6436 - val_accuracy: 0.7667 - val_auc: 0.8044 - val_loss: 0.6157
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 195ms/step - accuracy: 0.7070 - auc: 0.7058 - loss: 0.6239 - val_accuracy: 0.7667 - val_auc: 0.8178 - val_loss: 0.5873
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 185ms/step - accuracy: 0.7230 - auc: 0.7561 - loss: 0.5836 - val_accuracy: 0.7667 - val_auc: 0.8444 - val_loss: 0.5469
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 196ms/step - accuracy: 0.7057 - auc: 0.7070 - loss: 0.5880 - val_accuracy: 0.7667 - val_auc: 0.8133 - val_loss: 0.5247
Epoch 6/200
[1m7/7[0m [32m━━━━━━

2it [10:32, 311.80s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 329ms/step - accuracy: 0.7017 - auc: 0.7315 - loss: 0.6746 - val_accuracy: 0.7333 - val_auc: 0.7911 - val_loss: 0.6308
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 199ms/step - accuracy: 0.7841 - auc: 0.7513 - loss: 0.6336 - val_accuracy: 0.7333 - val_auc: 0.7578 - val_loss: 0.5896
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 217ms/step - accuracy: 0.7624 - auc: 0.6835 - loss: 0.5889 - val_accuracy: 0.7667 - val_auc: 0.7756 - val_loss: 0.5497
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 193ms/step - accuracy: 0.7841 - auc: 0.6929 - loss: 0.5572 - val_accuracy: 0.7667 - val_auc: 0.8000 - val_loss: 0.5339
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 203ms/step - accuracy: 0.7841 - auc: 0.6813 - loss: 0.5301 - val_accuracy: 0.7667 - val_auc: 0.8378 - val_loss: 0.5241
Epoch 6/200
[1m7/7[0m [32m━━━━━━

3it [15:20, 300.71s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 335ms/step - accuracy: 0.5653 - auc: 0.4553 - loss: 0.6969 - val_accuracy: 0.7333 - val_auc: 0.6867 - val_loss: 0.6519
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 222ms/step - accuracy: 0.7379 - auc: 0.7572 - loss: 0.6313 - val_accuracy: 0.6667 - val_auc: 0.5644 - val_loss: 0.6459
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 211ms/step - accuracy: 0.7899 - auc: 0.7603 - loss: 0.5877 - val_accuracy: 0.7333 - val_auc: 0.6933 - val_loss: 0.6112
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 203ms/step - accuracy: 0.7899 - auc: 0.8113 - loss: 0.5340 - val_accuracy: 0.7667 - val_auc: 0.7156 - val_loss: 0.5718
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 202ms/step - accuracy: 0.7899 - auc: 0.8624 - loss: 0.5025 - val_accuracy: 0.7667 - val_auc: 0.7667 - val_loss: 0.5476
Epoch 6/200
[1m7/7[0m [32m━━━━━━

4it [20:14, 298.06s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 345ms/step - accuracy: 0.6637 - auc: 0.5165 - loss: 0.6941 - val_accuracy: 0.7333 - val_auc: 0.7244 - val_loss: 0.6193
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 207ms/step - accuracy: 0.8258 - auc: 0.8290 - loss: 0.5937 - val_accuracy: 0.7333 - val_auc: 0.7044 - val_loss: 0.5954
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 207ms/step - accuracy: 0.8258 - auc: 0.8319 - loss: 0.5183 - val_accuracy: 0.7333 - val_auc: 0.7644 - val_loss: 0.5545
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 208ms/step - accuracy: 0.8258 - auc: 0.7798 - loss: 0.4846 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5622
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 209ms/step - accuracy: 0.8258 - auc: 0.8210 - loss: 0.4655 - val_accuracy: 0.7333 - val_auc: 0.8311 - val_loss: 0.5118
Epoch 6/200
[1m7/7[0m [32m━━━━━━

5it [24:07, 289.46s/it]

CPU times: user 51min 55s, sys: 20min 51s, total: 1h 12min 47s
Wall time: 24min 7s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

with open("./logs/GRUv2_conv2.txt", "w") as f:
    for i in range(NUM_EPOCHS):
        f.write(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}\n")
        f.write(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}\n")

Epoch 1: TRAIN Accuracy = 0.623 Loss = 0.676 AUC = 0.589
Epoch 1: VAL Accuracy = 0.74 Loss = 0.624 AUC = 0.748
Epoch 2: TRAIN Accuracy = 0.749 Loss = 0.617 AUC = 0.745
Epoch 2: VAL Accuracy = 0.74 Loss = 0.595 AUC = 0.734
Epoch 3: TRAIN Accuracy = 0.766 Loss = 0.573 AUC = 0.758
Epoch 3: VAL Accuracy = 0.76 Loss = 0.559 AUC = 0.779
Epoch 4: TRAIN Accuracy = 0.766 Loss = 0.541 AUC = 0.751
Epoch 4: VAL Accuracy = 0.767 Loss = 0.538 AUC = 0.793
Epoch 5: TRAIN Accuracy = 0.763 Loss = 0.531 AUC = 0.766
Epoch 5: VAL Accuracy = 0.767 Loss = 0.515 AUC = 0.827
Epoch 6: TRAIN Accuracy = 0.771 Loss = 0.512 AUC = 0.786
Epoch 6: VAL Accuracy = 0.773 Loss = 0.508 AUC = 0.841
Epoch 7: TRAIN Accuracy = 0.771 Loss = 0.502 AUC = 0.801
Epoch 7: VAL Accuracy = 0.76 Loss = 0.515 AUC = 0.846
Epoch 8: TRAIN Accuracy = 0.771 Loss = 0.492 AUC = 0.818
Epoch 8: VAL Accuracy = 0.767 Loss = 0.496 AUC = 0.881
Epoch 9: TRAIN Accuracy = 0.771 Loss = 0.492 AUC = 0.813
Epoch 9: VAL Accuracy = 0.767 Loss = 0.479 AUC = 0.