In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout, TimeDistributed
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-22 00:04:39.041846: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-22 00:04:39.049927: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-22 00:04:39.069856: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742591079.104677 1359486 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742591079.113363 1359486 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-22 00:04:39.146990: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(LTC(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LTC(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LTC(10, return_sequences=False))

    model.add(Dropout(0.2))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [None]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = [2]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,...,Raw_2,Delta_2,Theta_2,Alpha1_2,Alpha2_2,Beta1_2,Beta2_2,Gamma1_2,Gamma2_2,predefinedlabel
0,0,0.0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0.0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0,0.0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,...,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-22 00:04:43.458931: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 1s/step - accuracy: 0.5471 - auc: 0.6503 - loss: 0.6916 - val_accuracy: 0.5000 - val_auc: 0.6333 - val_loss: 0.6958
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 770ms/step - accuracy: 0.4160 - auc: 0.5337 - loss: 0.7033 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6935
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 783ms/step - accuracy: 0.5258 - auc: 0.4950 - loss: 0.6956 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6932
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 743ms/step - accuracy: 0.4297 - auc: 0.4031 - loss: 0.7112 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6932
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 716ms/step - accuracy: 0.4331 - auc: 0.4270 - loss: 0.7003 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6933
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━

1it [24:23, 1463.71s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m36s[0m 1s/step - accuracy: 0.5668 - auc: 0.6065 - loss: 0.6795 - val_accuracy: 0.5000 - val_auc: 0.6000 - val_loss: 0.6957
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 842ms/step - accuracy: 0.5753 - auc: 0.5779 - loss: 0.6776 - val_accuracy: 0.6000 - val_auc: 0.6000 - val_loss: 0.6874
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 807ms/step - accuracy: 0.5608 - auc: 0.5956 - loss: 0.6750 - val_accuracy: 0.6000 - val_auc: 0.6000 - val_loss: 0.6909
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 918ms/step - accuracy: 0.4177 - auc: 0.3733 - loss: 0.7084 - val_accuracy: 0.6000 - val_auc: 0.6000 - val_loss: 0.6870
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 812ms/step - accuracy: 0.6362 - auc: 0.6134 - loss: 0.6781 - val_accuracy: 0.5000 - val_auc: 0.6000 - val_loss: 0.6844
Epoch 6/200
[1m7/7[0m [32m━━━━━━━

2it [48:48, 1464.54s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 1s/step - accuracy: 0.5663 - auc: 0.6511 - loss: 0.6618 - val_accuracy: 0.5000 - val_auc: 0.6333 - val_loss: 0.7006
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 897ms/step - accuracy: 0.5641 - auc: 0.5026 - loss: 0.6886 - val_accuracy: 0.5000 - val_auc: 0.6333 - val_loss: 0.6900
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 926ms/step - accuracy: 0.6212 - auc: 0.6134 - loss: 0.6738 - val_accuracy: 0.5000 - val_auc: 0.6333 - val_loss: 0.6905
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 1s/step - accuracy: 0.5356 - auc: 0.5606 - loss: 0.6834 - val_accuracy: 0.5000 - val_auc: 0.6333 - val_loss: 0.6939
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 983ms/step - accuracy: 0.4574 - auc: 0.5048 - loss: 0.7010 - val_accuracy: 0.5000 - val_auc: 0.6333 - val_loss: 0.6934
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━

3it [1:14:24, 1497.09s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 1s/step - accuracy: 0.5005 - auc: 0.4742 - loss: 0.7139 - val_accuracy: 0.5000 - val_auc: 0.6667 - val_loss: 0.6928
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 788ms/step - accuracy: 0.4103 - auc: 0.3243 - loss: 0.7357 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6954
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 852ms/step - accuracy: 0.5099 - auc: 0.5439 - loss: 0.6894 - val_accuracy: 0.5000 - val_auc: 0.6333 - val_loss: 0.6917
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 819ms/step - accuracy: 0.5040 - auc: 0.5259 - loss: 0.6917 - val_accuracy: 0.5000 - val_auc: 0.6333 - val_loss: 0.6918
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 844ms/step - accuracy: 0.4895 - auc: 0.5116 - loss: 0.6962 - val_accuracy: 0.5000 - val_auc: 0.6333 - val_loss: 0.6900
Epoch 6/200
[1m7/7[0m [32m━━━━━━

4it [1:39:41, 1505.04s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 1s/step - accuracy: 0.4897 - auc: 0.4525 - loss: 0.6980 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6931
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 859ms/step - accuracy: 0.5523 - auc: 0.5403 - loss: 0.6791 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6932
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 852ms/step - accuracy: 0.6484 - auc: 0.5159 - loss: 0.6912 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6932
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 949ms/step - accuracy: 0.4209 - auc: 0.5470 - loss: 0.6910 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6932
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 925ms/step - accuracy: 0.5004 - auc: 0.5879 - loss: 0.6950 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6932
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━

5it [2:04:46, 1497.36s/it]

CPU times: user 3h 38min 19s, sys: 1h 1min 29s, total: 4h 39min 48s
Wall time: 2h 4min 46s





In [None]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

Epoch 1: TRAIN Accuracy = 0.503 Loss = 0.702 AUC = 0.537
Epoch 1: VAL Accuracy = 0.5 Loss = 0.696 AUC = 0.607
Epoch 2: TRAIN Accuracy = 0.466 Loss = 0.703 AUC = 0.482
Epoch 2: VAL Accuracy = 0.52 Loss = 0.692 AUC = 0.547
Epoch 3: TRAIN Accuracy = 0.537 Loss = 0.691 AUC = 0.536
Epoch 3: VAL Accuracy = 0.52 Loss = 0.692 AUC = 0.573
Epoch 4: TRAIN Accuracy = 0.469 Loss = 0.696 AUC = 0.487
Epoch 4: VAL Accuracy = 0.52 Loss = 0.692 AUC = 0.573
Epoch 5: TRAIN Accuracy = 0.511 Loss = 0.695 AUC = 0.52
Epoch 5: VAL Accuracy = 0.5 Loss = 0.691 AUC = 0.573
Epoch 6: TRAIN Accuracy = 0.486 Loss = 0.698 AUC = 0.49
Epoch 6: VAL Accuracy = 0.5 Loss = 0.691 AUC = 0.573
Epoch 7: TRAIN Accuracy = 0.54 Loss = 0.694 AUC = 0.538
Epoch 7: VAL Accuracy = 0.5 Loss = 0.689 AUC = 0.6
Epoch 8: TRAIN Accuracy = 0.517 Loss = 0.691 AUC = 0.518
Epoch 8: VAL Accuracy = 0.547 Loss = 0.689 AUC = 0.593
Epoch 9: TRAIN Accuracy = 0.56 Loss = 0.687 AUC = 0.554
Epoch 9: VAL Accuracy = 0.547 Loss = 0.688 AUC = 0.593
Epoch 10: