In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-02-17 22:29:36.628964: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-17 22:29:36.632734: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-17 22:29:36.643055: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1739820576.659385  924552 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739820576.663690  924552 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-17 22:29:36.681579: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 300
STEP = 5
NUM_EXPERIMENTS = 10

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(GRU(64, return_sequences=True))
    model.add(Dropout(0.3))

    model.add(GRU(32, return_sequences=False))
    model.add(Dropout(0.3))

    model.add(Dense(16, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.0003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-02-17 22:29:39.650508: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

models = []

for i in range(NUM_EPOCHS // STEP):
    epoch_acc = []
    epoch_loss = []
    epoch_auc = []

    epoch_val_acc = []
    epoch_val_loss = []
    epoch_val_auc = []

    for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
        np.random.seed(int(seed))
        random.seed(int(seed))
        tf.random.set_seed(int(seed))

        train_id = np.random.choice(np.unique(np.ravel(data[ID])), 70, replace=False)
        train_index = np.isin(data[ID], train_id)

        train = data.iloc[train_index]
        test = data.iloc[~train_index]

        X_train, y_train = reshape_dataset(train)
        X_test, y_test = reshape_dataset(test)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)

        if i == 0:
            model = create_model(X_train)
            models.append(model)
        else:
            model = models[j]

        history = model.fit(
            X_train, y_train,
            validation_data=(X_test, y_test),
            epochs=STEP,
            batch_size=16,
            verbose=0,
        )
        acc = history.history['accuracy'][0]
        loss = history.history['loss'][0]
        auc = history.history['auc'][0]

        val_acc = history.history['val_accuracy'][0]
        val_loss = history.history['val_loss'][0]
        val_auc = history.history['val_auc'][0]

        epoch_acc.append(acc)
        epoch_loss.append(loss)
        epoch_auc.append(auc)

        epoch_val_acc.append(val_acc)
        epoch_val_loss.append(val_loss)
        epoch_val_auc.append(val_auc)

    print(f"Epoch {(i + 1) * STEP}: TRAIN Accuracy = {np.round(np.mean(epoch_acc), 3)} Loss = {np.round(np.mean(epoch_loss), 3)} AUC = {np.round(np.mean(epoch_auc), 3)}")
    print(f"Epoch {(i + 1) * STEP}: VAL Accuracy = {np.round(np.mean(epoch_val_acc), 3)} Loss = {np.round(np.mean(epoch_val_loss), 3)} AUC = {np.round(np.mean(epoch_val_auc), 3)}")

0it [00:00, ?it/s]

10it [01:30,  9.08s/it]


Epoch 5: TRAIN Accuracy = 0.51 Loss = 0.719 AUC = 0.449
Epoch 5: VAL Accuracy = 0.58 Loss = 0.69 AUC = 0.492


10it [00:55,  5.55s/it]


Epoch 10: TRAIN Accuracy = 0.743 Loss = 0.621 AUC = 0.744
Epoch 10: VAL Accuracy = 0.753 Loss = 0.609 AUC = 0.764


10it [00:55,  5.51s/it]


Epoch 15: TRAIN Accuracy = 0.764 Loss = 0.558 AUC = 0.775
Epoch 15: VAL Accuracy = 0.767 Loss = 0.553 AUC = 0.781


10it [00:58,  5.89s/it]


Epoch 20: TRAIN Accuracy = 0.773 Loss = 0.517 AUC = 0.808
Epoch 20: VAL Accuracy = 0.767 Loss = 0.517 AUC = 0.781


10it [00:56,  5.62s/it]


Epoch 25: TRAIN Accuracy = 0.78 Loss = 0.482 AUC = 0.809
Epoch 25: VAL Accuracy = 0.77 Loss = 0.493 AUC = 0.8


10it [00:50,  5.04s/it]


Epoch 30: TRAIN Accuracy = 0.799 Loss = 0.459 AUC = 0.794
Epoch 30: VAL Accuracy = 0.773 Loss = 0.474 AUC = 0.809


10it [00:54,  5.49s/it]


Epoch 35: TRAIN Accuracy = 0.809 Loss = 0.438 AUC = 0.816
Epoch 35: VAL Accuracy = 0.777 Loss = 0.46 AUC = 0.826


10it [00:56,  5.69s/it]


Epoch 40: TRAIN Accuracy = 0.81 Loss = 0.425 AUC = 0.823
Epoch 40: VAL Accuracy = 0.783 Loss = 0.448 AUC = 0.832


10it [00:57,  5.77s/it]


Epoch 45: TRAIN Accuracy = 0.811 Loss = 0.408 AUC = 0.859
Epoch 45: VAL Accuracy = 0.79 Loss = 0.436 AUC = 0.855


10it [00:55,  5.54s/it]


Epoch 50: TRAIN Accuracy = 0.811 Loss = 0.402 AUC = 0.85
Epoch 50: VAL Accuracy = 0.79 Loss = 0.426 AUC = 0.87


10it [00:56,  5.61s/it]


Epoch 55: TRAIN Accuracy = 0.811 Loss = 0.387 AUC = 0.868
Epoch 55: VAL Accuracy = 0.79 Loss = 0.416 AUC = 0.871


10it [00:56,  5.69s/it]


Epoch 60: TRAIN Accuracy = 0.811 Loss = 0.376 AUC = 0.879
Epoch 60: VAL Accuracy = 0.79 Loss = 0.409 AUC = 0.88


10it [00:58,  5.86s/it]


Epoch 65: TRAIN Accuracy = 0.81 Loss = 0.363 AUC = 0.884
Epoch 65: VAL Accuracy = 0.793 Loss = 0.397 AUC = 0.887


10it [00:57,  5.75s/it]


Epoch 70: TRAIN Accuracy = 0.809 Loss = 0.345 AUC = 0.91
Epoch 70: VAL Accuracy = 0.797 Loss = 0.397 AUC = 0.887


10it [00:58,  5.86s/it]


Epoch 75: TRAIN Accuracy = 0.811 Loss = 0.323 AUC = 0.929
Epoch 75: VAL Accuracy = 0.797 Loss = 0.41 AUC = 0.889


10it [01:00,  6.09s/it]


Epoch 80: TRAIN Accuracy = 0.816 Loss = 0.308 AUC = 0.929
Epoch 80: VAL Accuracy = 0.79 Loss = 0.407 AUC = 0.888


10it [01:00,  6.06s/it]


Epoch 85: TRAIN Accuracy = 0.836 Loss = 0.289 AUC = 0.94
Epoch 85: VAL Accuracy = 0.787 Loss = 0.419 AUC = 0.885


10it [00:59,  5.96s/it]


Epoch 90: TRAIN Accuracy = 0.856 Loss = 0.26 AUC = 0.95
Epoch 90: VAL Accuracy = 0.793 Loss = 0.432 AUC = 0.884


10it [01:02,  6.22s/it]


Epoch 95: TRAIN Accuracy = 0.869 Loss = 0.233 AUC = 0.958
Epoch 95: VAL Accuracy = 0.793 Loss = 0.459 AUC = 0.88


10it [01:01,  6.12s/it]


Epoch 100: TRAIN Accuracy = 0.896 Loss = 0.204 AUC = 0.969
Epoch 100: VAL Accuracy = 0.783 Loss = 0.506 AUC = 0.872


10it [01:00,  6.08s/it]


Epoch 105: TRAIN Accuracy = 0.924 Loss = 0.177 AUC = 0.974
Epoch 105: VAL Accuracy = 0.79 Loss = 0.554 AUC = 0.869


10it [01:01,  6.17s/it]


Epoch 110: TRAIN Accuracy = 0.937 Loss = 0.152 AUC = 0.98
Epoch 110: VAL Accuracy = 0.76 Loss = 0.618 AUC = 0.864


10it [01:00,  6.09s/it]


Epoch 115: TRAIN Accuracy = 0.947 Loss = 0.139 AUC = 0.984
Epoch 115: VAL Accuracy = 0.767 Loss = 0.693 AUC = 0.865


10it [01:01,  6.15s/it]


Epoch 120: TRAIN Accuracy = 0.959 Loss = 0.114 AUC = 0.988
Epoch 120: VAL Accuracy = 0.77 Loss = 0.72 AUC = 0.863


10it [01:02,  6.28s/it]


Epoch 125: TRAIN Accuracy = 0.963 Loss = 0.102 AUC = 0.99
Epoch 125: VAL Accuracy = 0.767 Loss = 0.794 AUC = 0.864


10it [01:01,  6.18s/it]


Epoch 130: TRAIN Accuracy = 0.97 Loss = 0.089 AUC = 0.992
Epoch 130: VAL Accuracy = 0.767 Loss = 0.855 AUC = 0.863


10it [00:59,  5.95s/it]


Epoch 135: TRAIN Accuracy = 0.971 Loss = 0.077 AUC = 0.995
Epoch 135: VAL Accuracy = 0.763 Loss = 0.932 AUC = 0.855


10it [00:59,  5.92s/it]


Epoch 140: TRAIN Accuracy = 0.979 Loss = 0.068 AUC = 0.994
Epoch 140: VAL Accuracy = 0.773 Loss = 0.907 AUC = 0.863


10it [01:00,  6.05s/it]


Epoch 145: TRAIN Accuracy = 0.977 Loss = 0.06 AUC = 0.996
Epoch 145: VAL Accuracy = 0.763 Loss = 0.979 AUC = 0.861


10it [00:58,  5.88s/it]


Epoch 150: TRAIN Accuracy = 0.986 Loss = 0.043 AUC = 0.996
Epoch 150: VAL Accuracy = 0.773 Loss = 1.004 AUC = 0.855


10it [01:00,  6.01s/it]


Epoch 155: TRAIN Accuracy = 0.989 Loss = 0.042 AUC = 0.995
Epoch 155: VAL Accuracy = 0.763 Loss = 1.07 AUC = 0.844


10it [00:59,  5.94s/it]


Epoch 160: TRAIN Accuracy = 0.99 Loss = 0.036 AUC = 0.997
Epoch 160: VAL Accuracy = 0.777 Loss = 1.093 AUC = 0.841


10it [00:57,  5.79s/it]


Epoch 165: TRAIN Accuracy = 0.984 Loss = 0.044 AUC = 0.996
Epoch 165: VAL Accuracy = 0.77 Loss = 1.117 AUC = 0.838


10it [00:58,  5.85s/it]


Epoch 170: TRAIN Accuracy = 0.989 Loss = 0.03 AUC = 0.997
Epoch 170: VAL Accuracy = 0.767 Loss = 1.146 AUC = 0.842


10it [00:59,  6.00s/it]


Epoch 175: TRAIN Accuracy = 0.993 Loss = 0.027 AUC = 0.997
Epoch 175: VAL Accuracy = 0.767 Loss = 1.16 AUC = 0.84


10it [00:58,  5.88s/it]


Epoch 180: TRAIN Accuracy = 0.99 Loss = 0.023 AUC = 0.998
Epoch 180: VAL Accuracy = 0.773 Loss = 1.166 AUC = 0.836


10it [00:59,  5.98s/it]


Epoch 185: TRAIN Accuracy = 0.993 Loss = 0.02 AUC = 0.998
Epoch 185: VAL Accuracy = 0.777 Loss = 1.186 AUC = 0.832


10it [01:00,  6.06s/it]


Epoch 190: TRAIN Accuracy = 0.996 Loss = 0.02 AUC = 0.997
Epoch 190: VAL Accuracy = 0.773 Loss = 1.207 AUC = 0.833


10it [00:59,  5.93s/it]


Epoch 195: TRAIN Accuracy = 0.994 Loss = 0.02 AUC = 0.998
Epoch 195: VAL Accuracy = 0.777 Loss = 1.216 AUC = 0.831


10it [01:02,  6.29s/it]


Epoch 200: TRAIN Accuracy = 0.996 Loss = 0.015 AUC = 0.999
Epoch 200: VAL Accuracy = 0.77 Loss = 1.249 AUC = 0.821


10it [01:01,  6.13s/it]


Epoch 205: TRAIN Accuracy = 0.994 Loss = 0.018 AUC = 0.998
Epoch 205: VAL Accuracy = 0.763 Loss = 1.268 AUC = 0.824


10it [00:59,  5.91s/it]


Epoch 210: TRAIN Accuracy = 0.993 Loss = 0.019 AUC = 0.999
Epoch 210: VAL Accuracy = 0.763 Loss = 1.285 AUC = 0.824


10it [01:00,  6.00s/it]


Epoch 215: TRAIN Accuracy = 0.994 Loss = 0.017 AUC = 0.999
Epoch 215: VAL Accuracy = 0.767 Loss = 1.292 AUC = 0.822


10it [01:00,  6.05s/it]


Epoch 220: TRAIN Accuracy = 0.994 Loss = 0.016 AUC = 0.999
Epoch 220: VAL Accuracy = 0.76 Loss = 1.315 AUC = 0.82


10it [01:00,  6.03s/it]


Epoch 225: TRAIN Accuracy = 0.994 Loss = 0.013 AUC = 0.999
Epoch 225: VAL Accuracy = 0.763 Loss = 1.332 AUC = 0.822


10it [01:01,  6.12s/it]


Epoch 230: TRAIN Accuracy = 0.994 Loss = 0.012 AUC = 0.999
Epoch 230: VAL Accuracy = 0.763 Loss = 1.336 AUC = 0.825


10it [01:00,  6.10s/it]


Epoch 235: TRAIN Accuracy = 0.994 Loss = 0.014 AUC = 0.999
Epoch 235: VAL Accuracy = 0.76 Loss = 1.376 AUC = 0.82


10it [01:01,  6.13s/it]


Epoch 240: TRAIN Accuracy = 0.993 Loss = 0.013 AUC = 1.0
Epoch 240: VAL Accuracy = 0.763 Loss = 1.368 AUC = 0.822


10it [00:59,  5.95s/it]


Epoch 245: TRAIN Accuracy = 0.993 Loss = 0.011 AUC = 1.0
Epoch 245: VAL Accuracy = 0.767 Loss = 1.376 AUC = 0.823


10it [00:58,  5.90s/it]


Epoch 250: TRAIN Accuracy = 0.993 Loss = 0.012 AUC = 0.999
Epoch 250: VAL Accuracy = 0.767 Loss = 1.392 AUC = 0.82


10it [01:00,  6.09s/it]


Epoch 255: TRAIN Accuracy = 0.993 Loss = 0.01 AUC = 1.0
Epoch 255: VAL Accuracy = 0.767 Loss = 1.404 AUC = 0.818


10it [01:03,  6.32s/it]


Epoch 260: TRAIN Accuracy = 0.994 Loss = 0.01 AUC = 1.0
Epoch 260: VAL Accuracy = 0.773 Loss = 1.408 AUC = 0.82


10it [00:58,  5.87s/it]


Epoch 265: TRAIN Accuracy = 0.996 Loss = 0.009 AUC = 1.0
Epoch 265: VAL Accuracy = 0.763 Loss = 1.428 AUC = 0.818


10it [01:00,  6.09s/it]


Epoch 270: TRAIN Accuracy = 0.994 Loss = 0.009 AUC = 1.0
Epoch 270: VAL Accuracy = 0.763 Loss = 1.44 AUC = 0.818


10it [01:01,  6.14s/it]


Epoch 275: TRAIN Accuracy = 0.993 Loss = 0.009 AUC = 1.0
Epoch 275: VAL Accuracy = 0.77 Loss = 1.45 AUC = 0.818


10it [01:00,  6.10s/it]


Epoch 280: TRAIN Accuracy = 0.993 Loss = 0.008 AUC = 1.0
Epoch 280: VAL Accuracy = 0.767 Loss = 1.454 AUC = 0.818


10it [01:03,  6.35s/it]


Epoch 285: TRAIN Accuracy = 0.996 Loss = 0.007 AUC = 1.0
Epoch 285: VAL Accuracy = 0.773 Loss = 1.465 AUC = 0.817


10it [01:04,  6.49s/it]


Epoch 290: TRAIN Accuracy = 0.996 Loss = 0.007 AUC = 1.0
Epoch 290: VAL Accuracy = 0.77 Loss = 1.467 AUC = 0.816


10it [01:02,  6.24s/it]


Epoch 295: TRAIN Accuracy = 0.994 Loss = 0.008 AUC = 1.0
Epoch 295: VAL Accuracy = 0.767 Loss = 1.487 AUC = 0.816


10it [01:01,  6.19s/it]

Epoch 300: TRAIN Accuracy = 0.993 Loss = 0.01 AUC = 1.0
Epoch 300: VAL Accuracy = 0.763 Loss = 1.511 AUC = 0.813
CPU times: user 1h 41min 16s, sys: 25min 40s, total: 2h 6min 57s
Wall time: 1h 10s



