In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-02-19 14:06:04.563090: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-19 14:06:04.570657: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-19 14:06:04.583219: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1739963164.605292 3543600 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739963164.612652 3543600 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-19 14:06:04.641663: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 300
STEP = 5
NUM_EXPERIMENTS = 10

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(GRU(10, return_sequences=True))
    model.add(GRU(10, return_sequences=True))
    model.add(GRU(10, return_sequences=True))
    model.add(GRU(10, return_sequences=True))
    model.add(GRU(10, return_sequences=False))

    model.add(Dense(10, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.0003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-02-19 14:06:07.911473: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

models = []

for i in range(NUM_EPOCHS // STEP):
    epoch_acc = []
    epoch_loss = []
    epoch_auc = []

    epoch_val_acc = []
    epoch_val_loss = []
    epoch_val_auc = []

    for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
        np.random.seed(int(seed))
        random.seed(int(seed))
        tf.random.set_seed(int(seed))

        train_id = np.random.choice(np.unique(np.ravel(data[ID])), 70, replace=False)
        train_index = np.isin(data[ID], train_id)

        train = data.iloc[train_index]
        test = data.iloc[~train_index]

        X_train, y_train = reshape_dataset(train)
        X_test, y_test = reshape_dataset(test)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)

        if i == 0:
            model = create_model(X_train)
            models.append(model)
        else:
            model = models[j]

        history = model.fit(
            X_train, y_train,
            validation_data=(X_test, y_test),
            epochs=STEP,
            batch_size=16,
            verbose=0,
        )
        acc = history.history['accuracy'][0]
        loss = history.history['loss'][0]
        auc = history.history['auc'][0]

        val_acc = history.history['val_accuracy'][0]
        val_loss = history.history['val_loss'][0]
        val_auc = history.history['val_auc'][0]

        epoch_acc.append(acc)
        epoch_loss.append(loss)
        epoch_auc.append(auc)

        epoch_val_acc.append(val_acc)
        epoch_val_loss.append(val_loss)
        epoch_val_auc.append(val_auc)

    print(f"Epoch {(i + 1) * STEP}: TRAIN Accuracy = {np.round(np.mean(epoch_acc), 3)} Loss = {np.round(np.mean(epoch_loss), 3)} AUC = {np.round(np.mean(epoch_auc), 3)}")
    print(f"Epoch {(i + 1) * STEP}: VAL Accuracy = {np.round(np.mean(epoch_val_acc), 3)} Loss = {np.round(np.mean(epoch_val_loss), 3)} AUC = {np.round(np.mean(epoch_val_auc), 3)}")

10it [04:08, 24.83s/it]


Epoch 5: TRAIN Accuracy = 0.484 Loss = 0.696 AUC = 0.472
Epoch 5: VAL Accuracy = 0.557 Loss = 0.691 AUC = 0.569


10it [01:29,  8.93s/it]


Epoch 10: TRAIN Accuracy = 0.646 Loss = 0.676 AUC = 0.65
Epoch 10: VAL Accuracy = 0.603 Loss = 0.673 AUC = 0.677


10it [01:28,  8.81s/it]


Epoch 15: TRAIN Accuracy = 0.699 Loss = 0.658 AUC = 0.722
Epoch 15: VAL Accuracy = 0.657 Loss = 0.653 AUC = 0.727


10it [01:28,  8.85s/it]


Epoch 20: TRAIN Accuracy = 0.734 Loss = 0.635 AUC = 0.763
Epoch 20: VAL Accuracy = 0.713 Loss = 0.629 AUC = 0.769


10it [01:29,  8.92s/it]


Epoch 25: TRAIN Accuracy = 0.757 Loss = 0.606 AUC = 0.761
Epoch 25: VAL Accuracy = 0.743 Loss = 0.6 AUC = 0.76


10it [01:32,  9.27s/it]


Epoch 30: TRAIN Accuracy = 0.76 Loss = 0.575 AUC = 0.765
Epoch 30: VAL Accuracy = 0.757 Loss = 0.571 AUC = 0.776


10it [01:30,  9.03s/it]


Epoch 35: TRAIN Accuracy = 0.773 Loss = 0.546 AUC = 0.763
Epoch 35: VAL Accuracy = 0.767 Loss = 0.546 AUC = 0.759


10it [01:30,  9.03s/it]


Epoch 40: TRAIN Accuracy = 0.774 Loss = 0.52 AUC = 0.776
Epoch 40: VAL Accuracy = 0.767 Loss = 0.526 AUC = 0.777


10it [01:29,  8.97s/it]


Epoch 45: TRAIN Accuracy = 0.781 Loss = 0.497 AUC = 0.788
Epoch 45: VAL Accuracy = 0.77 Loss = 0.507 AUC = 0.781


10it [01:30,  9.00s/it]


Epoch 50: TRAIN Accuracy = 0.787 Loss = 0.478 AUC = 0.801
Epoch 50: VAL Accuracy = 0.773 Loss = 0.491 AUC = 0.794


10it [01:29,  8.96s/it]


Epoch 55: TRAIN Accuracy = 0.794 Loss = 0.463 AUC = 0.807
Epoch 55: VAL Accuracy = 0.777 Loss = 0.478 AUC = 0.799


10it [01:29,  8.95s/it]


Epoch 60: TRAIN Accuracy = 0.799 Loss = 0.449 AUC = 0.826
Epoch 60: VAL Accuracy = 0.777 Loss = 0.467 AUC = 0.81


10it [01:31,  9.11s/it]


Epoch 65: TRAIN Accuracy = 0.806 Loss = 0.438 AUC = 0.841
Epoch 65: VAL Accuracy = 0.773 Loss = 0.46 AUC = 0.827


10it [01:29,  8.95s/it]


Epoch 70: TRAIN Accuracy = 0.81 Loss = 0.43 AUC = 0.84
Epoch 70: VAL Accuracy = 0.78 Loss = 0.455 AUC = 0.817


10it [01:30,  9.07s/it]


Epoch 75: TRAIN Accuracy = 0.811 Loss = 0.423 AUC = 0.848
Epoch 75: VAL Accuracy = 0.78 Loss = 0.451 AUC = 0.826


10it [01:27,  8.80s/it]


Epoch 80: TRAIN Accuracy = 0.811 Loss = 0.417 AUC = 0.86
Epoch 80: VAL Accuracy = 0.787 Loss = 0.448 AUC = 0.855


10it [01:29,  8.92s/it]


Epoch 85: TRAIN Accuracy = 0.811 Loss = 0.412 AUC = 0.871
Epoch 85: VAL Accuracy = 0.79 Loss = 0.445 AUC = 0.858


10it [01:28,  8.89s/it]


Epoch 90: TRAIN Accuracy = 0.811 Loss = 0.407 AUC = 0.873
Epoch 90: VAL Accuracy = 0.79 Loss = 0.443 AUC = 0.866


10it [01:29,  8.96s/it]


Epoch 95: TRAIN Accuracy = 0.811 Loss = 0.403 AUC = 0.872
Epoch 95: VAL Accuracy = 0.79 Loss = 0.441 AUC = 0.857


10it [01:30,  9.09s/it]


Epoch 100: TRAIN Accuracy = 0.811 Loss = 0.399 AUC = 0.883
Epoch 100: VAL Accuracy = 0.79 Loss = 0.441 AUC = 0.864


10it [01:29,  8.99s/it]


Epoch 105: TRAIN Accuracy = 0.811 Loss = 0.395 AUC = 0.885
Epoch 105: VAL Accuracy = 0.793 Loss = 0.44 AUC = 0.865


10it [01:29,  8.91s/it]


Epoch 110: TRAIN Accuracy = 0.811 Loss = 0.391 AUC = 0.891
Epoch 110: VAL Accuracy = 0.793 Loss = 0.44 AUC = 0.87


10it [01:30,  9.05s/it]


Epoch 115: TRAIN Accuracy = 0.811 Loss = 0.387 AUC = 0.891
Epoch 115: VAL Accuracy = 0.793 Loss = 0.44 AUC = 0.868


10it [01:29,  8.96s/it]


Epoch 120: TRAIN Accuracy = 0.811 Loss = 0.384 AUC = 0.895
Epoch 120: VAL Accuracy = 0.79 Loss = 0.441 AUC = 0.878


10it [01:33,  9.35s/it]


Epoch 125: TRAIN Accuracy = 0.811 Loss = 0.381 AUC = 0.895
Epoch 125: VAL Accuracy = 0.79 Loss = 0.442 AUC = 0.876


10it [01:31,  9.17s/it]


Epoch 130: TRAIN Accuracy = 0.811 Loss = 0.377 AUC = 0.897
Epoch 130: VAL Accuracy = 0.79 Loss = 0.444 AUC = 0.879


10it [01:31,  9.15s/it]


Epoch 135: TRAIN Accuracy = 0.811 Loss = 0.374 AUC = 0.895
Epoch 135: VAL Accuracy = 0.79 Loss = 0.445 AUC = 0.88


10it [01:32,  9.27s/it]


Epoch 140: TRAIN Accuracy = 0.811 Loss = 0.371 AUC = 0.895
Epoch 140: VAL Accuracy = 0.79 Loss = 0.447 AUC = 0.875


10it [01:31,  9.20s/it]


Epoch 145: TRAIN Accuracy = 0.811 Loss = 0.368 AUC = 0.897
Epoch 145: VAL Accuracy = 0.79 Loss = 0.449 AUC = 0.877


10it [01:32,  9.27s/it]


Epoch 150: TRAIN Accuracy = 0.811 Loss = 0.364 AUC = 0.9
Epoch 150: VAL Accuracy = 0.79 Loss = 0.45 AUC = 0.88


10it [01:32,  9.20s/it]


Epoch 155: TRAIN Accuracy = 0.811 Loss = 0.361 AUC = 0.898
Epoch 155: VAL Accuracy = 0.787 Loss = 0.45 AUC = 0.881


10it [01:31,  9.16s/it]


Epoch 160: TRAIN Accuracy = 0.811 Loss = 0.357 AUC = 0.901
Epoch 160: VAL Accuracy = 0.787 Loss = 0.448 AUC = 0.881


10it [01:32,  9.22s/it]


Epoch 165: TRAIN Accuracy = 0.811 Loss = 0.354 AUC = 0.902
Epoch 165: VAL Accuracy = 0.787 Loss = 0.447 AUC = 0.876


10it [01:32,  9.20s/it]


Epoch 170: TRAIN Accuracy = 0.811 Loss = 0.351 AUC = 0.902
Epoch 170: VAL Accuracy = 0.79 Loss = 0.45 AUC = 0.874


10it [01:43, 10.32s/it]


Epoch 175: TRAIN Accuracy = 0.811 Loss = 0.347 AUC = 0.905
Epoch 175: VAL Accuracy = 0.783 Loss = 0.454 AUC = 0.869


10it [01:33,  9.34s/it]


Epoch 180: TRAIN Accuracy = 0.811 Loss = 0.344 AUC = 0.907
Epoch 180: VAL Accuracy = 0.783 Loss = 0.459 AUC = 0.868


10it [01:42, 10.23s/it]


Epoch 185: TRAIN Accuracy = 0.809 Loss = 0.34 AUC = 0.913
Epoch 185: VAL Accuracy = 0.783 Loss = 0.464 AUC = 0.865


10it [01:57, 11.75s/it]


Epoch 190: TRAIN Accuracy = 0.814 Loss = 0.337 AUC = 0.915
Epoch 190: VAL Accuracy = 0.777 Loss = 0.47 AUC = 0.858


10it [01:51, 11.13s/it]


Epoch 195: TRAIN Accuracy = 0.816 Loss = 0.333 AUC = 0.919
Epoch 195: VAL Accuracy = 0.777 Loss = 0.473 AUC = 0.86


10it [01:29,  8.90s/it]


Epoch 200: TRAIN Accuracy = 0.814 Loss = 0.33 AUC = 0.919
Epoch 200: VAL Accuracy = 0.773 Loss = 0.476 AUC = 0.86


10it [01:11,  7.13s/it]


Epoch 205: TRAIN Accuracy = 0.814 Loss = 0.326 AUC = 0.922
Epoch 205: VAL Accuracy = 0.767 Loss = 0.48 AUC = 0.858


10it [01:08,  6.89s/it]


Epoch 210: TRAIN Accuracy = 0.827 Loss = 0.322 AUC = 0.925
Epoch 210: VAL Accuracy = 0.757 Loss = 0.482 AUC = 0.858


10it [01:08,  6.90s/it]


Epoch 215: TRAIN Accuracy = 0.831 Loss = 0.318 AUC = 0.926
Epoch 215: VAL Accuracy = 0.753 Loss = 0.486 AUC = 0.858


10it [01:08,  6.86s/it]


Epoch 220: TRAIN Accuracy = 0.834 Loss = 0.314 AUC = 0.929
Epoch 220: VAL Accuracy = 0.757 Loss = 0.489 AUC = 0.861


10it [01:08,  6.84s/it]


Epoch 225: TRAIN Accuracy = 0.836 Loss = 0.309 AUC = 0.93
Epoch 225: VAL Accuracy = 0.753 Loss = 0.497 AUC = 0.858


10it [01:08,  6.90s/it]


Epoch 230: TRAIN Accuracy = 0.843 Loss = 0.304 AUC = 0.932
Epoch 230: VAL Accuracy = 0.75 Loss = 0.51 AUC = 0.854


10it [01:07,  6.74s/it]


Epoch 235: TRAIN Accuracy = 0.837 Loss = 0.3 AUC = 0.933
Epoch 235: VAL Accuracy = 0.75 Loss = 0.522 AUC = 0.847


10it [01:06,  6.68s/it]


Epoch 240: TRAIN Accuracy = 0.84 Loss = 0.295 AUC = 0.935
Epoch 240: VAL Accuracy = 0.747 Loss = 0.531 AUC = 0.845


10it [01:06,  6.68s/it]


Epoch 245: TRAIN Accuracy = 0.84 Loss = 0.29 AUC = 0.936
Epoch 245: VAL Accuracy = 0.743 Loss = 0.539 AUC = 0.837


10it [01:06,  6.68s/it]


Epoch 250: TRAIN Accuracy = 0.843 Loss = 0.285 AUC = 0.937
Epoch 250: VAL Accuracy = 0.73 Loss = 0.546 AUC = 0.833


10it [01:06,  6.64s/it]


Epoch 255: TRAIN Accuracy = 0.85 Loss = 0.281 AUC = 0.94
Epoch 255: VAL Accuracy = 0.73 Loss = 0.553 AUC = 0.834


10it [01:06,  6.62s/it]


Epoch 260: TRAIN Accuracy = 0.849 Loss = 0.277 AUC = 0.939
Epoch 260: VAL Accuracy = 0.73 Loss = 0.558 AUC = 0.832


10it [01:06,  6.65s/it]


Epoch 265: TRAIN Accuracy = 0.85 Loss = 0.273 AUC = 0.94
Epoch 265: VAL Accuracy = 0.733 Loss = 0.562 AUC = 0.832


10it [01:05,  6.59s/it]


Epoch 270: TRAIN Accuracy = 0.851 Loss = 0.27 AUC = 0.942
Epoch 270: VAL Accuracy = 0.733 Loss = 0.566 AUC = 0.831


10it [01:06,  6.62s/it]


Epoch 275: TRAIN Accuracy = 0.851 Loss = 0.267 AUC = 0.941
Epoch 275: VAL Accuracy = 0.733 Loss = 0.571 AUC = 0.829


10it [01:06,  6.63s/it]


Epoch 280: TRAIN Accuracy = 0.854 Loss = 0.264 AUC = 0.943
Epoch 280: VAL Accuracy = 0.733 Loss = 0.574 AUC = 0.833


10it [01:06,  6.65s/it]


Epoch 285: TRAIN Accuracy = 0.857 Loss = 0.261 AUC = 0.944
Epoch 285: VAL Accuracy = 0.733 Loss = 0.58 AUC = 0.83


10it [01:06,  6.70s/it]


Epoch 290: TRAIN Accuracy = 0.857 Loss = 0.259 AUC = 0.945
Epoch 290: VAL Accuracy = 0.74 Loss = 0.586 AUC = 0.829


10it [01:06,  6.68s/it]


Epoch 295: TRAIN Accuracy = 0.86 Loss = 0.256 AUC = 0.944
Epoch 295: VAL Accuracy = 0.743 Loss = 0.591 AUC = 0.831


10it [01:06,  6.65s/it]

Epoch 300: TRAIN Accuracy = 0.86 Loss = 0.254 AUC = 0.945
Epoch 300: VAL Accuracy = 0.743 Loss = 0.595 AUC = 0.829
CPU times: user 2h 42min 36s, sys: 53min 37s, total: 3h 36min 14s
Wall time: 1h 26min 38s



