In [None]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-02-18 09:17:42.883801: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-18 09:17:42.889983: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-18 09:17:42.909130: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1739859462.937748 1868146 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739859462.954165 1868146 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-18 09:17:43.029341: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 300
STEP = 5
NUM_EXPERIMENTS = 10

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(LTC(AutoNCP(15, 12), return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LTC(AutoNCP(12, 9), return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LTC(AutoNCP(9, 6), return_sequences=False))

    model.add(Dense(6, activation='relu'))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.0003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-02-18 09:17:46.741946: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

models = []

for i in range(NUM_EPOCHS // STEP):
    epoch_acc = []
    epoch_loss = []
    epoch_auc = []

    epoch_val_acc = []
    epoch_val_loss = []
    epoch_val_auc = []

    for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
        np.random.seed(int(seed))
        random.seed(int(seed))
        tf.random.set_seed(int(seed))

        train_id = np.random.choice(np.unique(np.ravel(data[ID])), 70, replace=False)
        train_index = np.isin(data[ID], train_id)

        train = data.iloc[train_index]
        test = data.iloc[~train_index]

        X_train, y_train = reshape_dataset(train)
        X_test, y_test = reshape_dataset(test)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)

        if i == 0:
            model = create_model(X_train)
            models.append(model)
        else:
            model = models[j]

        history = model.fit(
            X_train, y_train,
            validation_data=(X_test, y_test),
            epochs=STEP,
            batch_size=16,
            verbose=0,
        )
        acc = history.history['accuracy'][0]
        loss = history.history['loss'][0]
        auc = history.history['auc'][0]

        val_acc = history.history['val_accuracy'][0]
        val_loss = history.history['val_loss'][0]
        val_auc = history.history['val_auc'][0]

        epoch_acc.append(acc)
        epoch_loss.append(loss)
        epoch_auc.append(auc)

        epoch_val_acc.append(val_acc)
        epoch_val_loss.append(val_loss)
        epoch_val_auc.append(val_auc)

    print(f"Epoch {(i + 1) * STEP}: TRAIN Accuracy = {np.round(np.mean(epoch_acc), 3)} Loss = {np.round(np.mean(epoch_loss), 3)} AUC = {np.round(np.mean(epoch_auc), 3)}")
    print(f"Epoch {(i + 1) * STEP}: VAL Accuracy = {np.round(np.mean(epoch_val_acc), 3)} Loss = {np.round(np.mean(epoch_val_loss), 3)} AUC = {np.round(np.mean(epoch_val_auc), 3)}")

10it [13:24, 80.48s/it]


Epoch 5: TRAIN Accuracy = 0.509 Loss = 0.693 AUC = 0.481
Epoch 5: VAL Accuracy = 0.457 Loss = 0.695 AUC = 0.5


10it [05:43, 34.30s/it]


Epoch 10: TRAIN Accuracy = 0.519 Loss = 0.692 AUC = 0.505
Epoch 10: VAL Accuracy = 0.457 Loss = 0.696 AUC = 0.5


10it [05:40, 34.08s/it]


Epoch 15: TRAIN Accuracy = 0.519 Loss = 0.692 AUC = 0.506
Epoch 15: VAL Accuracy = 0.457 Loss = 0.696 AUC = 0.5


10it [05:43, 34.37s/it]


Epoch 20: TRAIN Accuracy = 0.519 Loss = 0.692 AUC = 0.497
Epoch 20: VAL Accuracy = 0.457 Loss = 0.697 AUC = 0.5


10it [06:16, 37.60s/it]


Epoch 25: TRAIN Accuracy = 0.521 Loss = 0.692 AUC = 0.489
Epoch 25: VAL Accuracy = 0.45 Loss = 0.697 AUC = 0.5


10it [06:21, 38.14s/it]


Epoch 30: TRAIN Accuracy = 0.521 Loss = 0.692 AUC = 0.498
Epoch 30: VAL Accuracy = 0.45 Loss = 0.697 AUC = 0.5


10it [06:18, 37.89s/it]


Epoch 35: TRAIN Accuracy = 0.52 Loss = 0.692 AUC = 0.496
Epoch 35: VAL Accuracy = 0.45 Loss = 0.697 AUC = 0.5


10it [06:24, 38.45s/it]


Epoch 40: TRAIN Accuracy = 0.516 Loss = 0.692 AUC = 0.508
Epoch 40: VAL Accuracy = 0.463 Loss = 0.697 AUC = 0.542


10it [06:16, 37.70s/it]


Epoch 45: TRAIN Accuracy = 0.51 Loss = 0.692 AUC = 0.495
Epoch 45: VAL Accuracy = 0.477 Loss = 0.697 AUC = 0.5


10it [06:14, 37.43s/it]


Epoch 50: TRAIN Accuracy = 0.51 Loss = 0.692 AUC = 0.507
Epoch 50: VAL Accuracy = 0.477 Loss = 0.697 AUC = 0.5


10it [06:15, 37.55s/it]


Epoch 55: TRAIN Accuracy = 0.51 Loss = 0.692 AUC = 0.51
Epoch 55: VAL Accuracy = 0.477 Loss = 0.697 AUC = 0.5


10it [06:18, 37.82s/it]


Epoch 60: TRAIN Accuracy = 0.51 Loss = 0.692 AUC = 0.537
Epoch 60: VAL Accuracy = 0.477 Loss = 0.697 AUC = 0.533


10it [06:20, 38.03s/it]


Epoch 65: TRAIN Accuracy = 0.51 Loss = 0.692 AUC = 0.515
Epoch 65: VAL Accuracy = 0.477 Loss = 0.697 AUC = 0.53


10it [06:22, 38.26s/it]


Epoch 70: TRAIN Accuracy = 0.51 Loss = 0.692 AUC = 0.528
Epoch 70: VAL Accuracy = 0.477 Loss = 0.697 AUC = 0.539


10it [06:22, 38.23s/it]


Epoch 75: TRAIN Accuracy = 0.51 Loss = 0.692 AUC = 0.524
Epoch 75: VAL Accuracy = 0.477 Loss = 0.697 AUC = 0.538


10it [06:21, 38.11s/it]


Epoch 80: TRAIN Accuracy = 0.51 Loss = 0.691 AUC = 0.558
Epoch 80: VAL Accuracy = 0.477 Loss = 0.697 AUC = 0.589


10it [06:22, 38.28s/it]


Epoch 85: TRAIN Accuracy = 0.51 Loss = 0.69 AUC = 0.565
Epoch 85: VAL Accuracy = 0.477 Loss = 0.696 AUC = 0.604


10it [06:20, 38.02s/it]


Epoch 90: TRAIN Accuracy = 0.536 Loss = 0.689 AUC = 0.603
Epoch 90: VAL Accuracy = 0.493 Loss = 0.693 AUC = 0.616


10it [06:20, 38.00s/it]


Epoch 95: TRAIN Accuracy = 0.541 Loss = 0.685 AUC = 0.598
Epoch 95: VAL Accuracy = 0.493 Loss = 0.689 AUC = 0.619


10it [06:19, 37.97s/it]


Epoch 100: TRAIN Accuracy = 0.553 Loss = 0.68 AUC = 0.598
Epoch 100: VAL Accuracy = 0.527 Loss = 0.684 AUC = 0.624


10it [06:08, 36.81s/it]


Epoch 105: TRAIN Accuracy = 0.566 Loss = 0.673 AUC = 0.595
Epoch 105: VAL Accuracy = 0.53 Loss = 0.678 AUC = 0.624


10it [06:12, 37.28s/it]


Epoch 110: TRAIN Accuracy = 0.563 Loss = 0.665 AUC = 0.603
Epoch 110: VAL Accuracy = 0.53 Loss = 0.669 AUC = 0.62


10it [06:16, 37.67s/it]


Epoch 115: TRAIN Accuracy = 0.583 Loss = 0.656 AUC = 0.599
Epoch 115: VAL Accuracy = 0.573 Loss = 0.661 AUC = 0.623


10it [06:22, 38.22s/it]


Epoch 120: TRAIN Accuracy = 0.586 Loss = 0.648 AUC = 0.605
Epoch 120: VAL Accuracy = 0.573 Loss = 0.653 AUC = 0.626


10it [06:21, 38.11s/it]


Epoch 125: TRAIN Accuracy = 0.589 Loss = 0.64 AUC = 0.617
Epoch 125: VAL Accuracy = 0.573 Loss = 0.648 AUC = 0.643


10it [06:24, 38.43s/it]


Epoch 130: TRAIN Accuracy = 0.589 Loss = 0.635 AUC = 0.619
Epoch 130: VAL Accuracy = 0.573 Loss = 0.642 AUC = 0.653


10it [06:21, 38.18s/it]


Epoch 135: TRAIN Accuracy = 0.589 Loss = 0.63 AUC = 0.635
Epoch 135: VAL Accuracy = 0.573 Loss = 0.637 AUC = 0.652


10it [06:22, 38.25s/it]


Epoch 140: TRAIN Accuracy = 0.597 Loss = 0.624 AUC = 0.634
Epoch 140: VAL Accuracy = 0.573 Loss = 0.632 AUC = 0.655


10it [06:27, 38.77s/it]


Epoch 145: TRAIN Accuracy = 0.594 Loss = 0.621 AUC = 0.64
Epoch 145: VAL Accuracy = 0.61 Loss = 0.625 AUC = 0.652


10it [06:20, 38.07s/it]


Epoch 150: TRAIN Accuracy = 0.616 Loss = 0.616 AUC = 0.645
Epoch 150: VAL Accuracy = 0.613 Loss = 0.619 AUC = 0.653


10it [06:21, 38.13s/it]


Epoch 155: TRAIN Accuracy = 0.636 Loss = 0.611 AUC = 0.652
Epoch 155: VAL Accuracy = 0.66 Loss = 0.615 AUC = 0.656


10it [06:23, 38.36s/it]


Epoch 160: TRAIN Accuracy = 0.636 Loss = 0.605 AUC = 0.641
Epoch 160: VAL Accuracy = 0.66 Loss = 0.609 AUC = 0.652


10it [06:24, 38.42s/it]


Epoch 165: TRAIN Accuracy = 0.643 Loss = 0.599 AUC = 0.639
Epoch 165: VAL Accuracy = 0.67 Loss = 0.602 AUC = 0.656


10it [06:23, 38.33s/it]


Epoch 170: TRAIN Accuracy = 0.641 Loss = 0.594 AUC = 0.652
Epoch 170: VAL Accuracy = 0.66 Loss = 0.599 AUC = 0.652


10it [06:23, 38.35s/it]


Epoch 175: TRAIN Accuracy = 0.639 Loss = 0.591 AUC = 0.638
Epoch 175: VAL Accuracy = 0.67 Loss = 0.591 AUC = 0.656


10it [06:26, 38.60s/it]


Epoch 180: TRAIN Accuracy = 0.64 Loss = 0.585 AUC = 0.645
Epoch 180: VAL Accuracy = 0.66 Loss = 0.587 AUC = 0.656


10it [06:20, 38.07s/it]


Epoch 185: TRAIN Accuracy = 0.643 Loss = 0.579 AUC = 0.657
Epoch 185: VAL Accuracy = 0.673 Loss = 0.58 AUC = 0.661


10it [06:20, 38.07s/it]


Epoch 190: TRAIN Accuracy = 0.647 Loss = 0.575 AUC = 0.639
Epoch 190: VAL Accuracy = 0.673 Loss = 0.577 AUC = 0.667


10it [06:23, 38.37s/it]


Epoch 195: TRAIN Accuracy = 0.644 Loss = 0.573 AUC = 0.651
Epoch 195: VAL Accuracy = 0.673 Loss = 0.574 AUC = 0.664


10it [06:22, 38.24s/it]


Epoch 200: TRAIN Accuracy = 0.647 Loss = 0.57 AUC = 0.649
Epoch 200: VAL Accuracy = 0.673 Loss = 0.57 AUC = 0.661


10it [06:21, 38.18s/it]


Epoch 205: TRAIN Accuracy = 0.644 Loss = 0.568 AUC = 0.642
Epoch 205: VAL Accuracy = 0.673 Loss = 0.568 AUC = 0.661


10it [06:23, 38.38s/it]


Epoch 210: TRAIN Accuracy = 0.646 Loss = 0.567 AUC = 0.659
Epoch 210: VAL Accuracy = 0.673 Loss = 0.566 AUC = 0.661


10it [06:19, 37.94s/it]


Epoch 215: TRAIN Accuracy = 0.646 Loss = 0.565 AUC = 0.652
Epoch 215: VAL Accuracy = 0.673 Loss = 0.563 AUC = 0.661


10it [06:17, 37.75s/it]


Epoch 220: TRAIN Accuracy = 0.644 Loss = 0.565 AUC = 0.657
Epoch 220: VAL Accuracy = 0.673 Loss = 0.563 AUC = 0.665


10it [06:16, 37.68s/it]


Epoch 225: TRAIN Accuracy = 0.641 Loss = 0.564 AUC = 0.653
Epoch 225: VAL Accuracy = 0.673 Loss = 0.562 AUC = 0.667


10it [06:24, 38.45s/it]


Epoch 230: TRAIN Accuracy = 0.646 Loss = 0.563 AUC = 0.652
Epoch 230: VAL Accuracy = 0.673 Loss = 0.56 AUC = 0.667


10it [06:22, 38.21s/it]


Epoch 235: TRAIN Accuracy = 0.647 Loss = 0.562 AUC = 0.661
Epoch 235: VAL Accuracy = 0.673 Loss = 0.559 AUC = 0.665


10it [06:23, 38.33s/it]


Epoch 240: TRAIN Accuracy = 0.647 Loss = 0.562 AUC = 0.661
Epoch 240: VAL Accuracy = 0.673 Loss = 0.558 AUC = 0.665


10it [06:21, 38.15s/it]


Epoch 245: TRAIN Accuracy = 0.647 Loss = 0.561 AUC = 0.647
Epoch 245: VAL Accuracy = 0.673 Loss = 0.56 AUC = 0.667


10it [06:22, 38.28s/it]


Epoch 250: TRAIN Accuracy = 0.647 Loss = 0.56 AUC = 0.646
Epoch 250: VAL Accuracy = 0.673 Loss = 0.556 AUC = 0.665


10it [06:21, 38.11s/it]


Epoch 255: TRAIN Accuracy = 0.647 Loss = 0.56 AUC = 0.663
Epoch 255: VAL Accuracy = 0.673 Loss = 0.555 AUC = 0.676


10it [06:23, 38.39s/it]


Epoch 260: TRAIN Accuracy = 0.647 Loss = 0.56 AUC = 0.664
Epoch 260: VAL Accuracy = 0.673 Loss = 0.556 AUC = 0.668


10it [06:24, 38.48s/it]


Epoch 265: TRAIN Accuracy = 0.647 Loss = 0.56 AUC = 0.655
Epoch 265: VAL Accuracy = 0.673 Loss = 0.555 AUC = 0.668


10it [06:21, 38.18s/it]


Epoch 270: TRAIN Accuracy = 0.647 Loss = 0.559 AUC = 0.65
Epoch 270: VAL Accuracy = 0.673 Loss = 0.556 AUC = 0.668


10it [06:26, 38.70s/it]


Epoch 275: TRAIN Accuracy = 0.647 Loss = 0.559 AUC = 0.647
Epoch 275: VAL Accuracy = 0.673 Loss = 0.554 AUC = 0.665


10it [06:25, 38.51s/it]


Epoch 280: TRAIN Accuracy = 0.646 Loss = 0.561 AUC = 0.647
Epoch 280: VAL Accuracy = 0.673 Loss = 0.554 AUC = 0.668


10it [06:16, 37.63s/it]


Epoch 285: TRAIN Accuracy = 0.647 Loss = 0.558 AUC = 0.655
Epoch 285: VAL Accuracy = 0.673 Loss = 0.554 AUC = 0.668


10it [06:16, 37.61s/it]


Epoch 290: TRAIN Accuracy = 0.647 Loss = 0.558 AUC = 0.666
Epoch 290: VAL Accuracy = 0.673 Loss = 0.553 AUC = 0.678


10it [06:18, 37.83s/it]


Epoch 295: TRAIN Accuracy = 0.647 Loss = 0.558 AUC = 0.653
Epoch 295: VAL Accuracy = 0.673 Loss = 0.552 AUC = 0.678


10it [06:23, 38.38s/it]

Epoch 300: TRAIN Accuracy = 0.646 Loss = 0.558 AUC = 0.663
Epoch 300: VAL Accuracy = 0.673 Loss = 0.553 AUC = 0.67
CPU times: user 15h 2min 43s, sys: 6h 41min 17s, total: 21h 44min 1s
Wall time: 6h 26min 6s



