In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-16 02:33:37.868947: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-16 02:33:37.875645: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-16 02:33:37.893111: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742081617.921394  264047 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742081617.931450  264047 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-16 02:33:37.995765: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 300
STEP = 5
NUM_EXPERIMENTS = 10

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=False))
    
    model.add(Dropout(0.2))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-16 02:33:44.381401: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

models = []

for i in range(NUM_EPOCHS // STEP):
    epoch_acc = []
    epoch_loss = []
    epoch_auc = []

    epoch_val_acc = []
    epoch_val_loss = []
    epoch_val_auc = []

    for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
        np.random.seed(int(seed))
        random.seed(int(seed))
        tf.random.set_seed(int(seed))

        train_id = np.random.choice(np.unique(np.ravel(data[ID])), 70, replace=False)
        train_index = np.isin(data[ID], train_id)

        train = data.iloc[train_index]
        test = data.iloc[~train_index]

        X_train, y_train = reshape_dataset(train)
        X_test, y_test = reshape_dataset(test)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)

        if i == 0:
            model = create_model(X_train)
            models.append(model)
        else:
            model = models[j]

        history = model.fit(
            X_train, y_train,
            validation_data=(X_test, y_test),
            epochs=STEP,
            batch_size=16,
            verbose=0,
        )
        acc = history.history['accuracy'][0]
        loss = history.history['loss'][0]
        auc = history.history['auc'][0]

        val_acc = history.history['val_accuracy'][0]
        val_loss = history.history['val_loss'][0]
        val_auc = history.history['val_auc'][0]

        epoch_acc.append(acc)
        epoch_loss.append(loss)
        epoch_auc.append(auc)

        epoch_val_acc.append(val_acc)
        epoch_val_loss.append(val_loss)
        epoch_val_auc.append(val_auc)

    print(f"Epoch {(i + 1) * STEP}: TRAIN Accuracy = {np.round(np.mean(epoch_acc), 3)} Loss = {np.round(np.mean(epoch_loss), 3)} AUC = {np.round(np.mean(epoch_auc), 3)}")
    print(f"Epoch {(i + 1) * STEP}: VAL Accuracy = {np.round(np.mean(epoch_val_acc), 3)} Loss = {np.round(np.mean(epoch_val_loss), 3)} AUC = {np.round(np.mean(epoch_val_auc), 3)}")

10it [02:13, 13.38s/it]


Epoch 5: TRAIN Accuracy = 0.513 Loss = 0.677 AUC = 0.588
Epoch 5: VAL Accuracy = 0.65 Loss = 0.658 AUC = 0.71


10it [01:23,  8.39s/it]


Epoch 10: TRAIN Accuracy = 0.76 Loss = 0.553 AUC = 0.784
Epoch 10: VAL Accuracy = 0.773 Loss = 0.535 AUC = 0.771


10it [01:19,  7.93s/it]


Epoch 15: TRAIN Accuracy = 0.789 Loss = 0.473 AUC = 0.813
Epoch 15: VAL Accuracy = 0.77 Loss = 0.483 AUC = 0.795


10it [01:22,  8.26s/it]


Epoch 20: TRAIN Accuracy = 0.809 Loss = 0.436 AUC = 0.814
Epoch 20: VAL Accuracy = 0.783 Loss = 0.448 AUC = 0.821


10it [01:21,  8.11s/it]


Epoch 25: TRAIN Accuracy = 0.807 Loss = 0.421 AUC = 0.818
Epoch 25: VAL Accuracy = 0.797 Loss = 0.436 AUC = 0.834


10it [01:20,  8.01s/it]


Epoch 30: TRAIN Accuracy = 0.81 Loss = 0.41 AUC = 0.838
Epoch 30: VAL Accuracy = 0.803 Loss = 0.42 AUC = 0.851


10it [01:17,  7.71s/it]


Epoch 35: TRAIN Accuracy = 0.804 Loss = 0.401 AUC = 0.842
Epoch 35: VAL Accuracy = 0.807 Loss = 0.413 AUC = 0.864


10it [01:16,  7.64s/it]


Epoch 40: TRAIN Accuracy = 0.797 Loss = 0.399 AUC = 0.855
Epoch 40: VAL Accuracy = 0.807 Loss = 0.415 AUC = 0.875


10it [01:14,  7.44s/it]


Epoch 45: TRAIN Accuracy = 0.8 Loss = 0.391 AUC = 0.854
Epoch 45: VAL Accuracy = 0.807 Loss = 0.395 AUC = 0.886


10it [01:12,  7.21s/it]


Epoch 50: TRAIN Accuracy = 0.81 Loss = 0.378 AUC = 0.866
Epoch 50: VAL Accuracy = 0.797 Loss = 0.386 AUC = 0.892


10it [01:09,  6.97s/it]


Epoch 55: TRAIN Accuracy = 0.816 Loss = 0.372 AUC = 0.867
Epoch 55: VAL Accuracy = 0.793 Loss = 0.389 AUC = 0.884


10it [01:12,  7.29s/it]


Epoch 60: TRAIN Accuracy = 0.81 Loss = 0.369 AUC = 0.874
Epoch 60: VAL Accuracy = 0.793 Loss = 0.39 AUC = 0.889


10it [01:16,  7.61s/it]


Epoch 65: TRAIN Accuracy = 0.806 Loss = 0.37 AUC = 0.868
Epoch 65: VAL Accuracy = 0.8 Loss = 0.382 AUC = 0.898


10it [01:10,  7.05s/it]


Epoch 70: TRAIN Accuracy = 0.801 Loss = 0.353 AUC = 0.894
Epoch 70: VAL Accuracy = 0.793 Loss = 0.373 AUC = 0.897


10it [01:10,  7.09s/it]


Epoch 75: TRAIN Accuracy = 0.8 Loss = 0.362 AUC = 0.874
Epoch 75: VAL Accuracy = 0.797 Loss = 0.37 AUC = 0.906


10it [01:12,  7.23s/it]


Epoch 80: TRAIN Accuracy = 0.813 Loss = 0.358 AUC = 0.886
Epoch 80: VAL Accuracy = 0.797 Loss = 0.359 AUC = 0.897


10it [01:11,  7.12s/it]


Epoch 85: TRAIN Accuracy = 0.806 Loss = 0.354 AUC = 0.886
Epoch 85: VAL Accuracy = 0.797 Loss = 0.363 AUC = 0.903


10it [01:12,  7.25s/it]


Epoch 90: TRAIN Accuracy = 0.81 Loss = 0.348 AUC = 0.9
Epoch 90: VAL Accuracy = 0.807 Loss = 0.358 AUC = 0.907


10it [01:11,  7.18s/it]


Epoch 95: TRAIN Accuracy = 0.807 Loss = 0.348 AUC = 0.894
Epoch 95: VAL Accuracy = 0.797 Loss = 0.363 AUC = 0.909


10it [01:11,  7.10s/it]


Epoch 100: TRAIN Accuracy = 0.8 Loss = 0.339 AUC = 0.899
Epoch 100: VAL Accuracy = 0.807 Loss = 0.354 AUC = 0.908


10it [01:10,  7.09s/it]


Epoch 105: TRAIN Accuracy = 0.807 Loss = 0.338 AUC = 0.907
Epoch 105: VAL Accuracy = 0.777 Loss = 0.36 AUC = 0.902


10it [01:11,  7.17s/it]


Epoch 110: TRAIN Accuracy = 0.813 Loss = 0.331 AUC = 0.902
Epoch 110: VAL Accuracy = 0.79 Loss = 0.348 AUC = 0.904


10it [01:10,  7.09s/it]


Epoch 115: TRAIN Accuracy = 0.824 Loss = 0.319 AUC = 0.914
Epoch 115: VAL Accuracy = 0.8 Loss = 0.345 AUC = 0.908


10it [01:12,  7.29s/it]


Epoch 120: TRAIN Accuracy = 0.813 Loss = 0.338 AUC = 0.898
Epoch 120: VAL Accuracy = 0.793 Loss = 0.33 AUC = 0.908


10it [01:12,  7.24s/it]


Epoch 125: TRAIN Accuracy = 0.817 Loss = 0.341 AUC = 0.905
Epoch 125: VAL Accuracy = 0.803 Loss = 0.344 AUC = 0.91


10it [01:14,  7.47s/it]


Epoch 130: TRAIN Accuracy = 0.807 Loss = 0.335 AUC = 0.902
Epoch 130: VAL Accuracy = 0.79 Loss = 0.339 AUC = 0.91


10it [01:18,  7.85s/it]


Epoch 135: TRAIN Accuracy = 0.817 Loss = 0.325 AUC = 0.912
Epoch 135: VAL Accuracy = 0.763 Loss = 0.355 AUC = 0.894


10it [01:37,  9.78s/it]


Epoch 140: TRAIN Accuracy = 0.821 Loss = 0.32 AUC = 0.913
Epoch 140: VAL Accuracy = 0.78 Loss = 0.351 AUC = 0.91


10it [02:06, 12.61s/it]


Epoch 145: TRAIN Accuracy = 0.813 Loss = 0.322 AUC = 0.913
Epoch 145: VAL Accuracy = 0.783 Loss = 0.348 AUC = 0.905


10it [02:02, 12.28s/it]


Epoch 150: TRAIN Accuracy = 0.847 Loss = 0.304 AUC = 0.926
Epoch 150: VAL Accuracy = 0.777 Loss = 0.343 AUC = 0.905


10it [02:03, 12.32s/it]


Epoch 155: TRAIN Accuracy = 0.817 Loss = 0.315 AUC = 0.914
Epoch 155: VAL Accuracy = 0.793 Loss = 0.324 AUC = 0.91


10it [02:05, 12.50s/it]


Epoch 160: TRAIN Accuracy = 0.819 Loss = 0.314 AUC = 0.916
Epoch 160: VAL Accuracy = 0.793 Loss = 0.338 AUC = 0.909


10it [01:44, 10.42s/it]


Epoch 165: TRAIN Accuracy = 0.816 Loss = 0.317 AUC = 0.915
Epoch 165: VAL Accuracy = 0.797 Loss = 0.314 AUC = 0.915


10it [01:40, 10.05s/it]


Epoch 170: TRAIN Accuracy = 0.827 Loss = 0.308 AUC = 0.917
Epoch 170: VAL Accuracy = 0.797 Loss = 0.334 AUC = 0.899


10it [01:53, 11.35s/it]


Epoch 175: TRAIN Accuracy = 0.839 Loss = 0.311 AUC = 0.923
Epoch 175: VAL Accuracy = 0.793 Loss = 0.33 AUC = 0.917


10it [01:49, 10.96s/it]


Epoch 180: TRAIN Accuracy = 0.831 Loss = 0.301 AUC = 0.921
Epoch 180: VAL Accuracy = 0.817 Loss = 0.3 AUC = 0.924


10it [01:47, 10.71s/it]


Epoch 185: TRAIN Accuracy = 0.826 Loss = 0.313 AUC = 0.917
Epoch 185: VAL Accuracy = 0.787 Loss = 0.336 AUC = 0.914


10it [01:36,  9.64s/it]


Epoch 190: TRAIN Accuracy = 0.826 Loss = 0.305 AUC = 0.924
Epoch 190: VAL Accuracy = 0.81 Loss = 0.301 AUC = 0.922


10it [01:26,  8.63s/it]


Epoch 195: TRAIN Accuracy = 0.824 Loss = 0.304 AUC = 0.923
Epoch 195: VAL Accuracy = 0.81 Loss = 0.298 AUC = 0.921


10it [01:28,  8.88s/it]


Epoch 200: TRAIN Accuracy = 0.819 Loss = 0.324 AUC = 0.911
Epoch 200: VAL Accuracy = 0.793 Loss = 0.317 AUC = 0.916


10it [01:28,  8.88s/it]


Epoch 205: TRAIN Accuracy = 0.834 Loss = 0.294 AUC = 0.923
Epoch 205: VAL Accuracy = 0.81 Loss = 0.305 AUC = 0.921


10it [01:27,  8.74s/it]


Epoch 210: TRAIN Accuracy = 0.826 Loss = 0.31 AUC = 0.925
Epoch 210: VAL Accuracy = 0.803 Loss = 0.325 AUC = 0.909


10it [01:28,  8.83s/it]


Epoch 215: TRAIN Accuracy = 0.829 Loss = 0.31 AUC = 0.919
Epoch 215: VAL Accuracy = 0.827 Loss = 0.305 AUC = 0.921


10it [01:29,  8.95s/it]


Epoch 220: TRAIN Accuracy = 0.841 Loss = 0.299 AUC = 0.925
Epoch 220: VAL Accuracy = 0.817 Loss = 0.304 AUC = 0.924


10it [01:27,  8.79s/it]


Epoch 225: TRAIN Accuracy = 0.827 Loss = 0.299 AUC = 0.924
Epoch 225: VAL Accuracy = 0.82 Loss = 0.316 AUC = 0.924


10it [01:36,  9.63s/it]


Epoch 230: TRAIN Accuracy = 0.823 Loss = 0.304 AUC = 0.919
Epoch 230: VAL Accuracy = 0.807 Loss = 0.303 AUC = 0.916


10it [01:38,  9.82s/it]


Epoch 235: TRAIN Accuracy = 0.827 Loss = 0.29 AUC = 0.929
Epoch 235: VAL Accuracy = 0.817 Loss = 0.309 AUC = 0.923


10it [01:36,  9.65s/it]


Epoch 240: TRAIN Accuracy = 0.814 Loss = 0.311 AUC = 0.922
Epoch 240: VAL Accuracy = 0.823 Loss = 0.3 AUC = 0.921


10it [01:39,  9.92s/it]


Epoch 245: TRAIN Accuracy = 0.819 Loss = 0.306 AUC = 0.918
Epoch 245: VAL Accuracy = 0.817 Loss = 0.313 AUC = 0.916


10it [01:35,  9.51s/it]


Epoch 250: TRAIN Accuracy = 0.816 Loss = 0.314 AUC = 0.919
Epoch 250: VAL Accuracy = 0.803 Loss = 0.317 AUC = 0.911


10it [01:37,  9.77s/it]


Epoch 255: TRAIN Accuracy = 0.836 Loss = 0.303 AUC = 0.919
Epoch 255: VAL Accuracy = 0.82 Loss = 0.315 AUC = 0.916


10it [01:37,  9.77s/it]


Epoch 260: TRAIN Accuracy = 0.816 Loss = 0.315 AUC = 0.916
Epoch 260: VAL Accuracy = 0.81 Loss = 0.304 AUC = 0.921


10it [01:38,  9.83s/it]


Epoch 265: TRAIN Accuracy = 0.839 Loss = 0.308 AUC = 0.923
Epoch 265: VAL Accuracy = 0.827 Loss = 0.304 AUC = 0.922


10it [01:35,  9.58s/it]


Epoch 270: TRAIN Accuracy = 0.844 Loss = 0.288 AUC = 0.926
Epoch 270: VAL Accuracy = 0.817 Loss = 0.319 AUC = 0.914


10it [01:38,  9.83s/it]


Epoch 275: TRAIN Accuracy = 0.837 Loss = 0.292 AUC = 0.923
Epoch 275: VAL Accuracy = 0.813 Loss = 0.332 AUC = 0.904


10it [01:37,  9.78s/it]


Epoch 280: TRAIN Accuracy = 0.836 Loss = 0.309 AUC = 0.915
Epoch 280: VAL Accuracy = 0.823 Loss = 0.305 AUC = 0.913


10it [01:39,  9.92s/it]


Epoch 285: TRAIN Accuracy = 0.817 Loss = 0.32 AUC = 0.905
Epoch 285: VAL Accuracy = 0.787 Loss = 0.346 AUC = 0.901


10it [01:38,  9.83s/it]


Epoch 290: TRAIN Accuracy = 0.844 Loss = 0.304 AUC = 0.922
Epoch 290: VAL Accuracy = 0.82 Loss = 0.306 AUC = 0.916


10it [01:37,  9.79s/it]


Epoch 295: TRAIN Accuracy = 0.843 Loss = 0.293 AUC = 0.924
Epoch 295: VAL Accuracy = 0.827 Loss = 0.31 AUC = 0.904


10it [01:38,  9.80s/it]

Epoch 300: TRAIN Accuracy = 0.853 Loss = 0.282 AUC = 0.926
Epoch 300: VAL Accuracy = 0.807 Loss = 0.306 AUC = 0.91
CPU times: user 2h 46min 3s, sys: 53min 2s, total: 3h 39min 6s
Wall time: 1h 29min 40s



