In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout, TimeDistributed
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-21 02:12:33.942110: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 02:12:33.950483: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 02:12:33.978255: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742512354.031859  273698 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742512354.048198  273698 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-21 02:12:34.105227: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(TimeDistributed(Dense(32, activation='relu')))

    model.add(LTC(AutoNCP(32, 25), return_sequences=True))
    model.add(LTC(AutoNCP(25, 16), return_sequences=False))

    model.add(Dense(16, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
LAGS = [1]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

new_features = []
for lag in LAGS:
    for feature_name in FEATURES:
        new_feature_name = f"{feature_name}_{lag}"
        new_features.append(new_feature_name)
        data[new_feature_name] = data.groupby(IDS)[feature_name].shift(lag).fillna(0)
FEATURES.extend(new_features)

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,...,Raw_1,Delta_1,Theta_1,Alpha1_1,Alpha2_1,Beta1_1,Beta2_1,Gamma1_1,Gamma2_1,predefinedlabel
0,0,0.0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0,0.0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,...,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
2,0,0.0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,...,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-21 02:12:42.230855: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m52s[0m 2s/step - accuracy: 0.5335 - auc: 0.4357 - loss: 0.6926 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6949
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 1s/step - accuracy: 0.5335 - auc: 0.5794 - loss: 0.6906 - val_accuracy: 0.5000 - val_auc: 0.8000 - val_loss: 0.6922
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 1s/step - accuracy: 0.5335 - auc: 0.6790 - loss: 0.6889 - val_accuracy: 0.5000 - val_auc: 0.8222 - val_loss: 0.6843
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 1s/step - accuracy: 0.5335 - auc: 0.7362 - loss: 0.6831 - val_accuracy: 0.5000 - val_auc: 0.8222 - val_loss: 0.6765
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 1s/step - accuracy: 0.5443 - auc: 0.7143 - loss: 0.6770 - val_accuracy: 0.8333 - val_auc: 0.8556 - val_loss: 0.6662
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━

1it [31:32, 1892.07s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m48s[0m 2s/step - accuracy: 0.5494 - auc: 0.3219 - loss: 0.6943 - val_accuracy: 0.5000 - val_auc: 0.5667 - val_loss: 0.6941
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 1s/step - accuracy: 0.5494 - auc: 0.5326 - loss: 0.6895 - val_accuracy: 0.5000 - val_auc: 0.7000 - val_loss: 0.6926
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 1s/step - accuracy: 0.5494 - auc: 0.5978 - loss: 0.6898 - val_accuracy: 0.5000 - val_auc: 0.8000 - val_loss: 0.6906
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 1s/step - accuracy: 0.5494 - auc: 0.7550 - loss: 0.6891 - val_accuracy: 0.5000 - val_auc: 0.8000 - val_loss: 0.6821
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 1s/step - accuracy: 0.5494 - auc: 0.7229 - loss: 0.6807 - val_accuracy: 0.7667 - val_auc: 0.7867 - val_loss: 0.6684
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━

2it [53:22, 1549.64s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 1s/step - accuracy: 0.4979 - auc: 0.4517 - loss: 0.6938 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6945
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 856ms/step - accuracy: 0.5663 - auc: 0.5153 - loss: 0.6876 - val_accuracy: 0.5000 - val_auc: 0.7667 - val_loss: 0.6933
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 840ms/step - accuracy: 0.5663 - auc: 0.6070 - loss: 0.6887 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6922
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 830ms/step - accuracy: 0.5663 - auc: 0.6720 - loss: 0.6892 - val_accuracy: 0.5000 - val_auc: 0.7000 - val_loss: 0.6903
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 826ms/step - accuracy: 0.7487 - auc: 0.7394 - loss: 0.6883 - val_accuracy: 0.7667 - val_auc: 0.8111 - val_loss: 0.6851
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━

3it [1:12:45, 1373.39s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 1s/step - accuracy: 0.5113 - auc: 0.4915 - loss: 0.6938 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6943
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 834ms/step - accuracy: 0.5113 - auc: 0.4870 - loss: 0.6933 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6931
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 833ms/step - accuracy: 0.4888 - auc: 0.4890 - loss: 0.6935 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6932
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 838ms/step - accuracy: 0.4888 - auc: 0.4753 - loss: 0.6938 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6931
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 840ms/step - accuracy: 0.4888 - auc: 0.4808 - loss: 0.6935 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6928
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━

4it [1:32:11, 1291.54s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 1s/step - accuracy: 0.5346 - auc: 0.5124 - loss: 0.6949 - val_accuracy: 0.5000 - val_auc: 0.7333 - val_loss: 0.6942
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 824ms/step - accuracy: 0.5346 - auc: 0.6666 - loss: 0.6879 - val_accuracy: 0.8000 - val_auc: 0.7333 - val_loss: 0.6906
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 832ms/step - accuracy: 0.8093 - auc: 0.7357 - loss: 0.6892 - val_accuracy: 0.5333 - val_auc: 0.7444 - val_loss: 0.6865
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 815ms/step - accuracy: 0.6377 - auc: 0.8142 - loss: 0.6820 - val_accuracy: 0.5333 - val_auc: 0.7467 - val_loss: 0.6750
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 819ms/step - accuracy: 0.7078 - auc: 0.8169 - loss: 0.6616 - val_accuracy: 0.7333 - val_auc: 0.7200 - val_loss: 0.6569
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━

5it [1:51:24, 1336.87s/it]

CPU times: user 3h 53min 17s, sys: 1h 35min 34s, total: 5h 28min 51s
Wall time: 1h 51min 24s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

Epoch 1: TRAIN Accuracy = 0.494 Loss = 0.698 AUC = 0.413
Epoch 1: VAL Accuracy = 0.5 Loss = 0.694 AUC = 0.56
Epoch 2: TRAIN Accuracy = 0.5 Loss = 0.693 AUC = 0.582
Epoch 2: VAL Accuracy = 0.56 Loss = 0.692 AUC = 0.7
Epoch 3: TRAIN Accuracy = 0.537 Loss = 0.692 AUC = 0.626
Epoch 3: VAL Accuracy = 0.507 Loss = 0.689 AUC = 0.673
Epoch 4: TRAIN Accuracy = 0.517 Loss = 0.689 AUC = 0.692
Epoch 4: VAL Accuracy = 0.507 Loss = 0.683 AUC = 0.714
Epoch 5: TRAIN Accuracy = 0.609 Loss = 0.681 AUC = 0.69
Epoch 5: VAL Accuracy = 0.72 Loss = 0.674 AUC = 0.735
Epoch 6: TRAIN Accuracy = 0.717 Loss = 0.669 AUC = 0.718
Epoch 6: VAL Accuracy = 0.74 Loss = 0.657 AUC = 0.769
Epoch 7: TRAIN Accuracy = 0.734 Loss = 0.651 AUC = 0.717
Epoch 7: VAL Accuracy = 0.733 Loss = 0.637 AUC = 0.795
Epoch 8: TRAIN Accuracy = 0.786 Loss = 0.627 AUC = 0.735
Epoch 8: VAL Accuracy = 0.787 Loss = 0.61 AUC = 0.787
Epoch 9: TRAIN Accuracy = 0.786 Loss = 0.598 AUC = 0.758
Epoch 9: VAL Accuracy = 0.793 Loss = 0.58 AUC = 0.77
Epoch 