In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout, Flatten
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-05-22 02:34:03.242831: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-22 02:34:03.247340: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-22 02:34:03.257904: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747870443.274748  503277 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747870443.279764  503277 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-22 02:34:03.302001: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 400
NUM_EXPERIMENTS = 5

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))
    model.add(Flatten())

    model.add(BatchNormalization())
    model.add(Dropout(0.75))
    model.add(Dense(48, activation='relu'))

    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(32, activation='relu'))

    model.add(BatchNormalization())
    model.add(Dropout(0.25))
    model.add(Dense(16, activation='relu'))

    model.add(BatchNormalization())
    model.add(Dropout(0.25))
    model.add(Dense(8, activation='relu'))

    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.002, weight_decay=1e-7, use_ema=True), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
USER = ["SubjectID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + USER + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,SubjectID,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,0.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,0.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,0.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-05-22 02:34:05.565239: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

all_acc = []
all_loss = []
all_auc = []

all_val_acc = []
all_val_loss = []
all_val_auc = []

for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
    np.random.seed(int(seed))
    random.seed(int(seed))
    tf.random.set_seed(int(seed))

    train_id = np.random.choice(np.unique(np.ravel(data[USER])), 7, replace=False)
    train_index = np.isin(data[USER], train_id)

    train = data.iloc[train_index]
    test = data.iloc[~train_index]

    X_train, y_train = reshape_dataset(train)
    X_test, y_test = reshape_dataset(test)

    y_train = y_train.reshape(-1, 1)
    y_test = y_test.reshape(-1, 1)

    model = create_model(X_train)

    history = model.fit(
        X_train, y_train,
        validation_data=(X_test, y_test),
        epochs=NUM_EPOCHS,
        batch_size=10,
        verbose=1,
    )

    acc = history.history['accuracy']
    loss = history.history['loss']
    auc = history.history['auc']

    val_acc = history.history['val_accuracy']
    val_loss = history.history['val_loss']
    val_auc = history.history['val_auc']

    all_acc.append(acc)
    all_loss.append(loss)
    all_auc.append(auc)

    all_val_acc.append(val_acc)
    all_val_loss.append(val_loss)
    all_val_auc.append(val_auc)

epoch_acc = np.mean(all_acc, axis=0)
epoch_loss = np.mean(all_loss, axis=0)
epoch_auc = np.mean(all_auc, axis=0)

epoch_val_acc = np.mean(all_val_acc, axis=0)
epoch_val_loss = np.mean(all_val_loss, axis=0)
epoch_val_auc = np.mean(all_val_auc, axis=0)

0it [00:00, ?it/s]

Epoch 1/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 54ms/step - accuracy: 0.4230 - auc: 0.4123 - loss: 0.8550 - val_accuracy: 0.5000 - val_auc: 0.5867 - val_loss: 2.1972
Epoch 2/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.5482 - auc: 0.6285 - loss: 0.7428 - val_accuracy: 0.5000 - val_auc: 0.6533 - val_loss: 1.4968
Epoch 3/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.5591 - auc: 0.6271 - loss: 0.7098 - val_accuracy: 0.5000 - val_auc: 0.6200 - val_loss: 0.9977
Epoch 4/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.5256 - auc: 0.6031 - loss: 0.6896 - val_accuracy: 0.4667 - val_auc: 0.6178 - val_loss: 0.7644
Epoch 5/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.5294 - auc: 0.6722 - loss: 0.7042 - val_accuracy: 0.4667 - val_auc: 0.6644 - val_loss: 0.6630
Epoch 6/400
[1m7/7[0m [32m━━━━━━━━━━━

1it [00:39, 39.19s/it]

Epoch 1/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 50ms/step - accuracy: 0.5734 - auc: 0.5866 - loss: 0.9467 - val_accuracy: 0.5000 - val_auc: 0.7378 - val_loss: 2.0713
Epoch 2/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.6091 - auc: 0.5909 - loss: 0.8837 - val_accuracy: 0.5000 - val_auc: 0.7178 - val_loss: 1.3994
Epoch 3/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.5510 - auc: 0.6627 - loss: 0.7621 - val_accuracy: 0.5000 - val_auc: 0.7111 - val_loss: 0.8930
Epoch 4/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.6133 - auc: 0.5987 - loss: 0.7807 - val_accuracy: 0.5000 - val_auc: 0.7156 - val_loss: 0.7158
Epoch 5/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.5521 - auc: 0.5130 - loss: 0.8207 - val_accuracy: 0.5667 - val_auc: 0.6622 - val_loss: 0.6680
Epoch 6/400
[1m7/7[0m [32m━━━━━━━━━━━

2it [01:19, 39.96s/it]

Epoch 1/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 49ms/step - accuracy: 0.5190 - auc: 0.4985 - loss: 0.8990 - val_accuracy: 0.5000 - val_auc: 0.6289 - val_loss: 1.5639
Epoch 2/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.4995 - auc: 0.5175 - loss: 0.8921 - val_accuracy: 0.5000 - val_auc: 0.6333 - val_loss: 1.1880
Epoch 3/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.4510 - auc: 0.4610 - loss: 0.8891 - val_accuracy: 0.5333 - val_auc: 0.6222 - val_loss: 0.9844
Epoch 4/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.5285 - auc: 0.5170 - loss: 0.8590 - val_accuracy: 0.6000 - val_auc: 0.6556 - val_loss: 0.8409
Epoch 5/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.6343 - auc: 0.5985 - loss: 0.7284 - val_accuracy: 0.6000 - val_auc: 0.6644 - val_loss: 0.7354
Epoch 6/400
[1m7/7[0m [32m━━━━━━━━━━━

3it [02:00, 40.37s/it]

Epoch 1/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 51ms/step - accuracy: 0.4383 - auc: 0.3285 - loss: 1.1577 - val_accuracy: 0.5000 - val_auc: 0.5089 - val_loss: 1.7209
Epoch 2/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.3546 - auc: 0.3817 - loss: 1.0967 - val_accuracy: 0.5000 - val_auc: 0.4978 - val_loss: 1.2248
Epoch 3/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.5240 - auc: 0.6316 - loss: 0.7603 - val_accuracy: 0.5000 - val_auc: 0.5133 - val_loss: 1.0080
Epoch 4/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.5098 - auc: 0.5224 - loss: 0.8143 - val_accuracy: 0.5000 - val_auc: 0.5044 - val_loss: 0.8574
Epoch 5/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.4748 - auc: 0.6047 - loss: 0.7051 - val_accuracy: 0.4667 - val_auc: 0.4689 - val_loss: 0.7962
Epoch 6/400
[1m7/7[0m [32m━━━━━━━━━━━

4it [02:42, 41.13s/it]

Epoch 1/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 68ms/step - accuracy: 0.5079 - auc: 0.3922 - loss: 0.9816 - val_accuracy: 0.4333 - val_auc: 0.4000 - val_loss: 0.9625
Epoch 2/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.5563 - auc: 0.6235 - loss: 0.7801 - val_accuracy: 0.4333 - val_auc: 0.4600 - val_loss: 0.8005
Epoch 3/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.4891 - auc: 0.5199 - loss: 0.8296 - val_accuracy: 0.5000 - val_auc: 0.5156 - val_loss: 0.7479
Epoch 4/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.4771 - auc: 0.4924 - loss: 0.7958 - val_accuracy: 0.5333 - val_auc: 0.5622 - val_loss: 0.7171
Epoch 5/400
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.5769 - auc: 0.6102 - loss: 0.7343 - val_accuracy: 0.5667 - val_auc: 0.6267 - val_loss: 0.6778
Epoch 6/400
[1m7/7[0m [32m━━━━━━━━━━━

5it [03:25, 41.18s/it]

CPU times: user 3min 39s, sys: 46.4 s, total: 4min 25s
Wall time: 3min 25s





In [8]:
for i in range(NUM_EPOCHS):
    print(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}")
    print(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}")

with open("./logs/Linearv1_expanded.txt", "w") as f:
    for i in range(NUM_EPOCHS):
        f.write(f"Epoch {(i + 1)}: TRAIN Accuracy = {np.round(epoch_acc[i], 3)} Loss = {np.round(epoch_loss[i], 3)} AUC = {np.round(epoch_auc[i], 3)}\n")
        f.write(f"Epoch {(i + 1)}: VAL Accuracy = {np.round(epoch_val_acc[i], 3)} Loss = {np.round(epoch_val_loss[i], 3)} AUC = {np.round(epoch_val_auc[i], 3)}\n")

Epoch 1: TRAIN Accuracy = 0.494 Loss = 0.968 AUC = 0.464
Epoch 1: VAL Accuracy = 0.487 Loss = 1.703 AUC = 0.572
Epoch 2: TRAIN Accuracy = 0.517 Loss = 0.862 AUC = 0.549
Epoch 2: VAL Accuracy = 0.487 Loss = 1.222 AUC = 0.592
Epoch 3: TRAIN Accuracy = 0.514 Loss = 0.82 AUC = 0.556
Epoch 3: VAL Accuracy = 0.507 Loss = 0.926 AUC = 0.596
Epoch 4: TRAIN Accuracy = 0.543 Loss = 0.768 AUC = 0.588
Epoch 4: VAL Accuracy = 0.52 Loss = 0.779 AUC = 0.611
Epoch 5: TRAIN Accuracy = 0.537 Loss = 0.769 AUC = 0.575
Epoch 5: VAL Accuracy = 0.533 Loss = 0.708 AUC = 0.617
Epoch 6: TRAIN Accuracy = 0.574 Loss = 0.707 AUC = 0.644
Epoch 6: VAL Accuracy = 0.6 Loss = 0.665 AUC = 0.633
Epoch 7: TRAIN Accuracy = 0.551 Loss = 0.707 AUC = 0.631
Epoch 7: VAL Accuracy = 0.627 Loss = 0.64 AUC = 0.66
Epoch 8: TRAIN Accuracy = 0.6 Loss = 0.674 AUC = 0.671
Epoch 8: VAL Accuracy = 0.6 Loss = 0.623 AUC = 0.699
Epoch 9: TRAIN Accuracy = 0.629 Loss = 0.657 AUC = 0.684
Epoch 9: VAL Accuracy = 0.62 Loss = 0.615 AUC = 0.739
Epo