In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout, TimeDistributed
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-16 03:08:58.632900: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-16 03:08:58.637843: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-16 03:08:58.655338: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742083738.683228  495713 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742083738.689512  495713 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-16 03:08:58.712192: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 300
STEP = 5
NUM_EXPERIMENTS = 10

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(TimeDistributed(Dense(32, activation='relu')))

    model.add(LTC(AutoNCP(32, 25), return_sequences=True))
    model.add(LTC(AutoNCP(25, 16), return_sequences=False))

    model.add(Dense(16, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-16 03:09:07.634060: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

models = []

for i in range(NUM_EPOCHS // STEP):
    epoch_acc = []
    epoch_loss = []
    epoch_auc = []

    epoch_val_acc = []
    epoch_val_loss = []
    epoch_val_auc = []

    for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
        np.random.seed(int(seed))
        random.seed(int(seed))
        tf.random.set_seed(int(seed))

        train_id = np.random.choice(np.unique(np.ravel(data[ID])), 70, replace=False)
        train_index = np.isin(data[ID], train_id)

        train = data.iloc[train_index]
        test = data.iloc[~train_index]

        X_train, y_train = reshape_dataset(train)
        X_test, y_test = reshape_dataset(test)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)

        if i == 0:
            model = create_model(X_train)
            models.append(model)
        else:
            model = models[j]

        history = model.fit(
            X_train, y_train,
            validation_data=(X_test, y_test),
            epochs=STEP,
            batch_size=16,
            verbose=0,
        )
        acc = history.history['accuracy'][0]
        loss = history.history['loss'][0]
        auc = history.history['auc'][0]

        val_acc = history.history['val_accuracy'][0]
        val_loss = history.history['val_loss'][0]
        val_auc = history.history['val_auc'][0]

        epoch_acc.append(acc)
        epoch_loss.append(loss)
        epoch_auc.append(auc)

        epoch_val_acc.append(val_acc)
        epoch_val_loss.append(val_loss)
        epoch_val_auc.append(val_auc)

    print(f"Epoch {(i + 1) * STEP}: TRAIN Accuracy = {np.round(np.mean(epoch_acc), 3)} Loss = {np.round(np.mean(epoch_loss), 3)} AUC = {np.round(np.mean(epoch_auc), 3)}")
    print(f"Epoch {(i + 1) * STEP}: VAL Accuracy = {np.round(np.mean(epoch_val_acc), 3)} Loss = {np.round(np.mean(epoch_val_loss), 3)} AUC = {np.round(np.mean(epoch_val_auc), 3)}")

0it [00:00, ?it/s]

10it [13:07, 78.77s/it]


Epoch 5: TRAIN Accuracy = 0.487 Loss = 0.703 AUC = 0.437
Epoch 5: VAL Accuracy = 0.49 Loss = 0.698 AUC = 0.589


10it [07:10, 43.02s/it]


Epoch 10: TRAIN Accuracy = 0.574 Loss = 0.675 AUC = 0.668
Epoch 10: VAL Accuracy = 0.557 Loss = 0.675 AUC = 0.736


10it [06:14, 37.48s/it]


Epoch 15: TRAIN Accuracy = 0.719 Loss = 0.606 AUC = 0.702
Epoch 15: VAL Accuracy = 0.717 Loss = 0.599 AUC = 0.717


10it [06:31, 39.14s/it]


Epoch 20: TRAIN Accuracy = 0.749 Loss = 0.509 AUC = 0.765
Epoch 20: VAL Accuracy = 0.773 Loss = 0.51 AUC = 0.785


10it [07:08, 42.81s/it]


Epoch 25: TRAIN Accuracy = 0.777 Loss = 0.457 AUC = 0.782
Epoch 25: VAL Accuracy = 0.803 Loss = 0.458 AUC = 0.803


10it [07:05, 42.54s/it]


Epoch 30: TRAIN Accuracy = 0.811 Loss = 0.424 AUC = 0.798
Epoch 30: VAL Accuracy = 0.807 Loss = 0.431 AUC = 0.801


10it [07:00, 42.08s/it]


Epoch 35: TRAIN Accuracy = 0.811 Loss = 0.41 AUC = 0.805
Epoch 35: VAL Accuracy = 0.807 Loss = 0.42 AUC = 0.829


10it [05:02, 30.22s/it]


Epoch 40: TRAIN Accuracy = 0.811 Loss = 0.404 AUC = 0.828
Epoch 40: VAL Accuracy = 0.807 Loss = 0.415 AUC = 0.841


10it [05:02, 30.25s/it]


Epoch 45: TRAIN Accuracy = 0.811 Loss = 0.398 AUC = 0.843
Epoch 45: VAL Accuracy = 0.807 Loss = 0.409 AUC = 0.856


10it [05:04, 30.41s/it]


Epoch 50: TRAIN Accuracy = 0.811 Loss = 0.393 AUC = 0.858
Epoch 50: VAL Accuracy = 0.807 Loss = 0.403 AUC = 0.872


10it [05:01, 30.17s/it]


Epoch 55: TRAIN Accuracy = 0.811 Loss = 0.386 AUC = 0.866
Epoch 55: VAL Accuracy = 0.807 Loss = 0.398 AUC = 0.883


10it [05:02, 30.30s/it]


Epoch 60: TRAIN Accuracy = 0.811 Loss = 0.379 AUC = 0.872
Epoch 60: VAL Accuracy = 0.807 Loss = 0.394 AUC = 0.882


10it [05:04, 30.46s/it]


Epoch 65: TRAIN Accuracy = 0.807 Loss = 0.374 AUC = 0.875
Epoch 65: VAL Accuracy = 0.803 Loss = 0.391 AUC = 0.882


10it [05:03, 30.33s/it]


Epoch 70: TRAIN Accuracy = 0.81 Loss = 0.37 AUC = 0.875
Epoch 70: VAL Accuracy = 0.807 Loss = 0.387 AUC = 0.881


10it [05:04, 30.48s/it]


Epoch 75: TRAIN Accuracy = 0.804 Loss = 0.363 AUC = 0.878
Epoch 75: VAL Accuracy = 0.8 Loss = 0.385 AUC = 0.883


10it [04:58, 29.89s/it]


Epoch 80: TRAIN Accuracy = 0.811 Loss = 0.36 AUC = 0.879
Epoch 80: VAL Accuracy = 0.797 Loss = 0.387 AUC = 0.875


10it [04:57, 29.75s/it]


Epoch 85: TRAIN Accuracy = 0.804 Loss = 0.355 AUC = 0.885
Epoch 85: VAL Accuracy = 0.787 Loss = 0.387 AUC = 0.877


10it [04:59, 29.93s/it]


Epoch 90: TRAIN Accuracy = 0.81 Loss = 0.35 AUC = 0.888
Epoch 90: VAL Accuracy = 0.787 Loss = 0.398 AUC = 0.895


10it [05:01, 30.12s/it]


Epoch 95: TRAIN Accuracy = 0.803 Loss = 0.346 AUC = 0.888
Epoch 95: VAL Accuracy = 0.78 Loss = 0.394 AUC = 0.884


10it [05:02, 30.26s/it]


Epoch 100: TRAIN Accuracy = 0.803 Loss = 0.344 AUC = 0.892
Epoch 100: VAL Accuracy = 0.783 Loss = 0.402 AUC = 0.887


10it [05:03, 30.33s/it]


Epoch 105: TRAIN Accuracy = 0.803 Loss = 0.339 AUC = 0.893
Epoch 105: VAL Accuracy = 0.767 Loss = 0.393 AUC = 0.883


10it [05:03, 30.37s/it]


Epoch 110: TRAIN Accuracy = 0.806 Loss = 0.337 AUC = 0.895
Epoch 110: VAL Accuracy = 0.77 Loss = 0.412 AUC = 0.888


10it [05:02, 30.28s/it]


Epoch 115: TRAIN Accuracy = 0.811 Loss = 0.331 AUC = 0.901
Epoch 115: VAL Accuracy = 0.78 Loss = 0.425 AUC = 0.888


10it [05:02, 30.24s/it]


Epoch 120: TRAIN Accuracy = 0.806 Loss = 0.347 AUC = 0.886
Epoch 120: VAL Accuracy = 0.77 Loss = 0.512 AUC = 0.856


10it [05:02, 30.27s/it]


Epoch 125: TRAIN Accuracy = 0.813 Loss = 0.341 AUC = 0.89
Epoch 125: VAL Accuracy = 0.79 Loss = 0.433 AUC = 0.881


10it [05:01, 30.16s/it]


Epoch 130: TRAIN Accuracy = 0.813 Loss = 0.325 AUC = 0.9
Epoch 130: VAL Accuracy = 0.77 Loss = 0.425 AUC = 0.881


10it [05:03, 30.31s/it]


Epoch 135: TRAIN Accuracy = 0.817 Loss = 0.318 AUC = 0.909
Epoch 135: VAL Accuracy = 0.767 Loss = 0.474 AUC = 0.885


10it [05:03, 30.35s/it]


Epoch 140: TRAIN Accuracy = 0.82 Loss = 0.312 AUC = 0.915
Epoch 140: VAL Accuracy = 0.773 Loss = 0.514 AUC = 0.88


10it [05:03, 30.33s/it]


Epoch 145: TRAIN Accuracy = 0.826 Loss = 0.308 AUC = 0.914
Epoch 145: VAL Accuracy = 0.793 Loss = 0.529 AUC = 0.882


10it [05:03, 30.30s/it]


Epoch 150: TRAIN Accuracy = 0.826 Loss = 0.304 AUC = 0.918
Epoch 150: VAL Accuracy = 0.78 Loss = 0.525 AUC = 0.88


10it [05:03, 30.34s/it]


Epoch 155: TRAIN Accuracy = 0.814 Loss = 0.323 AUC = 0.91
Epoch 155: VAL Accuracy = 0.773 Loss = 0.482 AUC = 0.883


10it [05:02, 30.29s/it]


Epoch 160: TRAIN Accuracy = 0.823 Loss = 0.305 AUC = 0.912
Epoch 160: VAL Accuracy = 0.78 Loss = 0.485 AUC = 0.887


10it [05:06, 30.63s/it]


Epoch 165: TRAIN Accuracy = 0.83 Loss = 0.303 AUC = 0.914
Epoch 165: VAL Accuracy = 0.787 Loss = 0.507 AUC = 0.888


10it [04:59, 29.97s/it]


Epoch 170: TRAIN Accuracy = 0.827 Loss = 0.296 AUC = 0.916
Epoch 170: VAL Accuracy = 0.79 Loss = 0.493 AUC = 0.893


10it [05:00, 30.01s/it]


Epoch 175: TRAIN Accuracy = 0.84 Loss = 0.304 AUC = 0.911
Epoch 175: VAL Accuracy = 0.783 Loss = 0.505 AUC = 0.886


10it [05:02, 30.30s/it]


Epoch 180: TRAIN Accuracy = 0.839 Loss = 0.29 AUC = 0.923
Epoch 180: VAL Accuracy = 0.803 Loss = 0.441 AUC = 0.902


10it [05:13, 31.35s/it]


Epoch 185: TRAIN Accuracy = 0.839 Loss = 0.284 AUC = 0.921
Epoch 185: VAL Accuracy = 0.807 Loss = 0.471 AUC = 0.9


10it [05:03, 30.33s/it]


Epoch 190: TRAIN Accuracy = 0.834 Loss = 0.292 AUC = 0.919
Epoch 190: VAL Accuracy = 0.797 Loss = 0.455 AUC = 0.902


10it [05:03, 30.39s/it]


Epoch 195: TRAIN Accuracy = 0.833 Loss = 0.288 AUC = 0.909
Epoch 195: VAL Accuracy = 0.783 Loss = 0.529 AUC = 0.882


10it [05:04, 30.45s/it]


Epoch 200: TRAIN Accuracy = 0.817 Loss = 0.304 AUC = 0.911
Epoch 200: VAL Accuracy = 0.8 Loss = 0.589 AUC = 0.874


10it [05:04, 30.42s/it]


Epoch 205: TRAIN Accuracy = 0.833 Loss = 0.294 AUC = 0.92
Epoch 205: VAL Accuracy = 0.793 Loss = 0.488 AUC = 0.892


10it [05:04, 30.45s/it]


Epoch 210: TRAIN Accuracy = 0.839 Loss = 0.296 AUC = 0.919
Epoch 210: VAL Accuracy = 0.793 Loss = 0.522 AUC = 0.88


10it [05:04, 30.44s/it]


Epoch 215: TRAIN Accuracy = 0.821 Loss = 0.301 AUC = 0.916
Epoch 215: VAL Accuracy = 0.797 Loss = 0.477 AUC = 0.896


10it [05:03, 30.39s/it]


Epoch 220: TRAIN Accuracy = 0.837 Loss = 0.286 AUC = 0.918
Epoch 220: VAL Accuracy = 0.797 Loss = 0.481 AUC = 0.895


10it [05:07, 30.72s/it]


Epoch 225: TRAIN Accuracy = 0.837 Loss = 0.293 AUC = 0.91
Epoch 225: VAL Accuracy = 0.8 Loss = 0.481 AUC = 0.892


10it [05:04, 30.41s/it]


Epoch 230: TRAIN Accuracy = 0.834 Loss = 0.289 AUC = 0.91
Epoch 230: VAL Accuracy = 0.793 Loss = 0.509 AUC = 0.889


10it [05:05, 30.59s/it]


Epoch 235: TRAIN Accuracy = 0.834 Loss = 0.285 AUC = 0.921
Epoch 235: VAL Accuracy = 0.79 Loss = 0.538 AUC = 0.883


10it [05:06, 30.63s/it]


Epoch 240: TRAIN Accuracy = 0.827 Loss = 0.311 AUC = 0.907
Epoch 240: VAL Accuracy = 0.79 Loss = 0.488 AUC = 0.883


10it [05:04, 30.44s/it]


Epoch 245: TRAIN Accuracy = 0.824 Loss = 0.302 AUC = 0.906
Epoch 245: VAL Accuracy = 0.787 Loss = 0.471 AUC = 0.886


10it [05:05, 30.53s/it]


Epoch 250: TRAIN Accuracy = 0.821 Loss = 0.316 AUC = 0.902
Epoch 250: VAL Accuracy = 0.783 Loss = 0.498 AUC = 0.878


10it [05:08, 30.87s/it]


Epoch 255: TRAIN Accuracy = 0.824 Loss = 0.299 AUC = 0.909
Epoch 255: VAL Accuracy = 0.79 Loss = 0.498 AUC = 0.89


10it [04:56, 29.69s/it]


Epoch 260: TRAIN Accuracy = 0.821 Loss = 0.309 AUC = 0.902
Epoch 260: VAL Accuracy = 0.78 Loss = 0.442 AUC = 0.892


10it [05:02, 30.21s/it]


Epoch 265: TRAIN Accuracy = 0.81 Loss = 0.312 AUC = 0.909
Epoch 265: VAL Accuracy = 0.78 Loss = 0.448 AUC = 0.891


10it [04:59, 29.95s/it]


Epoch 270: TRAIN Accuracy = 0.826 Loss = 0.299 AUC = 0.914
Epoch 270: VAL Accuracy = 0.793 Loss = 0.543 AUC = 0.89


10it [05:01, 30.12s/it]


Epoch 275: TRAIN Accuracy = 0.817 Loss = 0.319 AUC = 0.899
Epoch 275: VAL Accuracy = 0.813 Loss = 0.463 AUC = 0.895


10it [05:00, 30.07s/it]


Epoch 280: TRAIN Accuracy = 0.82 Loss = 0.318 AUC = 0.911
Epoch 280: VAL Accuracy = 0.807 Loss = 0.474 AUC = 0.901


10it [05:02, 30.28s/it]


Epoch 285: TRAIN Accuracy = 0.821 Loss = 0.309 AUC = 0.911
Epoch 285: VAL Accuracy = 0.797 Loss = 0.502 AUC = 0.876


10it [05:00, 30.04s/it]


Epoch 290: TRAIN Accuracy = 0.824 Loss = 0.296 AUC = 0.913
Epoch 290: VAL Accuracy = 0.78 Loss = 0.555 AUC = 0.873


10it [04:59, 29.99s/it]


Epoch 295: TRAIN Accuracy = 0.821 Loss = 0.292 AUC = 0.915
Epoch 295: VAL Accuracy = 0.783 Loss = 0.564 AUC = 0.869


10it [05:49, 34.99s/it]

Epoch 300: TRAIN Accuracy = 0.813 Loss = 0.302 AUC = 0.916
Epoch 300: VAL Accuracy = 0.777 Loss = 0.498 AUC = 0.871
CPU times: user 11h 52min 57s, sys: 4h 44min 18s, total: 16h 37min 16s
Wall time: 5h 22min 46s



