In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout, TimeDistributed
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-21 11:22:34.292612: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 11:22:34.305432: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 11:22:34.335742: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742545354.373779  659010 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742545354.385552  659010 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-21 11:22:34.437845: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(TimeDistributed(Dense(32, activation='relu')))

    model.add(LTC(AutoNCP(32, 25), return_sequences=True))
    model.add(LTC(AutoNCP(25, 16), return_sequences=False))

    model.add(Dense(16, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = [1, 2]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,...,Raw_2,Delta_2,Theta_2,Alpha1_2,Alpha2_2,Beta1_2,Beta2_2,Gamma1_2,Gamma2_2,predefinedlabel
0,0,0.0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0.0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0,0.0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,...,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-21 11:22:40.667866: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 2s/step - accuracy: 0.5335 - auc: 0.4357 - loss: 0.6926 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6950
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 1s/step - accuracy: 0.5335 - auc: 0.5497 - loss: 0.6908 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6931
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 1s/step - accuracy: 0.5335 - auc: 0.6442 - loss: 0.6902 - val_accuracy: 0.5000 - val_auc: 0.8200 - val_loss: 0.6871
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 1s/step - accuracy: 0.5335 - auc: 0.7475 - loss: 0.6852 - val_accuracy: 0.5000 - val_auc: 0.8889 - val_loss: 0.6801
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 1s/step - accuracy: 0.5335 - auc: 0.7305 - loss: 0.6791 - val_accuracy: 0.5000 - val_auc: 0.8222 - val_loss: 0.6704
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━

1it [28:16, 1697.00s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 2s/step - accuracy: 0.5494 - auc: 0.3698 - loss: 0.6942 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6938
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 1s/step - accuracy: 0.5494 - auc: 0.5437 - loss: 0.6891 - val_accuracy: 0.5000 - val_auc: 0.8000 - val_loss: 0.6908
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 1s/step - accuracy: 0.5494 - auc: 0.6816 - loss: 0.6881 - val_accuracy: 0.5000 - val_auc: 0.8000 - val_loss: 0.6879
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 1s/step - accuracy: 0.5494 - auc: 0.7672 - loss: 0.6846 - val_accuracy: 0.7000 - val_auc: 0.7867 - val_loss: 0.6784
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 1s/step - accuracy: 0.7287 - auc: 0.7766 - loss: 0.6756 - val_accuracy: 0.8000 - val_auc: 0.7867 - val_loss: 0.6614
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━

2it [49:41, 1454.10s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 1s/step - accuracy: 0.4979 - auc: 0.4517 - loss: 0.6938 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6945
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 901ms/step - accuracy: 0.5663 - auc: 0.4920 - loss: 0.6876 - val_accuracy: 0.5000 - val_auc: 0.8000 - val_loss: 0.6934
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 873ms/step - accuracy: 0.5663 - auc: 0.6177 - loss: 0.6887 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6923
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 900ms/step - accuracy: 0.5663 - auc: 0.6984 - loss: 0.6895 - val_accuracy: 0.5000 - val_auc: 0.6667 - val_loss: 0.6908
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 884ms/step - accuracy: 0.6766 - auc: 0.7463 - loss: 0.6892 - val_accuracy: 0.7667 - val_auc: 0.8111 - val_loss: 0.6871
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━

3it [1:09:27, 1332.05s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 1s/step - accuracy: 0.5113 - auc: 0.4915 - loss: 0.6938 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6943
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 871ms/step - accuracy: 0.5113 - auc: 0.5034 - loss: 0.6933 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6931
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 865ms/step - accuracy: 0.4888 - auc: 0.4890 - loss: 0.6934 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6931
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 884ms/step - accuracy: 0.4888 - auc: 0.4753 - loss: 0.6937 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6930
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 869ms/step - accuracy: 0.4923 - auc: 0.4808 - loss: 0.6934 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6927
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━

4it [1:29:50, 1288.67s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 1s/step - accuracy: 0.5346 - auc: 0.5124 - loss: 0.6948 - val_accuracy: 0.5000 - val_auc: 0.7000 - val_loss: 0.6940
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 868ms/step - accuracy: 0.5346 - auc: 0.6666 - loss: 0.6877 - val_accuracy: 0.7667 - val_auc: 0.7333 - val_loss: 0.6898
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 913ms/step - accuracy: 0.8430 - auc: 0.7932 - loss: 0.6881 - val_accuracy: 0.8000 - val_auc: 0.7467 - val_loss: 0.6858
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 874ms/step - accuracy: 0.8600 - auc: 0.8114 - loss: 0.6826 - val_accuracy: 0.8000 - val_auc: 0.7867 - val_loss: 0.6730
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 890ms/step - accuracy: 0.8535 - auc: 0.8092 - loss: 0.6617 - val_accuracy: 0.8000 - val_auc: 0.7467 - val_loss: 0.6548
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━

5it [1:50:14, 1322.97s/it]

CPU times: user 3h 58min 38s, sys: 1h 42min 57s, total: 5h 41min 35s
Wall time: 1h 50min 14s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

Epoch 1: TRAIN Accuracy = 0.494 Loss = 0.698 AUC = 0.424
Epoch 1: VAL Accuracy = 0.5 Loss = 0.694 AUC = 0.54
Epoch 2: TRAIN Accuracy = 0.5 Loss = 0.693 AUC = 0.587
Epoch 2: VAL Accuracy = 0.553 Loss = 0.692 AUC = 0.667
Epoch 3: TRAIN Accuracy = 0.551 Loss = 0.692 AUC = 0.649
Epoch 3: VAL Accuracy = 0.56 Loss = 0.689 AUC = 0.673
Epoch 4: TRAIN Accuracy = 0.56 Loss = 0.688 AUC = 0.705
Epoch 4: VAL Accuracy = 0.6 Loss = 0.683 AUC = 0.726
Epoch 5: TRAIN Accuracy = 0.649 Loss = 0.68 AUC = 0.709
Epoch 5: VAL Accuracy = 0.673 Loss = 0.673 AUC = 0.733
Epoch 6: TRAIN Accuracy = 0.706 Loss = 0.668 AUC = 0.714
Epoch 6: VAL Accuracy = 0.727 Loss = 0.658 AUC = 0.801
Epoch 7: TRAIN Accuracy = 0.757 Loss = 0.65 AUC = 0.735
Epoch 7: VAL Accuracy = 0.787 Loss = 0.638 AUC = 0.798
Epoch 8: TRAIN Accuracy = 0.789 Loss = 0.625 AUC = 0.764
Epoch 8: VAL Accuracy = 0.787 Loss = 0.608 AUC = 0.783
Epoch 9: TRAIN Accuracy = 0.786 Loss = 0.596 AUC = 0.762
Epoch 9: VAL Accuracy = 0.793 Loss = 0.578 AUC = 0.786
Epo