In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout, TimeDistributed
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-21 00:51:08.977793: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 00:51:08.984442: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-21 00:51:09.004256: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742507469.033839   90386 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742507469.044612   90386 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-21 00:51:09.082422: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 200
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(TimeDistributed(Dense(32, activation='relu')))

    model.add(LTC(AutoNCP(32, 25), return_sequences=True))
    model.add(LTC(AutoNCP(25, 16), return_sequences=False))

    model.add(Dense(16, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,0.0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,0.0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,0.0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-03-21 00:51:14.377967: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 2s/step - accuracy: 0.5335 - auc: 0.4357 - loss: 0.6926 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6950
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 1s/step - accuracy: 0.5335 - auc: 0.5987 - loss: 0.6908 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6932
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 1s/step - accuracy: 0.5335 - auc: 0.4773 - loss: 0.6906 - val_accuracy: 0.5000 - val_auc: 0.8333 - val_loss: 0.6877
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 1s/step - accuracy: 0.5335 - auc: 0.7477 - loss: 0.6857 - val_accuracy: 0.5000 - val_auc: 0.8889 - val_loss: 0.6788
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 1s/step - accuracy: 0.5335 - auc: 0.7145 - loss: 0.6785 - val_accuracy: 0.7333 - val_auc: 0.8556 - val_loss: 0.6685
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━

1it [30:13, 1813.35s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m50s[0m 3s/step - accuracy: 0.5494 - auc: 0.3446 - loss: 0.6944 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6942
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 2s/step - accuracy: 0.5494 - auc: 0.4838 - loss: 0.6897 - val_accuracy: 0.5000 - val_auc: 0.6000 - val_loss: 0.6929
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 1s/step - accuracy: 0.5494 - auc: 0.5250 - loss: 0.6902 - val_accuracy: 0.5000 - val_auc: 0.8000 - val_loss: 0.6916
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 1s/step - accuracy: 0.5494 - auc: 0.6871 - loss: 0.6901 - val_accuracy: 0.5000 - val_auc: 0.7867 - val_loss: 0.6865
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 1s/step - accuracy: 0.5494 - auc: 0.7364 - loss: 0.6857 - val_accuracy: 0.7667 - val_auc: 0.8000 - val_loss: 0.6792
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━

2it [50:05, 1448.23s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 1s/step - accuracy: 0.4979 - auc: 0.4517 - loss: 0.6938 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6946
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 808ms/step - accuracy: 0.5663 - auc: 0.4550 - loss: 0.6878 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6936
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 808ms/step - accuracy: 0.5663 - auc: 0.5069 - loss: 0.6886 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6926
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 825ms/step - accuracy: 0.5663 - auc: 0.5844 - loss: 0.6897 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6911
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 815ms/step - accuracy: 0.6226 - auc: 0.6977 - loss: 0.6892 - val_accuracy: 0.8333 - val_auc: 0.8111 - val_loss: 0.6869
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━

3it [1:08:22, 1287.78s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 1s/step - accuracy: 0.5113 - auc: 0.4915 - loss: 0.6938 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6943
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 837ms/step - accuracy: 0.5113 - auc: 0.4836 - loss: 0.6932 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6930
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 818ms/step - accuracy: 0.4888 - auc: 0.4890 - loss: 0.6933 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6930
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 807ms/step - accuracy: 0.4888 - auc: 0.4753 - loss: 0.6935 - val_accuracy: 0.5000 - val_auc: 0.5000 - val_loss: 0.6927
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 813ms/step - accuracy: 0.4923 - auc: 0.5068 - loss: 0.6930 - val_accuracy: 0.5667 - val_auc: 0.5667 - val_loss: 0.6921
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━

4it [1:33:02, 1363.59s/it]

Epoch 1/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m50s[0m 2s/step - accuracy: 0.5346 - auc: 0.4708 - loss: 0.6947 - val_accuracy: 0.5000 - val_auc: 0.6333 - val_loss: 0.6936
Epoch 2/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 1s/step - accuracy: 0.5346 - auc: 0.6914 - loss: 0.6870 - val_accuracy: 0.7333 - val_auc: 0.7644 - val_loss: 0.6894
Epoch 3/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 1s/step - accuracy: 0.7996 - auc: 0.8002 - loss: 0.6871 - val_accuracy: 0.7667 - val_auc: 0.7067 - val_loss: 0.6843
Epoch 4/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 1s/step - accuracy: 0.8528 - auc: 0.8079 - loss: 0.6776 - val_accuracy: 0.7333 - val_auc: 0.7022 - val_loss: 0.6689
Epoch 5/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 1s/step - accuracy: 0.8754 - auc: 0.7988 - loss: 0.6513 - val_accuracy: 0.7333 - val_auc: 0.7200 - val_loss: 0.6535
Epoch 6/200
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━

5it [2:03:19, 1479.96s/it]

CPU times: user 4h 14min, sys: 1h 35min 48s, total: 5h 49min 49s
Wall time: 2h 3min 19s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

Epoch 1: TRAIN Accuracy = 0.494 Loss = 0.698 AUC = 0.416
Epoch 1: VAL Accuracy = 0.5 Loss = 0.694 AUC = 0.527
Epoch 2: TRAIN Accuracy = 0.5 Loss = 0.693 AUC = 0.561
Epoch 2: VAL Accuracy = 0.547 Loss = 0.692 AUC = 0.573
Epoch 3: TRAIN Accuracy = 0.543 Loss = 0.692 AUC = 0.56
Epoch 3: VAL Accuracy = 0.553 Loss = 0.69 AUC = 0.668
Epoch 4: TRAIN Accuracy = 0.551 Loss = 0.688 AUC = 0.662
Epoch 4: VAL Accuracy = 0.547 Loss = 0.684 AUC = 0.676
Epoch 5: TRAIN Accuracy = 0.591 Loss = 0.681 AUC = 0.696
Epoch 5: VAL Accuracy = 0.727 Loss = 0.676 AUC = 0.751
Epoch 6: TRAIN Accuracy = 0.734 Loss = 0.671 AUC = 0.733
Epoch 6: VAL Accuracy = 0.753 Loss = 0.662 AUC = 0.804
Epoch 7: TRAIN Accuracy = 0.78 Loss = 0.654 AUC = 0.761
Epoch 7: VAL Accuracy = 0.793 Loss = 0.642 AUC = 0.793
Epoch 8: TRAIN Accuracy = 0.789 Loss = 0.631 AUC = 0.767
Epoch 8: VAL Accuracy = 0.773 Loss = 0.616 AUC = 0.775
Epoch 9: TRAIN Accuracy = 0.786 Loss = 0.602 AUC = 0.762
Epoch 9: VAL Accuracy = 0.787 Loss = 0.584 AUC = 0.776