In [6]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

# Configuration

In [7]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=False))
    
    model.add(Dropout(0.2))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [8]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = []
USE_DIFF = True
INIT_SEED = 5412

In [9]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

if USE_DIFF:
    for feature_name in FEATURES:
        data[feature_name] = data[feature_name] - data.groupby(IDS)[feature_name].shift(1).fillna(0)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,0.0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,0.0,-328.0,-228176.0,-62529.0,-32296.0,-21751.0,-25200.0,-41410.0,-27935.0,-5553.0,0.0
2,0,0.0,151.0,684566.0,355662.0,200560.0,59867.0,33547.0,126849.0,51950.0,22614.0,0.0


In [10]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [11]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-21 20:08:06.503452: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [12]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 202ms/step - accuracy: 0.5165 - auc: 0.6831 - loss: 0.6314 - val_accuracy: 0.8000 - val_auc: 0.8556 - val_loss: 0.5633
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 159ms/step - accuracy: 0.7285 - auc: 0.6867 - loss: 0.6116 - val_accuracy: 0.8000 - val_auc: 0.8222 - val_loss: 0.5125
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 160ms/step - accuracy: 0.7285 - auc: 0.7768 - loss: 0.5852 - val_accuracy: 0.8000 - val_auc: 0.8222 - val_loss: 0.4858
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 148ms/step - accuracy: 0.7130 - auc: 0.7094 - loss: 0.5667 - val_accuracy: 0.8000 - val_auc: 0.8556 - val_loss: 0.4686
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 187ms/step - accuracy: 0.7203 - auc: 0.7272 - loss: 0.5579 - val_accuracy: 0.8000 - val_auc: 0.8111 - val_loss: 0.4516
Epoch 6/200
[1m7/7[0m [32m━━━━━━

1it [06:03, 363.35s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 454ms/step - accuracy: 0.6410 - auc: 0.6847 - loss: 0.6753 - val_accuracy: 0.7667 - val_auc: 0.7800 - val_loss: 0.6583
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 249ms/step - accuracy: 0.7287 - auc: 0.7464 - loss: 0.6594 - val_accuracy: 0.7667 - val_auc: 0.7778 - val_loss: 0.6346
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 252ms/step - accuracy: 0.7287 - auc: 0.7790 - loss: 0.6467 - val_accuracy: 0.7667 - val_auc: 0.7778 - val_loss: 0.6069
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 233ms/step - accuracy: 0.7400 - auc: 0.8046 - loss: 0.6066 - val_accuracy: 0.7667 - val_auc: 0.7244 - val_loss: 0.5712
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 239ms/step - accuracy: 0.7441 - auc: 0.7214 - loss: 0.5806 - val_accuracy: 0.7667 - val_auc: 0.7556 - val_loss: 0.5398
Epoch 6/200
[1m7/7[0m [32m━━━━━

2it [12:21, 371.96s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 412ms/step - accuracy: 0.5152 - auc: 0.7926 - loss: 0.6633 - val_accuracy: 0.7000 - val_auc: 0.8044 - val_loss: 0.6244
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 244ms/step - accuracy: 0.5505 - auc: 0.7377 - loss: 0.6287 - val_accuracy: 0.5000 - val_auc: 0.8022 - val_loss: 0.5821
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 234ms/step - accuracy: 0.4741 - auc: 0.7856 - loss: 0.6038 - val_accuracy: 0.7667 - val_auc: 0.8022 - val_loss: 0.5429
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 238ms/step - accuracy: 0.6416 - auc: 0.7289 - loss: 0.5843 - val_accuracy: 0.7667 - val_auc: 0.8022 - val_loss: 0.5324
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 229ms/step - accuracy: 0.7407 - auc: 0.7796 - loss: 0.5595 - val_accuracy: 0.7667 - val_auc: 0.8022 - val_loss: 0.5394
Epoch 6/200
[1m7/7[0m [32m━━━━━━

3it [18:37, 373.71s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 413ms/step - accuracy: 0.5907 - auc: 0.5340 - loss: 0.6914 - val_accuracy: 0.8000 - val_auc: 0.7978 - val_loss: 0.6610
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 297ms/step - accuracy: 0.7203 - auc: 0.7539 - loss: 0.6632 - val_accuracy: 0.8000 - val_auc: 0.8000 - val_loss: 0.6234
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 258ms/step - accuracy: 0.7787 - auc: 0.8172 - loss: 0.6253 - val_accuracy: 0.8000 - val_auc: 0.7667 - val_loss: 0.5841
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 236ms/step - accuracy: 0.7682 - auc: 0.7452 - loss: 0.6051 - val_accuracy: 0.8000 - val_auc: 0.7511 - val_loss: 0.5466
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 234ms/step - accuracy: 0.7739 - auc: 0.8072 - loss: 0.5444 - val_accuracy: 0.8000 - val_auc: 0.7400 - val_loss: 0.5218
Epoch 6/200
[1m7/7[0m [32m━━━━━━

4it [24:57, 376.23s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 458ms/step - accuracy: 0.5934 - auc: 0.7341 - loss: 0.6672 - val_accuracy: 0.7333 - val_auc: 0.7067 - val_loss: 0.6362
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 255ms/step - accuracy: 0.8475 - auc: 0.8364 - loss: 0.6313 - val_accuracy: 0.7333 - val_auc: 0.7533 - val_loss: 0.6063
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 240ms/step - accuracy: 0.8258 - auc: 0.8947 - loss: 0.5726 - val_accuracy: 0.7333 - val_auc: 0.7511 - val_loss: 0.5705
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 234ms/step - accuracy: 0.8258 - auc: 0.8414 - loss: 0.5399 - val_accuracy: 0.7333 - val_auc: 0.7378 - val_loss: 0.5475
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 262ms/step - accuracy: 0.8258 - auc: 0.8111 - loss: 0.5162 - val_accuracy: 0.7333 - val_auc: 0.7222 - val_loss: 0.5278
Epoch 6/200
[1m7/7[0m [32m━━━━━

5it [31:10, 374.16s/it]

CPU times: user 1h 3min 1s, sys: 22min 57s, total: 1h 25min 59s
Wall time: 31min 10s





In [13]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

Epoch 1: TRAIN Accuracy = 0.637 Loss = 0.659 AUC = 0.694
Epoch 1: VAL Accuracy = 0.76 Loss = 0.629 AUC = 0.789
Epoch 2: TRAIN Accuracy = 0.737 Loss = 0.63 AUC = 0.764
Epoch 2: VAL Accuracy = 0.72 Loss = 0.592 AUC = 0.791
Epoch 3: TRAIN Accuracy = 0.729 Loss = 0.596 AUC = 0.778
Epoch 3: VAL Accuracy = 0.773 Loss = 0.558 AUC = 0.784
Epoch 4: TRAIN Accuracy = 0.749 Loss = 0.57 AUC = 0.754
Epoch 4: VAL Accuracy = 0.773 Loss = 0.533 AUC = 0.774
Epoch 5: TRAIN Accuracy = 0.766 Loss = 0.54 AUC = 0.768
Epoch 5: VAL Accuracy = 0.773 Loss = 0.516 AUC = 0.766
Epoch 6: TRAIN Accuracy = 0.76 Loss = 0.525 AUC = 0.794
Epoch 6: VAL Accuracy = 0.773 Loss = 0.499 AUC = 0.779
Epoch 7: TRAIN Accuracy = 0.771 Loss = 0.499 AUC = 0.779
Epoch 7: VAL Accuracy = 0.787 Loss = 0.483 AUC = 0.773
Epoch 8: TRAIN Accuracy = 0.771 Loss = 0.484 AUC = 0.787
Epoch 8: VAL Accuracy = 0.787 Loss = 0.46 AUC = 0.788
Epoch 9: TRAIN Accuracy = 0.791 Loss = 0.475 AUC = 0.778
Epoch 9: VAL Accuracy = 0.793 Loss = 0.447 AUC = 0.777