In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-02-18 21:49:28.608324: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-18 21:49:28.611997: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-18 21:49:28.623495: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1739904568.642751 2148908 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739904568.648542 2148908 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-18 21:49:28.671255: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 300
STEP = 5
NUM_EXPERIMENTS = 10

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(LTC(AutoNCP(15, 12), return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LTC(AutoNCP(12, 9), return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LTC(AutoNCP(9, 6), return_sequences=False))

    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.0003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-02-18 21:49:31.754853: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

models = []

for i in range(NUM_EPOCHS // STEP):
    epoch_acc = []
    epoch_loss = []
    epoch_auc = []

    epoch_val_acc = []
    epoch_val_loss = []
    epoch_val_auc = []

    for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
        np.random.seed(int(seed))
        random.seed(int(seed))
        tf.random.set_seed(int(seed))

        train_id = np.random.choice(np.unique(np.ravel(data[ID])), 70, replace=False)
        train_index = np.isin(data[ID], train_id)

        train = data.iloc[train_index]
        test = data.iloc[~train_index]

        X_train, y_train = reshape_dataset(train)
        X_test, y_test = reshape_dataset(test)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)

        if i == 0:
            model = create_model(X_train)
            models.append(model)
        else:
            model = models[j]

        history = model.fit(
            X_train, y_train,
            validation_data=(X_test, y_test),
            epochs=STEP,
            batch_size=16,
            verbose=0,
        )
        acc = history.history['accuracy'][0]
        loss = history.history['loss'][0]
        auc = history.history['auc'][0]

        val_acc = history.history['val_accuracy'][0]
        val_loss = history.history['val_loss'][0]
        val_auc = history.history['val_auc'][0]

        epoch_acc.append(acc)
        epoch_loss.append(loss)
        epoch_auc.append(auc)

        epoch_val_acc.append(val_acc)
        epoch_val_loss.append(val_loss)
        epoch_val_auc.append(val_auc)

    print(f"Epoch {(i + 1) * STEP}: TRAIN Accuracy = {np.round(np.mean(epoch_acc), 3)} Loss = {np.round(np.mean(epoch_loss), 3)} AUC = {np.round(np.mean(epoch_auc), 3)}")
    print(f"Epoch {(i + 1) * STEP}: VAL Accuracy = {np.round(np.mean(epoch_val_acc), 3)} Loss = {np.round(np.mean(epoch_val_loss), 3)} AUC = {np.round(np.mean(epoch_val_auc), 3)}")

10it [10:46, 64.64s/it]


Epoch 5: TRAIN Accuracy = 0.49 Loss = 0.696 AUC = 0.473
Epoch 5: VAL Accuracy = 0.523 Loss = 0.691 AUC = 0.5


10it [07:37, 45.73s/it]


Epoch 10: TRAIN Accuracy = 0.49 Loss = 0.695 AUC = 0.497
Epoch 10: VAL Accuracy = 0.523 Loss = 0.692 AUC = 0.5


10it [07:56, 47.69s/it]


Epoch 15: TRAIN Accuracy = 0.49 Loss = 0.695 AUC = 0.499
Epoch 15: VAL Accuracy = 0.523 Loss = 0.692 AUC = 0.5


10it [08:06, 48.65s/it]


Epoch 20: TRAIN Accuracy = 0.504 Loss = 0.694 AUC = 0.498
Epoch 20: VAL Accuracy = 0.49 Loss = 0.693 AUC = 0.5


10it [08:32, 51.23s/it]


Epoch 25: TRAIN Accuracy = 0.504 Loss = 0.694 AUC = 0.498
Epoch 25: VAL Accuracy = 0.49 Loss = 0.693 AUC = 0.5


10it [08:28, 50.81s/it]


Epoch 30: TRAIN Accuracy = 0.504 Loss = 0.694 AUC = 0.5
Epoch 30: VAL Accuracy = 0.49 Loss = 0.693 AUC = 0.5


10it [08:04, 48.47s/it]


Epoch 35: TRAIN Accuracy = 0.504 Loss = 0.694 AUC = 0.505
Epoch 35: VAL Accuracy = 0.49 Loss = 0.694 AUC = 0.5


10it [07:52, 47.24s/it]


Epoch 40: TRAIN Accuracy = 0.504 Loss = 0.694 AUC = 0.493
Epoch 40: VAL Accuracy = 0.49 Loss = 0.694 AUC = 0.5


10it [07:49, 46.99s/it]


Epoch 45: TRAIN Accuracy = 0.504 Loss = 0.694 AUC = 0.5
Epoch 45: VAL Accuracy = 0.49 Loss = 0.694 AUC = 0.5


10it [07:52, 47.27s/it]


Epoch 50: TRAIN Accuracy = 0.504 Loss = 0.694 AUC = 0.489
Epoch 50: VAL Accuracy = 0.49 Loss = 0.694 AUC = 0.5


10it [07:45, 46.52s/it]


Epoch 55: TRAIN Accuracy = 0.504 Loss = 0.693 AUC = 0.49
Epoch 55: VAL Accuracy = 0.49 Loss = 0.695 AUC = 0.532


10it [07:43, 46.35s/it]


Epoch 60: TRAIN Accuracy = 0.496 Loss = 0.693 AUC = 0.512
Epoch 60: VAL Accuracy = 0.477 Loss = 0.695 AUC = 0.5


10it [07:38, 45.86s/it]


Epoch 65: TRAIN Accuracy = 0.513 Loss = 0.693 AUC = 0.522
Epoch 65: VAL Accuracy = 0.477 Loss = 0.695 AUC = 0.5


10it [07:39, 45.99s/it]


Epoch 70: TRAIN Accuracy = 0.51 Loss = 0.693 AUC = 0.534
Epoch 70: VAL Accuracy = 0.477 Loss = 0.695 AUC = 0.557


10it [07:44, 46.50s/it]


Epoch 75: TRAIN Accuracy = 0.51 Loss = 0.693 AUC = 0.544
Epoch 75: VAL Accuracy = 0.477 Loss = 0.695 AUC = 0.525


10it [07:43, 46.31s/it]


Epoch 80: TRAIN Accuracy = 0.51 Loss = 0.692 AUC = 0.536
Epoch 80: VAL Accuracy = 0.477 Loss = 0.694 AUC = 0.529


10it [07:41, 46.20s/it]


Epoch 85: TRAIN Accuracy = 0.51 Loss = 0.692 AUC = 0.506
Epoch 85: VAL Accuracy = 0.477 Loss = 0.693 AUC = 0.532


10it [07:41, 46.17s/it]


Epoch 90: TRAIN Accuracy = 0.51 Loss = 0.69 AUC = 0.529
Epoch 90: VAL Accuracy = 0.477 Loss = 0.692 AUC = 0.529


10it [07:39, 45.99s/it]


Epoch 95: TRAIN Accuracy = 0.51 Loss = 0.689 AUC = 0.502
Epoch 95: VAL Accuracy = 0.477 Loss = 0.69 AUC = 0.529


10it [07:50, 47.06s/it]


Epoch 100: TRAIN Accuracy = 0.51 Loss = 0.687 AUC = 0.528
Epoch 100: VAL Accuracy = 0.503 Loss = 0.689 AUC = 0.531


10it [08:01, 48.13s/it]


Epoch 105: TRAIN Accuracy = 0.521 Loss = 0.685 AUC = 0.527
Epoch 105: VAL Accuracy = 0.503 Loss = 0.687 AUC = 0.532


10it [08:18, 49.84s/it]


Epoch 110: TRAIN Accuracy = 0.523 Loss = 0.683 AUC = 0.545
Epoch 110: VAL Accuracy = 0.503 Loss = 0.685 AUC = 0.562


10it [08:43, 52.34s/it]


Epoch 115: TRAIN Accuracy = 0.526 Loss = 0.682 AUC = 0.561
Epoch 115: VAL Accuracy = 0.513 Loss = 0.684 AUC = 0.599


10it [07:57, 47.75s/it]


Epoch 120: TRAIN Accuracy = 0.531 Loss = 0.679 AUC = 0.608
Epoch 120: VAL Accuracy = 0.513 Loss = 0.682 AUC = 0.615


10it [07:53, 47.33s/it]


Epoch 125: TRAIN Accuracy = 0.533 Loss = 0.676 AUC = 0.629
Epoch 125: VAL Accuracy = 0.513 Loss = 0.679 AUC = 0.66


10it [07:54, 47.42s/it]


Epoch 130: TRAIN Accuracy = 0.547 Loss = 0.672 AUC = 0.64
Epoch 130: VAL Accuracy = 0.557 Loss = 0.675 AUC = 0.657


10it [07:54, 47.41s/it]


Epoch 135: TRAIN Accuracy = 0.561 Loss = 0.665 AUC = 0.659
Epoch 135: VAL Accuracy = 0.593 Loss = 0.669 AUC = 0.655


10it [07:44, 46.42s/it]


Epoch 140: TRAIN Accuracy = 0.583 Loss = 0.658 AUC = 0.67
Epoch 140: VAL Accuracy = 0.6 Loss = 0.663 AUC = 0.65


10it [07:33, 45.37s/it]


Epoch 145: TRAIN Accuracy = 0.596 Loss = 0.65 AUC = 0.66
Epoch 145: VAL Accuracy = 0.613 Loss = 0.652 AUC = 0.656


10it [07:41, 46.13s/it]


Epoch 150: TRAIN Accuracy = 0.627 Loss = 0.637 AUC = 0.679
Epoch 150: VAL Accuracy = 0.64 Loss = 0.638 AUC = 0.661


10it [07:42, 46.28s/it]


Epoch 155: TRAIN Accuracy = 0.63 Loss = 0.627 AUC = 0.667
Epoch 155: VAL Accuracy = 0.64 Loss = 0.628 AUC = 0.659


10it [08:19, 49.97s/it]


Epoch 160: TRAIN Accuracy = 0.639 Loss = 0.615 AUC = 0.663
Epoch 160: VAL Accuracy = 0.637 Loss = 0.618 AUC = 0.659


10it [08:29, 50.98s/it]


Epoch 165: TRAIN Accuracy = 0.64 Loss = 0.603 AUC = 0.665
Epoch 165: VAL Accuracy = 0.637 Loss = 0.608 AUC = 0.655


10it [08:24, 50.48s/it]


Epoch 170: TRAIN Accuracy = 0.644 Loss = 0.594 AUC = 0.665
Epoch 170: VAL Accuracy = 0.637 Loss = 0.6 AUC = 0.664


10it [08:25, 50.56s/it]


Epoch 175: TRAIN Accuracy = 0.64 Loss = 0.586 AUC = 0.668
Epoch 175: VAL Accuracy = 0.637 Loss = 0.593 AUC = 0.666


10it [08:28, 50.87s/it]


Epoch 180: TRAIN Accuracy = 0.641 Loss = 0.577 AUC = 0.667
Epoch 180: VAL Accuracy = 0.637 Loss = 0.588 AUC = 0.674


10it [08:26, 50.69s/it]


Epoch 185: TRAIN Accuracy = 0.644 Loss = 0.57 AUC = 0.681
Epoch 185: VAL Accuracy = 0.637 Loss = 0.581 AUC = 0.662


10it [08:26, 50.66s/it]


Epoch 190: TRAIN Accuracy = 0.669 Loss = 0.565 AUC = 0.673
Epoch 190: VAL Accuracy = 0.657 Loss = 0.577 AUC = 0.656


10it [08:29, 51.00s/it]


Epoch 195: TRAIN Accuracy = 0.676 Loss = 0.561 AUC = 0.678
Epoch 195: VAL Accuracy = 0.66 Loss = 0.573 AUC = 0.664


10it [07:03, 42.35s/it]


Epoch 200: TRAIN Accuracy = 0.679 Loss = 0.556 AUC = 0.675
Epoch 200: VAL Accuracy = 0.66 Loss = 0.57 AUC = 0.657


10it [06:03, 36.30s/it]


Epoch 205: TRAIN Accuracy = 0.671 Loss = 0.553 AUC = 0.675
Epoch 205: VAL Accuracy = 0.66 Loss = 0.566 AUC = 0.664


10it [06:02, 36.22s/it]


Epoch 210: TRAIN Accuracy = 0.681 Loss = 0.55 AUC = 0.687
Epoch 210: VAL Accuracy = 0.67 Loss = 0.562 AUC = 0.665


10it [06:06, 36.69s/it]


Epoch 215: TRAIN Accuracy = 0.683 Loss = 0.548 AUC = 0.689
Epoch 215: VAL Accuracy = 0.673 Loss = 0.559 AUC = 0.668


10it [06:05, 36.51s/it]


Epoch 220: TRAIN Accuracy = 0.684 Loss = 0.545 AUC = 0.682
Epoch 220: VAL Accuracy = 0.663 Loss = 0.559 AUC = 0.669


10it [06:05, 36.60s/it]


Epoch 225: TRAIN Accuracy = 0.679 Loss = 0.545 AUC = 0.685
Epoch 225: VAL Accuracy = 0.673 Loss = 0.557 AUC = 0.673


10it [06:03, 36.38s/it]


Epoch 230: TRAIN Accuracy = 0.684 Loss = 0.541 AUC = 0.69
Epoch 230: VAL Accuracy = 0.673 Loss = 0.555 AUC = 0.671


10it [06:03, 36.35s/it]


Epoch 235: TRAIN Accuracy = 0.686 Loss = 0.54 AUC = 0.686
Epoch 235: VAL Accuracy = 0.673 Loss = 0.553 AUC = 0.668


10it [06:07, 36.75s/it]


Epoch 240: TRAIN Accuracy = 0.689 Loss = 0.539 AUC = 0.7
Epoch 240: VAL Accuracy = 0.673 Loss = 0.552 AUC = 0.677


10it [06:04, 36.47s/it]


Epoch 245: TRAIN Accuracy = 0.69 Loss = 0.537 AUC = 0.695
Epoch 245: VAL Accuracy = 0.673 Loss = 0.552 AUC = 0.674


10it [06:09, 36.95s/it]


Epoch 250: TRAIN Accuracy = 0.691 Loss = 0.535 AUC = 0.697
Epoch 250: VAL Accuracy = 0.673 Loss = 0.549 AUC = 0.673


10it [06:06, 36.60s/it]


Epoch 255: TRAIN Accuracy = 0.689 Loss = 0.534 AUC = 0.691
Epoch 255: VAL Accuracy = 0.673 Loss = 0.548 AUC = 0.67


10it [06:04, 36.43s/it]


Epoch 260: TRAIN Accuracy = 0.691 Loss = 0.534 AUC = 0.705
Epoch 260: VAL Accuracy = 0.673 Loss = 0.548 AUC = 0.674


10it [06:05, 36.58s/it]


Epoch 265: TRAIN Accuracy = 0.691 Loss = 0.532 AUC = 0.71
Epoch 265: VAL Accuracy = 0.673 Loss = 0.547 AUC = 0.67


10it [06:03, 36.36s/it]


Epoch 270: TRAIN Accuracy = 0.69 Loss = 0.531 AUC = 0.696
Epoch 270: VAL Accuracy = 0.673 Loss = 0.547 AUC = 0.674


10it [06:05, 36.56s/it]


Epoch 275: TRAIN Accuracy = 0.691 Loss = 0.53 AUC = 0.706
Epoch 275: VAL Accuracy = 0.673 Loss = 0.546 AUC = 0.673


10it [06:04, 36.40s/it]


Epoch 280: TRAIN Accuracy = 0.69 Loss = 0.529 AUC = 0.702
Epoch 280: VAL Accuracy = 0.673 Loss = 0.545 AUC = 0.674


10it [06:03, 36.32s/it]


Epoch 285: TRAIN Accuracy = 0.691 Loss = 0.529 AUC = 0.706
Epoch 285: VAL Accuracy = 0.673 Loss = 0.544 AUC = 0.679


10it [06:08, 36.89s/it]


Epoch 290: TRAIN Accuracy = 0.691 Loss = 0.528 AUC = 0.726
Epoch 290: VAL Accuracy = 0.673 Loss = 0.544 AUC = 0.701


10it [06:04, 36.48s/it]


Epoch 295: TRAIN Accuracy = 0.691 Loss = 0.526 AUC = 0.734
Epoch 295: VAL Accuracy = 0.673 Loss = 0.542 AUC = 0.7


10it [06:06, 36.62s/it]

Epoch 300: TRAIN Accuracy = 0.69 Loss = 0.525 AUC = 0.731
Epoch 300: VAL Accuracy = 0.673 Loss = 0.542 AUC = 0.699
CPU times: user 16h 46min 10s, sys: 6h 41min 34s, total: 23h 27min 45s
Wall time: 7h 24min 1s



