In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Concatenate, Dense, LSTM, GRU, Input, BatchNormalization, Dropout, TimeDistributed
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-17 02:51:50.536776: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-17 02:51:50.675055: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-17 02:51:50.791869: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742169110.871245   30260 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742169110.893723   30260 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-17 02:51:51.106170: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 300
STEP = 5
NUM_EXPERIMENTS = 10

def create_model(train):
    input_layer = Input(shape=(train.shape[1], train.shape[2]))

    x1 = Dropout(0.2)(input_layer)
    x1 = LSTM(10, return_sequences=True)(x1)
    x1 = Dropout(0.2)(x1)
    x1 = LSTM(10, return_sequences=True)(x1)
    x1 = Dropout(0.2)(x1)
    x1 = LSTM(10, return_sequences=False)(x1)
    x1 = Dropout(0.2)(x1)
    
    x2 = Dropout(0.2)(input_layer)
    x2 = GRU(10, return_sequences=True)(x2)
    x2 = Dropout(0.2)(x2)
    x2 = GRU(10, return_sequences=True)(x2)
    x2 = Dropout(0.2)(x2)
    x2 = GRU(10, return_sequences=False)(x2)
    x2 = Dropout(0.2)(x2)

    x3 = TimeDistributed(Dense(32, activation='relu'))(input_layer)
    x3 = LTC(AutoNCP(32, 25), return_sequences=True)(x3)
    x3 = LTC(AutoNCP(25, 16), return_sequences=False)(x3)

    concatenated = Concatenate()([x1, x2, x3])

    fc = Dense(16, activation='relu')(concatenated)
    fc = Dense(1, activation='sigmoid')(fc)
    
    model = Model(inputs=input_layer, outputs=fc)
    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    
    return model

# Experiment

In [3]:
ID = ["ID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-17 02:51:54.708416: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

models = []

for i in range(NUM_EPOCHS // STEP):
    epoch_acc = []
    epoch_loss = []
    epoch_auc = []

    epoch_val_acc = []
    epoch_val_loss = []
    epoch_val_auc = []

    for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
        np.random.seed(int(seed))
        random.seed(int(seed))
        tf.random.set_seed(int(seed))

        train_id = np.random.choice(np.unique(np.ravel(data[ID])), 70, replace=False)
        train_index = np.isin(data[ID], train_id)

        train = data.iloc[train_index]
        test = data.iloc[~train_index]

        X_train, y_train = reshape_dataset(train)
        X_test, y_test = reshape_dataset(test)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)

        if i == 0:
            model = create_model(X_train)
            models.append(model)
        else:
            model = models[j]

        history = model.fit(
            X_train, y_train,
            validation_data=(X_test, y_test),
            epochs=STEP,
            batch_size=16,
            verbose=0,
        )
        acc = history.history['accuracy'][0]
        loss = history.history['loss'][0]
        auc = history.history['auc'][0]

        val_acc = history.history['val_accuracy'][0]
        val_loss = history.history['val_loss'][0]
        val_auc = history.history['val_auc'][0]

        epoch_acc.append(acc)
        epoch_loss.append(loss)
        epoch_auc.append(auc)

        epoch_val_acc.append(val_acc)
        epoch_val_loss.append(val_loss)
        epoch_val_auc.append(val_auc)

    print(f"Epoch {(i + 1) * STEP}: TRAIN Accuracy = {np.round(np.mean(epoch_acc), 3)} Loss = {np.round(np.mean(epoch_loss), 3)} AUC = {np.round(np.mean(epoch_auc), 3)}")
    print(f"Epoch {(i + 1) * STEP}: VAL Accuracy = {np.round(np.mean(epoch_val_acc), 3)} Loss = {np.round(np.mean(epoch_val_loss), 3)} AUC = {np.round(np.mean(epoch_val_auc), 3)}")

10it [11:19, 67.91s/it]


Epoch 5: TRAIN Accuracy = 0.537 Loss = 0.688 AUC = 0.594
Epoch 5: VAL Accuracy = 0.617 Loss = 0.658 AUC = 0.715


10it [07:04, 42.47s/it]


Epoch 10: TRAIN Accuracy = 0.776 Loss = 0.511 AUC = 0.778
Epoch 10: VAL Accuracy = 0.767 Loss = 0.512 AUC = 0.769


10it [07:21, 44.14s/it]


Epoch 15: TRAIN Accuracy = 0.804 Loss = 0.434 AUC = 0.821
Epoch 15: VAL Accuracy = 0.793 Loss = 0.45 AUC = 0.831


10it [06:45, 40.55s/it]


Epoch 20: TRAIN Accuracy = 0.81 Loss = 0.394 AUC = 0.86
Epoch 20: VAL Accuracy = 0.77 Loss = 0.442 AUC = 0.846


10it [06:48, 40.80s/it]


Epoch 25: TRAIN Accuracy = 0.806 Loss = 0.382 AUC = 0.863
Epoch 25: VAL Accuracy = 0.777 Loss = 0.439 AUC = 0.858


10it [06:48, 40.85s/it]


Epoch 30: TRAIN Accuracy = 0.807 Loss = 0.362 AUC = 0.893
Epoch 30: VAL Accuracy = 0.763 Loss = 0.419 AUC = 0.867


10it [06:42, 40.27s/it]


Epoch 35: TRAIN Accuracy = 0.796 Loss = 0.353 AUC = 0.901
Epoch 35: VAL Accuracy = 0.767 Loss = 0.439 AUC = 0.854


10it [07:36, 45.65s/it]


Epoch 40: TRAIN Accuracy = 0.804 Loss = 0.357 AUC = 0.883
Epoch 40: VAL Accuracy = 0.757 Loss = 0.441 AUC = 0.849


10it [06:38, 39.81s/it]


Epoch 45: TRAIN Accuracy = 0.826 Loss = 0.34 AUC = 0.905
Epoch 45: VAL Accuracy = 0.77 Loss = 0.426 AUC = 0.872


10it [05:32, 33.29s/it]


Epoch 50: TRAIN Accuracy = 0.819 Loss = 0.332 AUC = 0.92
Epoch 50: VAL Accuracy = 0.773 Loss = 0.412 AUC = 0.875


10it [05:36, 33.66s/it]


Epoch 55: TRAIN Accuracy = 0.821 Loss = 0.336 AUC = 0.91
Epoch 55: VAL Accuracy = 0.777 Loss = 0.429 AUC = 0.878


10it [05:30, 33.10s/it]


Epoch 60: TRAIN Accuracy = 0.811 Loss = 0.337 AUC = 0.904
Epoch 60: VAL Accuracy = 0.783 Loss = 0.411 AUC = 0.887


10it [05:31, 33.14s/it]


Epoch 65: TRAIN Accuracy = 0.827 Loss = 0.313 AUC = 0.923
Epoch 65: VAL Accuracy = 0.78 Loss = 0.398 AUC = 0.899


10it [05:28, 32.80s/it]


Epoch 70: TRAIN Accuracy = 0.819 Loss = 0.333 AUC = 0.907
Epoch 70: VAL Accuracy = 0.777 Loss = 0.416 AUC = 0.896


10it [05:29, 32.91s/it]


Epoch 75: TRAIN Accuracy = 0.839 Loss = 0.317 AUC = 0.92
Epoch 75: VAL Accuracy = 0.783 Loss = 0.427 AUC = 0.897


10it [05:29, 33.00s/it]


Epoch 80: TRAIN Accuracy = 0.82 Loss = 0.333 AUC = 0.912
Epoch 80: VAL Accuracy = 0.787 Loss = 0.427 AUC = 0.893


10it [05:27, 32.79s/it]


Epoch 85: TRAIN Accuracy = 0.83 Loss = 0.324 AUC = 0.921
Epoch 85: VAL Accuracy = 0.797 Loss = 0.427 AUC = 0.896


10it [05:30, 33.08s/it]


Epoch 90: TRAIN Accuracy = 0.829 Loss = 0.313 AUC = 0.919
Epoch 90: VAL Accuracy = 0.807 Loss = 0.398 AUC = 0.893


10it [05:30, 33.06s/it]


Epoch 95: TRAIN Accuracy = 0.849 Loss = 0.3 AUC = 0.934
Epoch 95: VAL Accuracy = 0.8 Loss = 0.422 AUC = 0.889


10it [05:31, 33.11s/it]


Epoch 100: TRAIN Accuracy = 0.827 Loss = 0.329 AUC = 0.915
Epoch 100: VAL Accuracy = 0.803 Loss = 0.396 AUC = 0.898


10it [05:30, 33.05s/it]


Epoch 105: TRAIN Accuracy = 0.831 Loss = 0.298 AUC = 0.929
Epoch 105: VAL Accuracy = 0.79 Loss = 0.42 AUC = 0.89


10it [05:30, 33.05s/it]


Epoch 110: TRAIN Accuracy = 0.844 Loss = 0.296 AUC = 0.93
Epoch 110: VAL Accuracy = 0.787 Loss = 0.423 AUC = 0.903


10it [05:33, 33.33s/it]


Epoch 115: TRAIN Accuracy = 0.843 Loss = 0.305 AUC = 0.926
Epoch 115: VAL Accuracy = 0.79 Loss = 0.393 AUC = 0.91


10it [05:34, 33.44s/it]


Epoch 120: TRAIN Accuracy = 0.839 Loss = 0.304 AUC = 0.928
Epoch 120: VAL Accuracy = 0.807 Loss = 0.397 AUC = 0.913


10it [05:31, 33.13s/it]


Epoch 125: TRAIN Accuracy = 0.844 Loss = 0.296 AUC = 0.933
Epoch 125: VAL Accuracy = 0.793 Loss = 0.399 AUC = 0.915


10it [05:28, 32.82s/it]


Epoch 130: TRAIN Accuracy = 0.829 Loss = 0.298 AUC = 0.93
Epoch 130: VAL Accuracy = 0.793 Loss = 0.405 AUC = 0.914


10it [05:29, 32.99s/it]


Epoch 135: TRAIN Accuracy = 0.844 Loss = 0.299 AUC = 0.936
Epoch 135: VAL Accuracy = 0.803 Loss = 0.384 AUC = 0.913


10it [05:29, 32.94s/it]


Epoch 140: TRAIN Accuracy = 0.839 Loss = 0.284 AUC = 0.937
Epoch 140: VAL Accuracy = 0.773 Loss = 0.413 AUC = 0.905


10it [05:28, 32.90s/it]


Epoch 145: TRAIN Accuracy = 0.823 Loss = 0.301 AUC = 0.931
Epoch 145: VAL Accuracy = 0.8 Loss = 0.401 AUC = 0.907


10it [05:30, 33.02s/it]


Epoch 150: TRAIN Accuracy = 0.854 Loss = 0.295 AUC = 0.932
Epoch 150: VAL Accuracy = 0.8 Loss = 0.381 AUC = 0.904


10it [05:29, 32.91s/it]


Epoch 155: TRAIN Accuracy = 0.841 Loss = 0.289 AUC = 0.932
Epoch 155: VAL Accuracy = 0.79 Loss = 0.428 AUC = 0.897


10it [05:32, 33.29s/it]


Epoch 160: TRAIN Accuracy = 0.846 Loss = 0.295 AUC = 0.935
Epoch 160: VAL Accuracy = 0.783 Loss = 0.412 AUC = 0.907


10it [05:32, 33.25s/it]


Epoch 165: TRAIN Accuracy = 0.843 Loss = 0.289 AUC = 0.937
Epoch 165: VAL Accuracy = 0.807 Loss = 0.423 AUC = 0.91


10it [05:33, 33.36s/it]


Epoch 170: TRAIN Accuracy = 0.849 Loss = 0.283 AUC = 0.935
Epoch 170: VAL Accuracy = 0.793 Loss = 0.382 AUC = 0.918


10it [05:36, 33.69s/it]


Epoch 175: TRAIN Accuracy = 0.849 Loss = 0.297 AUC = 0.937
Epoch 175: VAL Accuracy = 0.787 Loss = 0.433 AUC = 0.901


10it [05:41, 34.13s/it]


Epoch 180: TRAIN Accuracy = 0.84 Loss = 0.285 AUC = 0.933
Epoch 180: VAL Accuracy = 0.807 Loss = 0.383 AUC = 0.908


10it [05:32, 33.26s/it]


Epoch 185: TRAIN Accuracy = 0.854 Loss = 0.281 AUC = 0.941
Epoch 185: VAL Accuracy = 0.8 Loss = 0.362 AUC = 0.921


10it [05:33, 33.31s/it]


Epoch 190: TRAIN Accuracy = 0.837 Loss = 0.29 AUC = 0.934
Epoch 190: VAL Accuracy = 0.797 Loss = 0.407 AUC = 0.918


10it [05:31, 33.11s/it]


Epoch 195: TRAIN Accuracy = 0.854 Loss = 0.27 AUC = 0.948
Epoch 195: VAL Accuracy = 0.793 Loss = 0.398 AUC = 0.914


10it [05:32, 33.26s/it]


Epoch 200: TRAIN Accuracy = 0.849 Loss = 0.273 AUC = 0.945
Epoch 200: VAL Accuracy = 0.8 Loss = 0.398 AUC = 0.917


10it [05:33, 33.34s/it]


Epoch 205: TRAIN Accuracy = 0.846 Loss = 0.272 AUC = 0.946
Epoch 205: VAL Accuracy = 0.797 Loss = 0.381 AUC = 0.913


10it [05:32, 33.29s/it]


Epoch 210: TRAIN Accuracy = 0.856 Loss = 0.278 AUC = 0.947
Epoch 210: VAL Accuracy = 0.803 Loss = 0.401 AUC = 0.912


10it [05:36, 33.68s/it]


Epoch 215: TRAIN Accuracy = 0.856 Loss = 0.268 AUC = 0.948
Epoch 215: VAL Accuracy = 0.817 Loss = 0.384 AUC = 0.915


10it [05:34, 33.46s/it]


Epoch 220: TRAIN Accuracy = 0.859 Loss = 0.266 AUC = 0.948
Epoch 220: VAL Accuracy = 0.823 Loss = 0.388 AUC = 0.908


10it [05:33, 33.31s/it]


Epoch 225: TRAIN Accuracy = 0.854 Loss = 0.282 AUC = 0.946
Epoch 225: VAL Accuracy = 0.787 Loss = 0.408 AUC = 0.901


10it [05:46, 34.63s/it]


Epoch 230: TRAIN Accuracy = 0.873 Loss = 0.243 AUC = 0.956
Epoch 230: VAL Accuracy = 0.81 Loss = 0.394 AUC = 0.903


10it [05:34, 33.48s/it]


Epoch 235: TRAIN Accuracy = 0.869 Loss = 0.251 AUC = 0.955
Epoch 235: VAL Accuracy = 0.807 Loss = 0.426 AUC = 0.899


10it [05:33, 33.37s/it]


Epoch 240: TRAIN Accuracy = 0.866 Loss = 0.254 AUC = 0.954
Epoch 240: VAL Accuracy = 0.803 Loss = 0.46 AUC = 0.897


10it [05:33, 33.35s/it]


Epoch 245: TRAIN Accuracy = 0.859 Loss = 0.284 AUC = 0.946
Epoch 245: VAL Accuracy = 0.823 Loss = 0.408 AUC = 0.906


10it [05:35, 33.56s/it]


Epoch 250: TRAIN Accuracy = 0.87 Loss = 0.266 AUC = 0.952
Epoch 250: VAL Accuracy = 0.82 Loss = 0.405 AUC = 0.907


10it [05:34, 33.49s/it]


Epoch 255: TRAIN Accuracy = 0.853 Loss = 0.266 AUC = 0.953
Epoch 255: VAL Accuracy = 0.807 Loss = 0.411 AUC = 0.91


10it [05:34, 33.42s/it]


Epoch 260: TRAIN Accuracy = 0.866 Loss = 0.249 AUC = 0.957
Epoch 260: VAL Accuracy = 0.81 Loss = 0.378 AUC = 0.909


10it [05:34, 33.45s/it]


Epoch 265: TRAIN Accuracy = 0.873 Loss = 0.248 AUC = 0.955
Epoch 265: VAL Accuracy = 0.793 Loss = 0.395 AUC = 0.909


10it [05:36, 33.65s/it]


Epoch 270: TRAIN Accuracy = 0.861 Loss = 0.257 AUC = 0.95
Epoch 270: VAL Accuracy = 0.817 Loss = 0.36 AUC = 0.917


10it [05:38, 33.82s/it]


Epoch 275: TRAIN Accuracy = 0.864 Loss = 0.275 AUC = 0.948
Epoch 275: VAL Accuracy = 0.813 Loss = 0.383 AUC = 0.913


10it [05:36, 33.62s/it]


Epoch 280: TRAIN Accuracy = 0.869 Loss = 0.254 AUC = 0.956
Epoch 280: VAL Accuracy = 0.83 Loss = 0.368 AUC = 0.921


10it [05:38, 33.83s/it]


Epoch 285: TRAIN Accuracy = 0.863 Loss = 0.253 AUC = 0.958
Epoch 285: VAL Accuracy = 0.823 Loss = 0.378 AUC = 0.914


10it [05:35, 33.55s/it]


Epoch 290: TRAIN Accuracy = 0.881 Loss = 0.246 AUC = 0.959
Epoch 290: VAL Accuracy = 0.83 Loss = 0.356 AUC = 0.923


10it [05:36, 33.69s/it]


Epoch 295: TRAIN Accuracy = 0.85 Loss = 0.275 AUC = 0.948
Epoch 295: VAL Accuracy = 0.803 Loss = 0.4 AUC = 0.917


10it [05:35, 33.51s/it]

Epoch 300: TRAIN Accuracy = 0.869 Loss = 0.25 AUC = 0.956
Epoch 300: VAL Accuracy = 0.81 Loss = 0.361 AUC = 0.918
CPU times: user 15h 25min 43s, sys: 5h 2min 36s, total: 20h 28min 19s
Wall time: 5h 50min 21s
Parser   : 138 ms



