In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout, TimeDistributed
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-02-19 15:48:56.871377: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-19 15:48:56.953888: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-19 15:48:57.064188: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1739969337.158773 3901987 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739969337.187209 3901987 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-19 15:48:57.417622: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 300
STEP = 5
NUM_EXPERIMENTS = 10

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(TimeDistributed(Dense(16, activation='relu')))

    model.add(LTC(AutoNCP(32, 25), return_sequences=True))
    model.add(LTC(AutoNCP(25, 16), return_sequences=False))

    model.add(Dense(16, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.0003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data_confusion")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-02-19 15:49:01.243916: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [None]:
%%time

models = []

for i in range(NUM_EPOCHS // STEP):
    epoch_acc = []
    epoch_loss = []
    epoch_auc = []

    epoch_val_acc = []
    epoch_val_loss = []
    epoch_val_auc = []

    for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
        np.random.seed(int(seed))
        random.seed(int(seed))
        tf.random.set_seed(int(seed))

        train_id = np.random.choice(np.unique(np.ravel(data[ID])), 70, replace=False)
        train_index = np.isin(data[ID], train_id)

        train = data.iloc[train_index]
        test = data.iloc[~train_index]

        X_train, y_train = reshape_dataset(train)
        X_test, y_test = reshape_dataset(test)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)

        if i == 0:
            model = create_model(X_train)
            models.append(model)
        else:
            model = models[j]

        history = model.fit(
            X_train, y_train,
            validation_data=(X_test, y_test),
            epochs=STEP,
            batch_size=16,
            verbose=0,
        )
        acc = history.history['accuracy'][0]
        loss = history.history['loss'][0]
        auc = history.history['auc'][0]

        val_acc = history.history['val_accuracy'][0]
        val_loss = history.history['val_loss'][0]
        val_auc = history.history['val_auc'][0]

        epoch_acc.append(acc)
        epoch_loss.append(loss)
        epoch_auc.append(auc)

        epoch_val_acc.append(val_acc)
        epoch_val_loss.append(val_loss)
        epoch_val_auc.append(val_auc)

    print(f"Epoch {(i + 1) * STEP}: TRAIN Accuracy = {np.round(np.mean(epoch_acc), 3)} Loss = {np.round(np.mean(epoch_loss), 3)} AUC = {np.round(np.mean(epoch_auc), 3)}")
    print(f"Epoch {(i + 1) * STEP}: VAL Accuracy = {np.round(np.mean(epoch_val_acc), 3)} Loss = {np.round(np.mean(epoch_val_loss), 3)} AUC = {np.round(np.mean(epoch_val_auc), 3)}")

10it [06:19, 37.99s/it]


Epoch 5: TRAIN Accuracy = 0.49 Loss = 0.705 AUC = 0.436
Epoch 5: VAL Accuracy = 0.523 Loss = 0.7 AUC = 0.535


10it [04:19, 25.94s/it]


Epoch 10: TRAIN Accuracy = 0.519 Loss = 0.696 AUC = 0.612
Epoch 10: VAL Accuracy = 0.567 Loss = 0.695 AUC = 0.602


10it [04:24, 26.48s/it]


Epoch 15: TRAIN Accuracy = 0.557 Loss = 0.691 AUC = 0.656
Epoch 15: VAL Accuracy = 0.56 Loss = 0.692 AUC = 0.676


10it [04:21, 26.14s/it]


Epoch 20: TRAIN Accuracy = 0.593 Loss = 0.686 AUC = 0.665
Epoch 20: VAL Accuracy = 0.563 Loss = 0.688 AUC = 0.69


10it [04:23, 26.40s/it]


Epoch 25: TRAIN Accuracy = 0.594 Loss = 0.68 AUC = 0.684
Epoch 25: VAL Accuracy = 0.553 Loss = 0.684 AUC = 0.671


10it [04:28, 26.83s/it]


Epoch 30: TRAIN Accuracy = 0.593 Loss = 0.673 AUC = 0.678
Epoch 30: VAL Accuracy = 0.553 Loss = 0.679 AUC = 0.707


10it [04:33, 27.31s/it]


Epoch 35: TRAIN Accuracy = 0.617 Loss = 0.666 AUC = 0.723
Epoch 35: VAL Accuracy = 0.583 Loss = 0.672 AUC = 0.728


10it [04:44, 28.42s/it]


Epoch 40: TRAIN Accuracy = 0.647 Loss = 0.657 AUC = 0.711
Epoch 40: VAL Accuracy = 0.59 Loss = 0.665 AUC = 0.722


10it [04:50, 29.01s/it]


Epoch 45: TRAIN Accuracy = 0.65 Loss = 0.646 AUC = 0.742
Epoch 45: VAL Accuracy = 0.587 Loss = 0.656 AUC = 0.747


10it [04:36, 27.70s/it]


Epoch 50: TRAIN Accuracy = 0.65 Loss = 0.635 AUC = 0.738
Epoch 50: VAL Accuracy = 0.587 Loss = 0.646 AUC = 0.748


10it [04:24, 26.49s/it]


Epoch 55: TRAIN Accuracy = 0.654 Loss = 0.624 AUC = 0.735
Epoch 55: VAL Accuracy = 0.593 Loss = 0.636 AUC = 0.746


10it [04:37, 27.70s/it]


Epoch 60: TRAIN Accuracy = 0.654 Loss = 0.613 AUC = 0.759
Epoch 60: VAL Accuracy = 0.593 Loss = 0.627 AUC = 0.755


10it [04:37, 27.78s/it]


Epoch 65: TRAIN Accuracy = 0.674 Loss = 0.601 AUC = 0.77
Epoch 65: VAL Accuracy = 0.613 Loss = 0.615 AUC = 0.767


10it [04:49, 28.97s/it]


Epoch 70: TRAIN Accuracy = 0.681 Loss = 0.589 AUC = 0.769
Epoch 70: VAL Accuracy = 0.62 Loss = 0.603 AUC = 0.766


10it [04:25, 26.53s/it]


Epoch 75: TRAIN Accuracy = 0.709 Loss = 0.575 AUC = 0.772
Epoch 75: VAL Accuracy = 0.7 Loss = 0.589 AUC = 0.774


10it [04:14, 25.43s/it]


Epoch 80: TRAIN Accuracy = 0.717 Loss = 0.562 AUC = 0.768
Epoch 80: VAL Accuracy = 0.703 Loss = 0.575 AUC = 0.776


10it [04:11, 25.16s/it]


Epoch 85: TRAIN Accuracy = 0.75 Loss = 0.548 AUC = 0.773
Epoch 85: VAL Accuracy = 0.723 Loss = 0.562 AUC = 0.767


10it [04:10, 25.06s/it]


Epoch 90: TRAIN Accuracy = 0.749 Loss = 0.536 AUC = 0.764
Epoch 90: VAL Accuracy = 0.733 Loss = 0.549 AUC = 0.76


10it [04:17, 25.74s/it]


Epoch 95: TRAIN Accuracy = 0.771 Loss = 0.524 AUC = 0.771
Epoch 95: VAL Accuracy = 0.75 Loss = 0.537 AUC = 0.765


10it [03:17, 19.71s/it]


Epoch 100: TRAIN Accuracy = 0.786 Loss = 0.512 AUC = 0.773
Epoch 100: VAL Accuracy = 0.777 Loss = 0.525 AUC = 0.77


10it [03:20, 20.09s/it]


Epoch 105: TRAIN Accuracy = 0.793 Loss = 0.501 AUC = 0.777
Epoch 105: VAL Accuracy = 0.78 Loss = 0.515 AUC = 0.772


10it [03:24, 20.47s/it]


Epoch 110: TRAIN Accuracy = 0.801 Loss = 0.49 AUC = 0.783
Epoch 110: VAL Accuracy = 0.787 Loss = 0.504 AUC = 0.768


10it [03:44, 22.49s/it]


Epoch 115: TRAIN Accuracy = 0.801 Loss = 0.481 AUC = 0.789
Epoch 115: VAL Accuracy = 0.787 Loss = 0.495 AUC = 0.787


10it [03:47, 22.79s/it]


Epoch 120: TRAIN Accuracy = 0.806 Loss = 0.473 AUC = 0.798
Epoch 120: VAL Accuracy = 0.79 Loss = 0.487 AUC = 0.778


10it [03:51, 23.17s/it]


Epoch 125: TRAIN Accuracy = 0.807 Loss = 0.465 AUC = 0.803
Epoch 125: VAL Accuracy = 0.79 Loss = 0.479 AUC = 0.789


10it [04:07, 24.76s/it]


Epoch 130: TRAIN Accuracy = 0.807 Loss = 0.458 AUC = 0.806
Epoch 130: VAL Accuracy = 0.79 Loss = 0.473 AUC = 0.785


10it [04:16, 25.63s/it]


Epoch 135: TRAIN Accuracy = 0.807 Loss = 0.453 AUC = 0.807
Epoch 135: VAL Accuracy = 0.793 Loss = 0.467 AUC = 0.798


10it [03:37, 21.79s/it]


Epoch 140: TRAIN Accuracy = 0.807 Loss = 0.447 AUC = 0.809
Epoch 140: VAL Accuracy = 0.8 Loss = 0.462 AUC = 0.798


10it [03:37, 21.73s/it]


Epoch 145: TRAIN Accuracy = 0.81 Loss = 0.443 AUC = 0.806
Epoch 145: VAL Accuracy = 0.8 Loss = 0.457 AUC = 0.798


10it [03:31, 21.12s/it]


Epoch 150: TRAIN Accuracy = 0.81 Loss = 0.438 AUC = 0.81
Epoch 150: VAL Accuracy = 0.803 Loss = 0.453 AUC = 0.792


10it [03:23, 20.39s/it]


Epoch 155: TRAIN Accuracy = 0.81 Loss = 0.435 AUC = 0.818
Epoch 155: VAL Accuracy = 0.803 Loss = 0.449 AUC = 0.793


10it [03:24, 20.49s/it]


Epoch 160: TRAIN Accuracy = 0.811 Loss = 0.432 AUC = 0.821
Epoch 160: VAL Accuracy = 0.803 Loss = 0.446 AUC = 0.803


10it [03:23, 20.31s/it]


Epoch 165: TRAIN Accuracy = 0.811 Loss = 0.429 AUC = 0.817
Epoch 165: VAL Accuracy = 0.803 Loss = 0.444 AUC = 0.8


10it [03:24, 20.48s/it]


Epoch 170: TRAIN Accuracy = 0.811 Loss = 0.426 AUC = 0.823
Epoch 170: VAL Accuracy = 0.803 Loss = 0.442 AUC = 0.8


10it [03:22, 20.25s/it]


Epoch 175: TRAIN Accuracy = 0.811 Loss = 0.424 AUC = 0.821
Epoch 175: VAL Accuracy = 0.803 Loss = 0.441 AUC = 0.804


10it [03:23, 20.36s/it]


Epoch 180: TRAIN Accuracy = 0.811 Loss = 0.422 AUC = 0.837
Epoch 180: VAL Accuracy = 0.803 Loss = 0.441 AUC = 0.812


10it [03:24, 20.50s/it]


Epoch 185: TRAIN Accuracy = 0.811 Loss = 0.42 AUC = 0.827
Epoch 185: VAL Accuracy = 0.803 Loss = 0.44 AUC = 0.815


10it [03:24, 20.43s/it]


Epoch 190: TRAIN Accuracy = 0.811 Loss = 0.419 AUC = 0.822
Epoch 190: VAL Accuracy = 0.797 Loss = 0.44 AUC = 0.814


10it [03:50, 23.09s/it]


Epoch 195: TRAIN Accuracy = 0.811 Loss = 0.417 AUC = 0.831
Epoch 195: VAL Accuracy = 0.797 Loss = 0.437 AUC = 0.826


10it [04:40, 28.03s/it]


Epoch 200: TRAIN Accuracy = 0.811 Loss = 0.416 AUC = 0.836
Epoch 200: VAL Accuracy = 0.797 Loss = 0.438 AUC = 0.822


10it [04:40, 28.07s/it]


Epoch 205: TRAIN Accuracy = 0.811 Loss = 0.415 AUC = 0.833
Epoch 205: VAL Accuracy = 0.793 Loss = 0.438 AUC = 0.838


10it [04:45, 28.59s/it]


Epoch 210: TRAIN Accuracy = 0.811 Loss = 0.414 AUC = 0.83
Epoch 210: VAL Accuracy = 0.793 Loss = 0.438 AUC = 0.822


10it [04:42, 28.27s/it]


Epoch 215: TRAIN Accuracy = 0.811 Loss = 0.413 AUC = 0.841
Epoch 215: VAL Accuracy = 0.793 Loss = 0.438 AUC = 0.818


10it [04:42, 28.23s/it]


Epoch 220: TRAIN Accuracy = 0.811 Loss = 0.412 AUC = 0.839
Epoch 220: VAL Accuracy = 0.793 Loss = 0.437 AUC = 0.822


10it [04:41, 28.13s/it]


Epoch 225: TRAIN Accuracy = 0.811 Loss = 0.411 AUC = 0.841
Epoch 225: VAL Accuracy = 0.793 Loss = 0.437 AUC = 0.824


10it [05:20, 32.08s/it]


Epoch 230: TRAIN Accuracy = 0.811 Loss = 0.41 AUC = 0.844
Epoch 230: VAL Accuracy = 0.793 Loss = 0.436 AUC = 0.823


10it [05:22, 32.22s/it]


Epoch 235: TRAIN Accuracy = 0.811 Loss = 0.409 AUC = 0.846
Epoch 235: VAL Accuracy = 0.793 Loss = 0.436 AUC = 0.842


10it [05:22, 32.26s/it]


Epoch 240: TRAIN Accuracy = 0.811 Loss = 0.409 AUC = 0.847
Epoch 240: VAL Accuracy = 0.793 Loss = 0.435 AUC = 0.839


10it [05:16, 31.62s/it]


Epoch 245: TRAIN Accuracy = 0.811 Loss = 0.408 AUC = 0.851
Epoch 245: VAL Accuracy = 0.793 Loss = 0.435 AUC = 0.837


10it [05:18, 31.83s/it]


Epoch 250: TRAIN Accuracy = 0.811 Loss = 0.407 AUC = 0.853
Epoch 250: VAL Accuracy = 0.793 Loss = 0.434 AUC = 0.848


10it [05:19, 31.91s/it]


Epoch 255: TRAIN Accuracy = 0.811 Loss = 0.406 AUC = 0.855
Epoch 255: VAL Accuracy = 0.797 Loss = 0.433 AUC = 0.847


10it [05:17, 31.79s/it]


Epoch 260: TRAIN Accuracy = 0.811 Loss = 0.406 AUC = 0.851
Epoch 260: VAL Accuracy = 0.797 Loss = 0.432 AUC = 0.837


10it [05:14, 31.48s/it]


Epoch 265: TRAIN Accuracy = 0.811 Loss = 0.405 AUC = 0.854
Epoch 265: VAL Accuracy = 0.797 Loss = 0.432 AUC = 0.838


10it [05:18, 31.87s/it]


Epoch 270: TRAIN Accuracy = 0.811 Loss = 0.404 AUC = 0.857
Epoch 270: VAL Accuracy = 0.797 Loss = 0.432 AUC = 0.853


10it [05:23, 32.39s/it]


Epoch 275: TRAIN Accuracy = 0.811 Loss = 0.404 AUC = 0.86
Epoch 275: VAL Accuracy = 0.797 Loss = 0.432 AUC = 0.852


10it [05:18, 31.84s/it]


Epoch 280: TRAIN Accuracy = 0.811 Loss = 0.403 AUC = 0.869
Epoch 280: VAL Accuracy = 0.797 Loss = 0.432 AUC = 0.855


10it [05:19, 31.91s/it]


Epoch 285: TRAIN Accuracy = 0.811 Loss = 0.403 AUC = 0.871
Epoch 285: VAL Accuracy = 0.797 Loss = 0.433 AUC = 0.853


10it [05:18, 31.87s/it]


Epoch 290: TRAIN Accuracy = 0.811 Loss = 0.402 AUC = 0.872
Epoch 290: VAL Accuracy = 0.797 Loss = 0.433 AUC = 0.853


10it [05:22, 32.25s/it]


Epoch 295: TRAIN Accuracy = 0.811 Loss = 0.401 AUC = 0.875
Epoch 295: VAL Accuracy = 0.797 Loss = 0.433 AUC = 0.86


10it [05:15, 31.54s/it]

Epoch 300: TRAIN Accuracy = 0.811 Loss = 0.401 AUC = 0.871
Epoch 300: VAL Accuracy = 0.797 Loss = 0.434 AUC = 0.86
CPU times: user 9h 17min 4s, sys: 3h 25min 1s, total: 12h 42min 6s
Wall time: 4h 26min 33s





: 