In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout, LSTM
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-02-19 12:38:35.661448: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-19 12:38:35.667059: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-19 12:38:35.677714: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1739957915.692752 3135226 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739957915.697065 3135226 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-19 12:38:35.717047: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 300
STEP = 5
NUM_EXPERIMENTS = 10

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))
    model.add(BatchNormalization())

    model.add(LSTM(10, return_sequences=True))
    model.add(LSTM(10, return_sequences=False))

    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.0003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-02-19 12:38:38.067768: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

models = []

for i in range(NUM_EPOCHS // STEP):
    epoch_acc = []
    epoch_loss = []
    epoch_auc = []

    epoch_val_acc = []
    epoch_val_loss = []
    epoch_val_auc = []

    for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
        np.random.seed(int(seed))
        random.seed(int(seed))
        tf.random.set_seed(int(seed))

        train_id = np.random.choice(np.unique(np.ravel(data[ID])), 70, replace=False)
        train_index = np.isin(data[ID], train_id)

        train = data.iloc[train_index]
        test = data.iloc[~train_index]

        X_train, y_train = reshape_dataset(train)
        X_test, y_test = reshape_dataset(test)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)

        if i == 0:
            model = create_model(X_train)
            models.append(model)
        else:
            model = models[j]

        history = model.fit(
            X_train, y_train,
            validation_data=(X_test, y_test),
            epochs=STEP,
            batch_size=16,
            verbose=0,
        )
        acc = history.history['accuracy'][0]
        loss = history.history['loss'][0]
        auc = history.history['auc'][0]

        val_acc = history.history['val_accuracy'][0]
        val_loss = history.history['val_loss'][0]
        val_auc = history.history['val_auc'][0]

        epoch_acc.append(acc)
        epoch_loss.append(loss)
        epoch_auc.append(auc)

        epoch_val_acc.append(val_acc)
        epoch_val_loss.append(val_loss)
        epoch_val_auc.append(val_auc)

    print(f"Epoch {(i + 1) * STEP}: TRAIN Accuracy = {np.round(np.mean(epoch_acc), 3)} Loss = {np.round(np.mean(epoch_loss), 3)} AUC = {np.round(np.mean(epoch_auc), 3)}")
    print(f"Epoch {(i + 1) * STEP}: VAL Accuracy = {np.round(np.mean(epoch_val_acc), 3)} Loss = {np.round(np.mean(epoch_val_loss), 3)} AUC = {np.round(np.mean(epoch_val_auc), 3)}")

10it [01:08,  6.82s/it]


Epoch 5: TRAIN Accuracy = 0.497 Loss = 0.695 AUC = 0.505
Epoch 5: VAL Accuracy = 0.48 Loss = 0.698 AUC = 0.383


10it [00:46,  4.68s/it]


Epoch 10: TRAIN Accuracy = 0.554 Loss = 0.686 AUC = 0.668
Epoch 10: VAL Accuracy = 0.567 Loss = 0.689 AUC = 0.527


10it [00:46,  4.69s/it]


Epoch 15: TRAIN Accuracy = 0.589 Loss = 0.68 AUC = 0.708
Epoch 15: VAL Accuracy = 0.647 Loss = 0.681 AUC = 0.643


10it [00:43,  4.32s/it]


Epoch 20: TRAIN Accuracy = 0.624 Loss = 0.675 AUC = 0.756
Epoch 20: VAL Accuracy = 0.627 Loss = 0.677 AUC = 0.652


10it [00:46,  4.69s/it]


Epoch 25: TRAIN Accuracy = 0.637 Loss = 0.669 AUC = 0.767
Epoch 25: VAL Accuracy = 0.653 Loss = 0.672 AUC = 0.693


10it [00:48,  4.82s/it]


Epoch 30: TRAIN Accuracy = 0.646 Loss = 0.663 AUC = 0.77
Epoch 30: VAL Accuracy = 0.667 Loss = 0.668 AUC = 0.706


10it [00:48,  4.89s/it]


Epoch 35: TRAIN Accuracy = 0.664 Loss = 0.655 AUC = 0.773
Epoch 35: VAL Accuracy = 0.66 Loss = 0.665 AUC = 0.7


10it [00:48,  4.86s/it]


Epoch 40: TRAIN Accuracy = 0.681 Loss = 0.645 AUC = 0.775
Epoch 40: VAL Accuracy = 0.643 Loss = 0.661 AUC = 0.692


10it [00:47,  4.79s/it]


Epoch 45: TRAIN Accuracy = 0.696 Loss = 0.633 AUC = 0.781
Epoch 45: VAL Accuracy = 0.657 Loss = 0.657 AUC = 0.689


10it [00:49,  4.94s/it]


Epoch 50: TRAIN Accuracy = 0.727 Loss = 0.622 AUC = 0.785
Epoch 50: VAL Accuracy = 0.66 Loss = 0.656 AUC = 0.677


10it [00:46,  4.69s/it]


Epoch 55: TRAIN Accuracy = 0.753 Loss = 0.61 AUC = 0.787
Epoch 55: VAL Accuracy = 0.66 Loss = 0.654 AUC = 0.673


10it [00:48,  4.85s/it]


Epoch 60: TRAIN Accuracy = 0.769 Loss = 0.597 AUC = 0.79
Epoch 60: VAL Accuracy = 0.673 Loss = 0.649 AUC = 0.681


10it [00:48,  4.80s/it]


Epoch 65: TRAIN Accuracy = 0.783 Loss = 0.583 AUC = 0.798
Epoch 65: VAL Accuracy = 0.67 Loss = 0.649 AUC = 0.698


10it [00:47,  4.76s/it]


Epoch 70: TRAIN Accuracy = 0.784 Loss = 0.568 AUC = 0.799
Epoch 70: VAL Accuracy = 0.667 Loss = 0.649 AUC = 0.709


10it [00:48,  4.88s/it]


Epoch 75: TRAIN Accuracy = 0.796 Loss = 0.553 AUC = 0.807
Epoch 75: VAL Accuracy = 0.673 Loss = 0.651 AUC = 0.722


10it [00:48,  4.83s/it]


Epoch 80: TRAIN Accuracy = 0.797 Loss = 0.538 AUC = 0.813
Epoch 80: VAL Accuracy = 0.653 Loss = 0.654 AUC = 0.731


10it [00:49,  4.96s/it]


Epoch 85: TRAIN Accuracy = 0.801 Loss = 0.525 AUC = 0.821
Epoch 85: VAL Accuracy = 0.65 Loss = 0.654 AUC = 0.741


10it [00:50,  5.04s/it]


Epoch 90: TRAIN Accuracy = 0.803 Loss = 0.51 AUC = 0.833
Epoch 90: VAL Accuracy = 0.647 Loss = 0.649 AUC = 0.737


10it [00:49,  4.92s/it]


Epoch 95: TRAIN Accuracy = 0.813 Loss = 0.493 AUC = 0.841
Epoch 95: VAL Accuracy = 0.653 Loss = 0.648 AUC = 0.74


10it [00:48,  4.86s/it]


Epoch 100: TRAIN Accuracy = 0.823 Loss = 0.48 AUC = 0.851
Epoch 100: VAL Accuracy = 0.647 Loss = 0.647 AUC = 0.736


10it [00:50,  5.01s/it]


Epoch 105: TRAIN Accuracy = 0.833 Loss = 0.464 AUC = 0.859
Epoch 105: VAL Accuracy = 0.653 Loss = 0.646 AUC = 0.733


10it [00:49,  4.92s/it]


Epoch 110: TRAIN Accuracy = 0.836 Loss = 0.449 AUC = 0.87
Epoch 110: VAL Accuracy = 0.637 Loss = 0.658 AUC = 0.719


10it [00:49,  5.00s/it]


Epoch 115: TRAIN Accuracy = 0.837 Loss = 0.433 AUC = 0.879
Epoch 115: VAL Accuracy = 0.637 Loss = 0.666 AUC = 0.713


10it [00:48,  4.83s/it]


Epoch 120: TRAIN Accuracy = 0.846 Loss = 0.415 AUC = 0.887
Epoch 120: VAL Accuracy = 0.643 Loss = 0.665 AUC = 0.71


10it [00:48,  4.87s/it]


Epoch 125: TRAIN Accuracy = 0.847 Loss = 0.404 AUC = 0.888
Epoch 125: VAL Accuracy = 0.653 Loss = 0.671 AUC = 0.715


10it [00:48,  4.86s/it]


Epoch 130: TRAIN Accuracy = 0.856 Loss = 0.391 AUC = 0.894
Epoch 130: VAL Accuracy = 0.643 Loss = 0.678 AUC = 0.701


10it [00:48,  4.84s/it]


Epoch 135: TRAIN Accuracy = 0.861 Loss = 0.377 AUC = 0.898
Epoch 135: VAL Accuracy = 0.65 Loss = 0.69 AUC = 0.708


10it [00:52,  5.22s/it]


Epoch 140: TRAIN Accuracy = 0.864 Loss = 0.366 AUC = 0.901
Epoch 140: VAL Accuracy = 0.647 Loss = 0.697 AUC = 0.708


10it [00:49,  4.99s/it]


Epoch 145: TRAIN Accuracy = 0.867 Loss = 0.364 AUC = 0.904
Epoch 145: VAL Accuracy = 0.65 Loss = 0.708 AUC = 0.705


10it [00:50,  5.08s/it]


Epoch 150: TRAIN Accuracy = 0.869 Loss = 0.352 AUC = 0.91
Epoch 150: VAL Accuracy = 0.63 Loss = 0.714 AUC = 0.707


10it [00:50,  5.04s/it]


Epoch 155: TRAIN Accuracy = 0.876 Loss = 0.339 AUC = 0.914
Epoch 155: VAL Accuracy = 0.643 Loss = 0.729 AUC = 0.705


10it [00:48,  4.90s/it]


Epoch 160: TRAIN Accuracy = 0.876 Loss = 0.335 AUC = 0.913
Epoch 160: VAL Accuracy = 0.647 Loss = 0.731 AUC = 0.707


10it [00:50,  5.09s/it]


Epoch 165: TRAIN Accuracy = 0.883 Loss = 0.322 AUC = 0.918
Epoch 165: VAL Accuracy = 0.65 Loss = 0.748 AUC = 0.691


10it [00:47,  4.71s/it]


Epoch 170: TRAIN Accuracy = 0.884 Loss = 0.314 AUC = 0.921
Epoch 170: VAL Accuracy = 0.647 Loss = 0.761 AUC = 0.716


10it [00:49,  4.95s/it]


Epoch 175: TRAIN Accuracy = 0.887 Loss = 0.306 AUC = 0.924
Epoch 175: VAL Accuracy = 0.65 Loss = 0.772 AUC = 0.724


10it [00:52,  5.22s/it]


Epoch 180: TRAIN Accuracy = 0.89 Loss = 0.3 AUC = 0.926
Epoch 180: VAL Accuracy = 0.65 Loss = 0.782 AUC = 0.726


10it [00:49,  4.95s/it]


Epoch 185: TRAIN Accuracy = 0.889 Loss = 0.294 AUC = 0.928
Epoch 185: VAL Accuracy = 0.643 Loss = 0.793 AUC = 0.728


10it [00:49,  4.96s/it]


Epoch 190: TRAIN Accuracy = 0.889 Loss = 0.288 AUC = 0.932
Epoch 190: VAL Accuracy = 0.64 Loss = 0.804 AUC = 0.731


10it [00:50,  5.08s/it]


Epoch 195: TRAIN Accuracy = 0.896 Loss = 0.282 AUC = 0.934
Epoch 195: VAL Accuracy = 0.63 Loss = 0.812 AUC = 0.722


10it [00:51,  5.19s/it]


Epoch 200: TRAIN Accuracy = 0.896 Loss = 0.275 AUC = 0.936
Epoch 200: VAL Accuracy = 0.64 Loss = 0.822 AUC = 0.72


10it [00:53,  5.32s/it]


Epoch 205: TRAIN Accuracy = 0.889 Loss = 0.278 AUC = 0.935
Epoch 205: VAL Accuracy = 0.64 Loss = 0.829 AUC = 0.712


10it [00:54,  5.44s/it]


Epoch 210: TRAIN Accuracy = 0.897 Loss = 0.268 AUC = 0.938
Epoch 210: VAL Accuracy = 0.627 Loss = 0.844 AUC = 0.716


10it [00:55,  5.54s/it]


Epoch 215: TRAIN Accuracy = 0.889 Loss = 0.291 AUC = 0.929
Epoch 215: VAL Accuracy = 0.61 Loss = 0.862 AUC = 0.726


10it [00:54,  5.50s/it]


Epoch 220: TRAIN Accuracy = 0.893 Loss = 0.271 AUC = 0.938
Epoch 220: VAL Accuracy = 0.627 Loss = 0.844 AUC = 0.725


10it [00:56,  5.62s/it]


Epoch 225: TRAIN Accuracy = 0.9 Loss = 0.256 AUC = 0.943
Epoch 225: VAL Accuracy = 0.62 Loss = 0.843 AUC = 0.722


10it [00:56,  5.68s/it]


Epoch 230: TRAIN Accuracy = 0.901 Loss = 0.26 AUC = 0.944
Epoch 230: VAL Accuracy = 0.633 Loss = 0.845 AUC = 0.741


10it [00:54,  5.47s/it]


Epoch 235: TRAIN Accuracy = 0.899 Loss = 0.252 AUC = 0.943
Epoch 235: VAL Accuracy = 0.62 Loss = 0.855 AUC = 0.731


10it [00:53,  5.39s/it]


Epoch 240: TRAIN Accuracy = 0.906 Loss = 0.245 AUC = 0.947
Epoch 240: VAL Accuracy = 0.617 Loss = 0.86 AUC = 0.733


10it [00:55,  5.55s/it]


Epoch 245: TRAIN Accuracy = 0.907 Loss = 0.239 AUC = 0.946
Epoch 245: VAL Accuracy = 0.61 Loss = 0.871 AUC = 0.715


10it [00:54,  5.47s/it]


Epoch 250: TRAIN Accuracy = 0.909 Loss = 0.235 AUC = 0.947
Epoch 250: VAL Accuracy = 0.603 Loss = 0.876 AUC = 0.715


10it [00:54,  5.44s/it]


Epoch 255: TRAIN Accuracy = 0.907 Loss = 0.237 AUC = 0.945
Epoch 255: VAL Accuracy = 0.61 Loss = 0.883 AUC = 0.714


10it [00:55,  5.51s/it]


Epoch 260: TRAIN Accuracy = 0.91 Loss = 0.232 AUC = 0.947
Epoch 260: VAL Accuracy = 0.613 Loss = 0.889 AUC = 0.71


10it [00:54,  5.46s/it]


Epoch 265: TRAIN Accuracy = 0.91 Loss = 0.228 AUC = 0.949
Epoch 265: VAL Accuracy = 0.603 Loss = 0.898 AUC = 0.712


10it [00:54,  5.43s/it]


Epoch 270: TRAIN Accuracy = 0.91 Loss = 0.224 AUC = 0.948
Epoch 270: VAL Accuracy = 0.607 Loss = 0.91 AUC = 0.71


10it [00:55,  5.53s/it]


Epoch 275: TRAIN Accuracy = 0.913 Loss = 0.22 AUC = 0.951
Epoch 275: VAL Accuracy = 0.61 Loss = 0.929 AUC = 0.695


10it [00:53,  5.37s/it]


Epoch 280: TRAIN Accuracy = 0.909 Loss = 0.231 AUC = 0.946
Epoch 280: VAL Accuracy = 0.61 Loss = 0.913 AUC = 0.712


10it [00:52,  5.27s/it]


Epoch 285: TRAIN Accuracy = 0.904 Loss = 0.232 AUC = 0.945
Epoch 285: VAL Accuracy = 0.613 Loss = 0.918 AUC = 0.72


10it [00:53,  5.33s/it]


Epoch 290: TRAIN Accuracy = 0.909 Loss = 0.226 AUC = 0.949
Epoch 290: VAL Accuracy = 0.623 Loss = 0.909 AUC = 0.711


10it [00:54,  5.42s/it]


Epoch 295: TRAIN Accuracy = 0.911 Loss = 0.22 AUC = 0.949
Epoch 295: VAL Accuracy = 0.627 Loss = 0.897 AUC = 0.714


10it [00:53,  5.36s/it]

Epoch 300: TRAIN Accuracy = 0.916 Loss = 0.217 AUC = 0.949
Epoch 300: VAL Accuracy = 0.633 Loss = 0.898 AUC = 0.717
CPU times: user 1h 13min 43s, sys: 17min 45s, total: 1h 31min 28s
Wall time: 51min 9s



