In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Conv1D, Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-04-07 22:52:30.789371: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-07 22:52:30.794300: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-07 22:52:30.812574: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744055550.838623  733455 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744055550.846793  733455 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-07 22:52:30.872300: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Conv1D(filters=50, kernel_size=3, activation='relu', padding='causal'))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(20, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True, activation='tanh'))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=False, activation='tanh'))

    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.002, weight_decay=1e-7, use_ema=True), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,0.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,0.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,0.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-04-07 22:52:35.049199: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 398ms/step - accuracy: 0.6346 - auc: 0.6745 - loss: 0.6655 - val_accuracy: 0.8000 - val_auc: 0.8000 - val_loss: 0.5651
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 207ms/step - accuracy: 0.7285 - auc: 0.7426 - loss: 0.6016 - val_accuracy: 0.8000 - val_auc: 0.7889 - val_loss: 0.4981
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 206ms/step - accuracy: 0.7249 - auc: 0.7663 - loss: 0.5636 - val_accuracy: 0.8000 - val_auc: 0.7889 - val_loss: 0.4561
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 237ms/step - accuracy: 0.7285 - auc: 0.7374 - loss: 0.5635 - val_accuracy: 0.8000 - val_auc: 0.7889 - val_loss: 0.4456
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 236ms/step - accuracy: 0.7341 - auc: 0.7266 - loss: 0.5425 - val_accuracy: 0.8000 - val_auc: 0.8111 - val_loss: 0.4351
Epoch 6/200
[1m7/7[0m [32m━━━━━

1it [04:55, 295.27s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 321ms/step - accuracy: 0.6191 - auc: 0.7217 - loss: 0.6474 - val_accuracy: 0.7333 - val_auc: 0.7200 - val_loss: 0.5845
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 201ms/step - accuracy: 0.7287 - auc: 0.7898 - loss: 0.5788 - val_accuracy: 0.7667 - val_auc: 0.7200 - val_loss: 0.5235
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 200ms/step - accuracy: 0.7441 - auc: 0.7928 - loss: 0.5256 - val_accuracy: 0.7667 - val_auc: 0.7600 - val_loss: 0.4892
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 192ms/step - accuracy: 0.7554 - auc: 0.8477 - loss: 0.5085 - val_accuracy: 0.7667 - val_auc: 0.7467 - val_loss: 0.4821
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 195ms/step - accuracy: 0.7441 - auc: 0.7710 - loss: 0.5269 - val_accuracy: 0.7667 - val_auc: 0.7822 - val_loss: 0.4739
Epoch 6/200
[1m7/7[0m [32m━━━━━━

2it [09:41, 289.71s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 329ms/step - accuracy: 0.6077 - auc: 0.4459 - loss: 0.7024 - val_accuracy: 0.7333 - val_auc: 0.7222 - val_loss: 0.6210
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 205ms/step - accuracy: 0.7841 - auc: 0.7562 - loss: 0.6220 - val_accuracy: 0.7667 - val_auc: 0.7911 - val_loss: 0.5639
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 200ms/step - accuracy: 0.7785 - auc: 0.6711 - loss: 0.5784 - val_accuracy: 0.7667 - val_auc: 0.7333 - val_loss: 0.5209
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 196ms/step - accuracy: 0.7841 - auc: 0.7644 - loss: 0.5382 - val_accuracy: 0.7667 - val_auc: 0.7889 - val_loss: 0.5019
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 199ms/step - accuracy: 0.7841 - auc: 0.6889 - loss: 0.5132 - val_accuracy: 0.7667 - val_auc: 0.7667 - val_loss: 0.4964
Epoch 6/200
[1m7/7[0m [32m━━━━━━

3it [14:21, 285.44s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 316ms/step - accuracy: 0.6160 - auc: 0.7231 - loss: 0.6598 - val_accuracy: 0.7667 - val_auc: 0.7711 - val_loss: 0.6170
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 202ms/step - accuracy: 0.7899 - auc: 0.8174 - loss: 0.5876 - val_accuracy: 0.7667 - val_auc: 0.8111 - val_loss: 0.5442
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 197ms/step - accuracy: 0.7899 - auc: 0.7781 - loss: 0.5406 - val_accuracy: 0.7667 - val_auc: 0.7778 - val_loss: 0.4928
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 199ms/step - accuracy: 0.7899 - auc: 0.8234 - loss: 0.4981 - val_accuracy: 0.7667 - val_auc: 0.7778 - val_loss: 0.4725
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 197ms/step - accuracy: 0.7899 - auc: 0.8006 - loss: 0.4871 - val_accuracy: 0.7667 - val_auc: 0.8222 - val_loss: 0.4577
Epoch 6/200
[1m7/7[0m [32m━━━━━━

4it [18:58, 282.30s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 210ms/step - accuracy: 0.6114 - auc: 0.5835 - loss: 0.6838 - val_accuracy: 0.7333 - val_auc: 0.7333 - val_loss: 0.6280
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 125ms/step - accuracy: 0.8413 - auc: 0.8336 - loss: 0.6159 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5895
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 124ms/step - accuracy: 0.8258 - auc: 0.8233 - loss: 0.5640 - val_accuracy: 0.7333 - val_auc: 0.7200 - val_loss: 0.5499
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 122ms/step - accuracy: 0.8258 - auc: 0.8454 - loss: 0.4880 - val_accuracy: 0.7333 - val_auc: 0.7267 - val_loss: 0.5286
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 121ms/step - accuracy: 0.8258 - auc: 0.8462 - loss: 0.4317 - val_accuracy: 0.7333 - val_auc: 0.7467 - val_loss: 0.5304
Epoch 6/200
[1m7/7[0m [32m━━━━━━

5it [21:58, 263.61s/it]

CPU times: user 47min 58s, sys: 20min 7s, total: 1h 8min 5s
Wall time: 21min 58s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

with open("./logs/GRUv2_conv.txt", "w") as f:
    for i in range(NUM_EPOCHS):
        f.write(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}\n")
        f.write(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}\n")

Epoch 1: TRAIN Accuracy = 0.671 Loss = 0.657 AUC = 0.678
Epoch 1: VAL Accuracy = 0.753 Loss = 0.603 AUC = 0.749
Epoch 2: TRAIN Accuracy = 0.774 Loss = 0.592 AUC = 0.772
Epoch 2: VAL Accuracy = 0.767 Loss = 0.544 AUC = 0.772
Epoch 3: TRAIN Accuracy = 0.769 Loss = 0.547 AUC = 0.768
Epoch 3: VAL Accuracy = 0.767 Loss = 0.502 AUC = 0.756
Epoch 4: TRAIN Accuracy = 0.777 Loss = 0.511 AUC = 0.789
Epoch 4: VAL Accuracy = 0.767 Loss = 0.486 AUC = 0.766
Epoch 5: TRAIN Accuracy = 0.777 Loss = 0.492 AUC = 0.763
Epoch 5: VAL Accuracy = 0.767 Loss = 0.479 AUC = 0.786
Epoch 6: TRAIN Accuracy = 0.774 Loss = 0.482 AUC = 0.787
Epoch 6: VAL Accuracy = 0.773 Loss = 0.471 AUC = 0.818
Epoch 7: TRAIN Accuracy = 0.78 Loss = 0.474 AUC = 0.788
Epoch 7: VAL Accuracy = 0.78 Loss = 0.456 AUC = 0.837
Epoch 8: TRAIN Accuracy = 0.783 Loss = 0.463 AUC = 0.769
Epoch 8: VAL Accuracy = 0.787 Loss = 0.444 AUC = 0.844
Epoch 9: TRAIN Accuracy = 0.789 Loss = 0.453 AUC = 0.82
Epoch 9: VAL Accuracy = 0.8 Loss = 0.434 AUC = 0.8