In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-19 01:31:18.817716: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-19 01:31:18.822511: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-19 01:31:18.835703: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742337078.858512 2178303 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742337078.865071 2178303 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-19 01:31:18.893924: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 300
STEP = 5
NUM_EXPERIMENTS = 10

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))
    model.add(BatchNormalization())

    model.add(Dropout(0.4))
    model.add(GRU(10, return_sequences=True))
    model.add(BatchNormalization())

    model.add(Dropout(0.4))
    model.add(GRU(10, return_sequences=True))
    model.add(BatchNormalization())

    model.add(Dropout(0.4))
    model.add(GRU(10, return_sequences=False))
    model.add(BatchNormalization())

    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-19 01:31:22.977778: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

models = []

for i in range(NUM_EPOCHS // STEP):
    epoch_acc = []
    epoch_loss = []
    epoch_auc = []

    epoch_val_acc = []
    epoch_val_loss = []
    epoch_val_auc = []

    for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
        np.random.seed(int(seed))
        random.seed(int(seed))
        tf.random.set_seed(int(seed))

        train_id = np.random.choice(np.unique(np.ravel(data[ID])), 70, replace=False)
        train_index = np.isin(data[ID], train_id)

        train = data.iloc[train_index]
        test = data.iloc[~train_index]

        X_train, y_train = reshape_dataset(train)
        X_test, y_test = reshape_dataset(test)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)

        if i == 0:
            model = create_model(X_train)
            models.append(model)
        else:
            model = models[j]

        history = model.fit(
            X_train, y_train,
            validation_data=(X_test, y_test),
            epochs=STEP,
            batch_size=16,
            verbose=0,
        )
        acc = history.history['accuracy'][0]
        loss = history.history['loss'][0]
        auc = history.history['auc'][0]

        val_acc = history.history['val_accuracy'][0]
        val_loss = history.history['val_loss'][0]
        val_auc = history.history['val_auc'][0]

        epoch_acc.append(acc)
        epoch_loss.append(loss)
        epoch_auc.append(auc)

        epoch_val_acc.append(val_acc)
        epoch_val_loss.append(val_loss)
        epoch_val_auc.append(val_auc)

    print(f"Epoch {(i + 1) * STEP}: TRAIN Accuracy = {np.round(np.mean(epoch_acc), 3)} Loss = {np.round(np.mean(epoch_loss), 3)} AUC = {np.round(np.mean(epoch_auc), 3)}")
    print(f"Epoch {(i + 1) * STEP}: VAL Accuracy = {np.round(np.mean(epoch_val_acc), 3)} Loss = {np.round(np.mean(epoch_val_loss), 3)} AUC = {np.round(np.mean(epoch_val_auc), 3)}")

10it [02:48, 16.85s/it]


Epoch 5: TRAIN Accuracy = 0.483 Loss = 0.825 AUC = 0.466
Epoch 5: VAL Accuracy = 0.543 Loss = 0.694 AUC = 0.568


10it [01:36,  9.64s/it]


Epoch 10: TRAIN Accuracy = 0.596 Loss = 0.674 AUC = 0.624
Epoch 10: VAL Accuracy = 0.633 Loss = 0.664 AUC = 0.75


10it [01:35,  9.50s/it]


Epoch 15: TRAIN Accuracy = 0.653 Loss = 0.607 AUC = 0.711
Epoch 15: VAL Accuracy = 0.627 Loss = 0.645 AUC = 0.743


10it [01:36,  9.66s/it]


Epoch 20: TRAIN Accuracy = 0.693 Loss = 0.581 AUC = 0.718
Epoch 20: VAL Accuracy = 0.603 Loss = 0.66 AUC = 0.764


10it [01:33,  9.36s/it]


Epoch 25: TRAIN Accuracy = 0.684 Loss = 0.567 AUC = 0.725
Epoch 25: VAL Accuracy = 0.59 Loss = 0.718 AUC = 0.77


10it [01:36,  9.65s/it]


Epoch 30: TRAIN Accuracy = 0.686 Loss = 0.557 AUC = 0.737
Epoch 30: VAL Accuracy = 0.573 Loss = 0.813 AUC = 0.767


10it [01:34,  9.44s/it]


Epoch 35: TRAIN Accuracy = 0.723 Loss = 0.539 AUC = 0.745
Epoch 35: VAL Accuracy = 0.577 Loss = 1.086 AUC = 0.769


10it [01:35,  9.56s/it]


Epoch 40: TRAIN Accuracy = 0.72 Loss = 0.519 AUC = 0.759
Epoch 40: VAL Accuracy = 0.533 Loss = 1.15 AUC = 0.75


10it [01:34,  9.44s/it]


Epoch 45: TRAIN Accuracy = 0.746 Loss = 0.491 AUC = 0.796
Epoch 45: VAL Accuracy = 0.527 Loss = 1.548 AUC = 0.728


10it [01:35,  9.51s/it]


Epoch 50: TRAIN Accuracy = 0.751 Loss = 0.505 AUC = 0.754
Epoch 50: VAL Accuracy = 0.59 Loss = 1.538 AUC = 0.7


10it [01:33,  9.39s/it]


Epoch 55: TRAIN Accuracy = 0.769 Loss = 0.471 AUC = 0.804
Epoch 55: VAL Accuracy = 0.547 Loss = 1.746 AUC = 0.733


10it [01:35,  9.58s/it]


Epoch 60: TRAIN Accuracy = 0.783 Loss = 0.451 AUC = 0.811
Epoch 60: VAL Accuracy = 0.533 Loss = 1.87 AUC = 0.709


10it [01:34,  9.48s/it]


Epoch 65: TRAIN Accuracy = 0.78 Loss = 0.462 AUC = 0.789
Epoch 65: VAL Accuracy = 0.637 Loss = 1.02 AUC = 0.755


10it [01:34,  9.49s/it]


Epoch 70: TRAIN Accuracy = 0.76 Loss = 0.468 AUC = 0.81
Epoch 70: VAL Accuracy = 0.573 Loss = 1.248 AUC = 0.77


10it [01:35,  9.56s/it]


Epoch 75: TRAIN Accuracy = 0.776 Loss = 0.463 AUC = 0.808
Epoch 75: VAL Accuracy = 0.65 Loss = 1.203 AUC = 0.74


10it [01:36,  9.67s/it]


Epoch 80: TRAIN Accuracy = 0.783 Loss = 0.441 AUC = 0.824
Epoch 80: VAL Accuracy = 0.617 Loss = 1.36 AUC = 0.734


10it [01:36,  9.70s/it]


Epoch 85: TRAIN Accuracy = 0.796 Loss = 0.433 AUC = 0.836
Epoch 85: VAL Accuracy = 0.673 Loss = 1.313 AUC = 0.75


10it [01:36,  9.69s/it]


Epoch 90: TRAIN Accuracy = 0.796 Loss = 0.426 AUC = 0.836
Epoch 90: VAL Accuracy = 0.647 Loss = 1.285 AUC = 0.756


10it [01:37,  9.75s/it]


Epoch 95: TRAIN Accuracy = 0.786 Loss = 0.423 AUC = 0.845
Epoch 95: VAL Accuracy = 0.647 Loss = 1.604 AUC = 0.734


10it [01:40, 10.05s/it]


Epoch 100: TRAIN Accuracy = 0.794 Loss = 0.418 AUC = 0.844
Epoch 100: VAL Accuracy = 0.703 Loss = 0.557 AUC = 0.804


10it [01:32,  9.21s/it]


Epoch 105: TRAIN Accuracy = 0.791 Loss = 0.411 AUC = 0.851
Epoch 105: VAL Accuracy = 0.657 Loss = 1.22 AUC = 0.786


10it [01:34,  9.41s/it]


Epoch 110: TRAIN Accuracy = 0.787 Loss = 0.415 AUC = 0.836
Epoch 110: VAL Accuracy = 0.697 Loss = 0.724 AUC = 0.807


10it [01:33,  9.31s/it]


Epoch 115: TRAIN Accuracy = 0.804 Loss = 0.399 AUC = 0.857
Epoch 115: VAL Accuracy = 0.713 Loss = 0.601 AUC = 0.792


10it [01:31,  9.14s/it]


Epoch 120: TRAIN Accuracy = 0.814 Loss = 0.391 AUC = 0.875
Epoch 120: VAL Accuracy = 0.747 Loss = 0.624 AUC = 0.828


10it [01:31,  9.12s/it]


Epoch 125: TRAIN Accuracy = 0.817 Loss = 0.389 AUC = 0.85
Epoch 125: VAL Accuracy = 0.717 Loss = 0.925 AUC = 0.796


10it [01:30,  9.07s/it]


Epoch 130: TRAIN Accuracy = 0.806 Loss = 0.399 AUC = 0.861
Epoch 130: VAL Accuracy = 0.743 Loss = 0.617 AUC = 0.814


10it [01:30,  9.06s/it]


Epoch 135: TRAIN Accuracy = 0.811 Loss = 0.382 AUC = 0.876
Epoch 135: VAL Accuracy = 0.763 Loss = 0.586 AUC = 0.801


10it [01:31,  9.17s/it]


Epoch 140: TRAIN Accuracy = 0.814 Loss = 0.388 AUC = 0.869
Epoch 140: VAL Accuracy = 0.707 Loss = 0.727 AUC = 0.798


10it [01:32,  9.29s/it]


Epoch 145: TRAIN Accuracy = 0.821 Loss = 0.364 AUC = 0.89
Epoch 145: VAL Accuracy = 0.747 Loss = 0.642 AUC = 0.795


10it [01:33,  9.36s/it]


Epoch 150: TRAIN Accuracy = 0.809 Loss = 0.363 AUC = 0.896
Epoch 150: VAL Accuracy = 0.697 Loss = 0.641 AUC = 0.827


10it [01:31,  9.19s/it]


Epoch 155: TRAIN Accuracy = 0.83 Loss = 0.354 AUC = 0.877
Epoch 155: VAL Accuracy = 0.737 Loss = 0.658 AUC = 0.818


10it [01:35,  9.55s/it]


Epoch 160: TRAIN Accuracy = 0.83 Loss = 0.35 AUC = 0.888
Epoch 160: VAL Accuracy = 0.71 Loss = 0.918 AUC = 0.79


10it [01:32,  9.23s/it]


Epoch 165: TRAIN Accuracy = 0.836 Loss = 0.349 AUC = 0.898
Epoch 165: VAL Accuracy = 0.737 Loss = 0.703 AUC = 0.812


10it [01:30,  9.05s/it]


Epoch 170: TRAIN Accuracy = 0.836 Loss = 0.344 AUC = 0.901
Epoch 170: VAL Accuracy = 0.723 Loss = 0.796 AUC = 0.809


10it [01:30,  9.05s/it]


Epoch 175: TRAIN Accuracy = 0.85 Loss = 0.313 AUC = 0.92
Epoch 175: VAL Accuracy = 0.703 Loss = 0.821 AUC = 0.819


10it [01:29,  8.99s/it]


Epoch 180: TRAIN Accuracy = 0.849 Loss = 0.321 AUC = 0.919
Epoch 180: VAL Accuracy = 0.72 Loss = 0.684 AUC = 0.815


10it [01:31,  9.12s/it]


Epoch 185: TRAIN Accuracy = 0.847 Loss = 0.319 AUC = 0.907
Epoch 185: VAL Accuracy = 0.713 Loss = 0.731 AUC = 0.82


10it [01:31,  9.12s/it]


Epoch 190: TRAIN Accuracy = 0.836 Loss = 0.314 AUC = 0.916
Epoch 190: VAL Accuracy = 0.723 Loss = 0.836 AUC = 0.819


10it [01:30,  9.09s/it]


Epoch 195: TRAIN Accuracy = 0.847 Loss = 0.305 AUC = 0.919
Epoch 195: VAL Accuracy = 0.72 Loss = 0.743 AUC = 0.846


10it [01:36,  9.66s/it]


Epoch 200: TRAIN Accuracy = 0.854 Loss = 0.297 AUC = 0.934
Epoch 200: VAL Accuracy = 0.65 Loss = 0.901 AUC = 0.816


10it [01:37,  9.75s/it]


Epoch 205: TRAIN Accuracy = 0.851 Loss = 0.303 AUC = 0.928
Epoch 205: VAL Accuracy = 0.7 Loss = 0.839 AUC = 0.821


10it [01:36,  9.62s/it]


Epoch 210: TRAIN Accuracy = 0.859 Loss = 0.298 AUC = 0.932
Epoch 210: VAL Accuracy = 0.703 Loss = 0.727 AUC = 0.819


10it [01:34,  9.45s/it]


Epoch 215: TRAIN Accuracy = 0.864 Loss = 0.281 AUC = 0.936
Epoch 215: VAL Accuracy = 0.663 Loss = 0.788 AUC = 0.814


10it [01:34,  9.49s/it]


Epoch 220: TRAIN Accuracy = 0.859 Loss = 0.3 AUC = 0.929
Epoch 220: VAL Accuracy = 0.71 Loss = 0.992 AUC = 0.79


10it [01:31,  9.16s/it]


Epoch 225: TRAIN Accuracy = 0.876 Loss = 0.268 AUC = 0.94
Epoch 225: VAL Accuracy = 0.727 Loss = 0.805 AUC = 0.827


10it [01:35,  9.51s/it]


Epoch 230: TRAIN Accuracy = 0.863 Loss = 0.286 AUC = 0.929
Epoch 230: VAL Accuracy = 0.717 Loss = 0.774 AUC = 0.822


10it [01:31,  9.14s/it]


Epoch 235: TRAIN Accuracy = 0.874 Loss = 0.266 AUC = 0.937
Epoch 235: VAL Accuracy = 0.703 Loss = 0.871 AUC = 0.81


10it [01:30,  9.07s/it]


Epoch 240: TRAIN Accuracy = 0.879 Loss = 0.26 AUC = 0.946
Epoch 240: VAL Accuracy = 0.693 Loss = 1.086 AUC = 0.808


10it [01:30,  9.06s/it]


Epoch 245: TRAIN Accuracy = 0.869 Loss = 0.27 AUC = 0.934
Epoch 245: VAL Accuracy = 0.707 Loss = 0.856 AUC = 0.81


10it [01:31,  9.14s/it]


Epoch 250: TRAIN Accuracy = 0.893 Loss = 0.251 AUC = 0.945
Epoch 250: VAL Accuracy = 0.7 Loss = 1.222 AUC = 0.783


10it [01:30,  9.09s/it]


Epoch 255: TRAIN Accuracy = 0.879 Loss = 0.25 AUC = 0.953
Epoch 255: VAL Accuracy = 0.703 Loss = 0.828 AUC = 0.802


10it [01:32,  9.25s/it]


Epoch 260: TRAIN Accuracy = 0.864 Loss = 0.249 AUC = 0.948
Epoch 260: VAL Accuracy = 0.71 Loss = 0.897 AUC = 0.785


10it [01:30,  9.07s/it]


Epoch 265: TRAIN Accuracy = 0.891 Loss = 0.23 AUC = 0.96
Epoch 265: VAL Accuracy = 0.663 Loss = 0.862 AUC = 0.824


10it [01:31,  9.12s/it]


Epoch 270: TRAIN Accuracy = 0.891 Loss = 0.235 AUC = 0.956
Epoch 270: VAL Accuracy = 0.667 Loss = 0.937 AUC = 0.807


10it [01:33,  9.37s/it]


Epoch 275: TRAIN Accuracy = 0.877 Loss = 0.249 AUC = 0.951
Epoch 275: VAL Accuracy = 0.66 Loss = 1.232 AUC = 0.772


10it [01:35,  9.58s/it]


Epoch 280: TRAIN Accuracy = 0.901 Loss = 0.224 AUC = 0.959
Epoch 280: VAL Accuracy = 0.65 Loss = 1.033 AUC = 0.778


10it [01:32,  9.21s/it]


Epoch 285: TRAIN Accuracy = 0.906 Loss = 0.2 AUC = 0.966
Epoch 285: VAL Accuracy = 0.683 Loss = 1.135 AUC = 0.798


10it [01:30,  9.06s/it]


Epoch 290: TRAIN Accuracy = 0.886 Loss = 0.243 AUC = 0.957
Epoch 290: VAL Accuracy = 0.69 Loss = 0.993 AUC = 0.815


10it [01:31,  9.11s/it]


Epoch 295: TRAIN Accuracy = 0.889 Loss = 0.219 AUC = 0.958
Epoch 295: VAL Accuracy = 0.673 Loss = 0.992 AUC = 0.785


10it [01:30,  9.10s/it]

Epoch 300: TRAIN Accuracy = 0.886 Loss = 0.239 AUC = 0.958
Epoch 300: VAL Accuracy = 0.69 Loss = 0.836 AUC = 0.818
CPU times: user 3h 20min 22s, sys: 1h 9min 25s, total: 4h 29min 48s
Wall time: 1h 34min 45s



