In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-02-19 12:39:25.456309: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-19 12:39:25.462129: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-19 12:39:25.475605: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1739957965.501486 3137766 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739957965.507055 3137766 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-19 12:39:25.535624: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 300
STEP = 5
NUM_EXPERIMENTS = 10

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))
    model.add(BatchNormalization())

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=False))

    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.0003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data_confusion")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-02-19 12:39:29.650910: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

models = []

for i in range(NUM_EPOCHS // STEP):
    epoch_acc = []
    epoch_loss = []
    epoch_auc = []

    epoch_val_acc = []
    epoch_val_loss = []
    epoch_val_auc = []

    for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
        np.random.seed(int(seed))
        random.seed(int(seed))
        tf.random.set_seed(int(seed))

        train_id = np.random.choice(np.unique(np.ravel(data[ID])), 70, replace=False)
        train_index = np.isin(data[ID], train_id)

        train = data.iloc[train_index]
        test = data.iloc[~train_index]

        X_train, y_train = reshape_dataset(train)
        X_test, y_test = reshape_dataset(test)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)

        if i == 0:
            model = create_model(X_train)
            models.append(model)
        else:
            model = models[j]

        history = model.fit(
            X_train, y_train,
            validation_data=(X_test, y_test),
            epochs=STEP,
            batch_size=16,
            verbose=0,
        )
        acc = history.history['accuracy'][0]
        loss = history.history['loss'][0]
        auc = history.history['auc'][0]

        val_acc = history.history['val_accuracy'][0]
        val_loss = history.history['val_loss'][0]
        val_auc = history.history['val_auc'][0]

        epoch_acc.append(acc)
        epoch_loss.append(loss)
        epoch_auc.append(auc)

        epoch_val_acc.append(val_acc)
        epoch_val_loss.append(val_loss)
        epoch_val_auc.append(val_auc)

    print(f"Epoch {(i + 1) * STEP}: TRAIN Accuracy = {np.round(np.mean(epoch_acc), 3)} Loss = {np.round(np.mean(epoch_loss), 3)} AUC = {np.round(np.mean(epoch_auc), 3)}")
    print(f"Epoch {(i + 1) * STEP}: VAL Accuracy = {np.round(np.mean(epoch_val_acc), 3)} Loss = {np.round(np.mean(epoch_val_loss), 3)} AUC = {np.round(np.mean(epoch_val_auc), 3)}")

10it [02:02, 12.28s/it]


Epoch 5: TRAIN Accuracy = 0.451 Loss = 0.705 AUC = 0.383
Epoch 5: VAL Accuracy = 0.417 Loss = 0.706 AUC = 0.418


10it [01:05,  6.50s/it]


Epoch 10: TRAIN Accuracy = 0.527 Loss = 0.693 AUC = 0.536
Epoch 10: VAL Accuracy = 0.553 Loss = 0.692 AUC = 0.564


10it [01:08,  6.87s/it]


Epoch 15: TRAIN Accuracy = 0.554 Loss = 0.687 AUC = 0.62
Epoch 15: VAL Accuracy = 0.58 Loss = 0.684 AUC = 0.697


10it [01:09,  6.92s/it]


Epoch 20: TRAIN Accuracy = 0.603 Loss = 0.68 AUC = 0.684
Epoch 20: VAL Accuracy = 0.59 Loss = 0.681 AUC = 0.68


10it [01:10,  7.01s/it]


Epoch 25: TRAIN Accuracy = 0.603 Loss = 0.679 AUC = 0.656
Epoch 25: VAL Accuracy = 0.637 Loss = 0.677 AUC = 0.734


10it [01:10,  7.02s/it]


Epoch 30: TRAIN Accuracy = 0.609 Loss = 0.674 AUC = 0.702
Epoch 30: VAL Accuracy = 0.647 Loss = 0.673 AUC = 0.738


10it [01:10,  7.09s/it]


Epoch 35: TRAIN Accuracy = 0.647 Loss = 0.667 AUC = 0.735
Epoch 35: VAL Accuracy = 0.663 Loss = 0.668 AUC = 0.749


10it [01:13,  7.37s/it]


Epoch 40: TRAIN Accuracy = 0.666 Loss = 0.656 AUC = 0.753
Epoch 40: VAL Accuracy = 0.717 Loss = 0.662 AUC = 0.773


10it [01:08,  6.82s/it]


Epoch 45: TRAIN Accuracy = 0.663 Loss = 0.649 AUC = 0.757
Epoch 45: VAL Accuracy = 0.703 Loss = 0.655 AUC = 0.763


10it [01:12,  7.26s/it]


Epoch 50: TRAIN Accuracy = 0.669 Loss = 0.639 AUC = 0.763
Epoch 50: VAL Accuracy = 0.65 Loss = 0.647 AUC = 0.768


10it [01:09,  6.95s/it]


Epoch 55: TRAIN Accuracy = 0.704 Loss = 0.626 AUC = 0.763
Epoch 55: VAL Accuracy = 0.69 Loss = 0.637 AUC = 0.759


10it [01:12,  7.20s/it]


Epoch 60: TRAIN Accuracy = 0.719 Loss = 0.61 AUC = 0.788
Epoch 60: VAL Accuracy = 0.71 Loss = 0.626 AUC = 0.767


10it [01:12,  7.28s/it]


Epoch 65: TRAIN Accuracy = 0.713 Loss = 0.595 AUC = 0.776
Epoch 65: VAL Accuracy = 0.737 Loss = 0.616 AUC = 0.765


10it [01:10,  7.03s/it]


Epoch 70: TRAIN Accuracy = 0.733 Loss = 0.585 AUC = 0.774
Epoch 70: VAL Accuracy = 0.713 Loss = 0.606 AUC = 0.745


10it [01:11,  7.15s/it]


Epoch 75: TRAIN Accuracy = 0.723 Loss = 0.57 AUC = 0.777
Epoch 75: VAL Accuracy = 0.737 Loss = 0.595 AUC = 0.733


10it [01:11,  7.19s/it]


Epoch 80: TRAIN Accuracy = 0.737 Loss = 0.556 AUC = 0.786
Epoch 80: VAL Accuracy = 0.74 Loss = 0.589 AUC = 0.733


10it [01:09,  6.91s/it]


Epoch 85: TRAIN Accuracy = 0.74 Loss = 0.55 AUC = 0.785
Epoch 85: VAL Accuracy = 0.74 Loss = 0.58 AUC = 0.747


10it [01:12,  7.29s/it]


Epoch 90: TRAIN Accuracy = 0.759 Loss = 0.537 AUC = 0.8
Epoch 90: VAL Accuracy = 0.74 Loss = 0.572 AUC = 0.746


10it [01:10,  7.06s/it]


Epoch 95: TRAIN Accuracy = 0.741 Loss = 0.53 AUC = 0.786
Epoch 95: VAL Accuracy = 0.74 Loss = 0.567 AUC = 0.753


10it [01:12,  7.29s/it]


Epoch 100: TRAIN Accuracy = 0.754 Loss = 0.518 AUC = 0.801
Epoch 100: VAL Accuracy = 0.737 Loss = 0.561 AUC = 0.742


10it [01:11,  7.19s/it]


Epoch 105: TRAIN Accuracy = 0.776 Loss = 0.505 AUC = 0.797
Epoch 105: VAL Accuracy = 0.737 Loss = 0.558 AUC = 0.78


10it [01:11,  7.17s/it]


Epoch 110: TRAIN Accuracy = 0.759 Loss = 0.506 AUC = 0.805
Epoch 110: VAL Accuracy = 0.733 Loss = 0.556 AUC = 0.778


10it [01:12,  7.26s/it]


Epoch 115: TRAIN Accuracy = 0.76 Loss = 0.496 AUC = 0.814
Epoch 115: VAL Accuracy = 0.73 Loss = 0.557 AUC = 0.784


10it [01:13,  7.34s/it]


Epoch 120: TRAIN Accuracy = 0.769 Loss = 0.493 AUC = 0.805
Epoch 120: VAL Accuracy = 0.737 Loss = 0.551 AUC = 0.781


10it [01:12,  7.27s/it]


Epoch 125: TRAIN Accuracy = 0.761 Loss = 0.495 AUC = 0.803
Epoch 125: VAL Accuracy = 0.73 Loss = 0.555 AUC = 0.779


10it [01:16,  7.61s/it]


Epoch 130: TRAIN Accuracy = 0.77 Loss = 0.486 AUC = 0.82
Epoch 130: VAL Accuracy = 0.727 Loss = 0.555 AUC = 0.803


10it [01:18,  7.85s/it]


Epoch 135: TRAIN Accuracy = 0.771 Loss = 0.487 AUC = 0.806
Epoch 135: VAL Accuracy = 0.74 Loss = 0.547 AUC = 0.797


10it [01:20,  8.03s/it]


Epoch 140: TRAIN Accuracy = 0.771 Loss = 0.478 AUC = 0.809
Epoch 140: VAL Accuracy = 0.733 Loss = 0.551 AUC = 0.807


10it [01:18,  7.88s/it]


Epoch 145: TRAIN Accuracy = 0.773 Loss = 0.474 AUC = 0.817
Epoch 145: VAL Accuracy = 0.733 Loss = 0.551 AUC = 0.805


10it [01:21,  8.14s/it]


Epoch 150: TRAIN Accuracy = 0.766 Loss = 0.473 AUC = 0.826
Epoch 150: VAL Accuracy = 0.737 Loss = 0.551 AUC = 0.807


10it [01:19,  7.96s/it]


Epoch 155: TRAIN Accuracy = 0.769 Loss = 0.479 AUC = 0.804
Epoch 155: VAL Accuracy = 0.733 Loss = 0.544 AUC = 0.807


10it [01:22,  8.24s/it]


Epoch 160: TRAIN Accuracy = 0.783 Loss = 0.469 AUC = 0.827
Epoch 160: VAL Accuracy = 0.687 Loss = 0.551 AUC = 0.804


10it [01:20,  8.03s/it]


Epoch 165: TRAIN Accuracy = 0.78 Loss = 0.464 AUC = 0.823
Epoch 165: VAL Accuracy = 0.74 Loss = 0.543 AUC = 0.806


10it [01:20,  8.04s/it]


Epoch 170: TRAIN Accuracy = 0.779 Loss = 0.467 AUC = 0.821
Epoch 170: VAL Accuracy = 0.737 Loss = 0.547 AUC = 0.808


10it [01:19,  7.97s/it]


Epoch 175: TRAIN Accuracy = 0.786 Loss = 0.448 AUC = 0.844
Epoch 175: VAL Accuracy = 0.743 Loss = 0.549 AUC = 0.812


10it [01:18,  7.90s/it]


Epoch 180: TRAIN Accuracy = 0.793 Loss = 0.447 AUC = 0.844
Epoch 180: VAL Accuracy = 0.737 Loss = 0.548 AUC = 0.815


10it [01:20,  8.03s/it]


Epoch 185: TRAIN Accuracy = 0.786 Loss = 0.456 AUC = 0.833
Epoch 185: VAL Accuracy = 0.737 Loss = 0.54 AUC = 0.809


10it [01:18,  7.83s/it]


Epoch 190: TRAIN Accuracy = 0.787 Loss = 0.445 AUC = 0.842
Epoch 190: VAL Accuracy = 0.693 Loss = 0.558 AUC = 0.814


10it [01:20,  8.03s/it]


Epoch 195: TRAIN Accuracy = 0.789 Loss = 0.448 AUC = 0.833
Epoch 195: VAL Accuracy = 0.747 Loss = 0.545 AUC = 0.825


10it [01:19,  7.91s/it]


Epoch 200: TRAIN Accuracy = 0.8 Loss = 0.435 AUC = 0.847
Epoch 200: VAL Accuracy = 0.743 Loss = 0.544 AUC = 0.819


10it [01:09,  6.94s/it]


Epoch 205: TRAIN Accuracy = 0.793 Loss = 0.442 AUC = 0.837
Epoch 205: VAL Accuracy = 0.743 Loss = 0.549 AUC = 0.821


10it [01:09,  6.91s/it]


Epoch 210: TRAIN Accuracy = 0.793 Loss = 0.446 AUC = 0.842
Epoch 210: VAL Accuracy = 0.75 Loss = 0.544 AUC = 0.835


10it [01:09,  6.98s/it]


Epoch 215: TRAIN Accuracy = 0.793 Loss = 0.435 AUC = 0.846
Epoch 215: VAL Accuracy = 0.753 Loss = 0.548 AUC = 0.825


10it [01:08,  6.85s/it]


Epoch 220: TRAIN Accuracy = 0.804 Loss = 0.43 AUC = 0.845
Epoch 220: VAL Accuracy = 0.753 Loss = 0.537 AUC = 0.831


10it [01:09,  6.94s/it]


Epoch 225: TRAIN Accuracy = 0.79 Loss = 0.43 AUC = 0.858
Epoch 225: VAL Accuracy = 0.747 Loss = 0.543 AUC = 0.832


10it [01:09,  6.96s/it]


Epoch 230: TRAIN Accuracy = 0.784 Loss = 0.439 AUC = 0.85
Epoch 230: VAL Accuracy = 0.75 Loss = 0.546 AUC = 0.833


10it [01:07,  6.77s/it]


Epoch 235: TRAIN Accuracy = 0.789 Loss = 0.435 AUC = 0.843
Epoch 235: VAL Accuracy = 0.713 Loss = 0.551 AUC = 0.83


10it [01:09,  6.96s/it]


Epoch 240: TRAIN Accuracy = 0.793 Loss = 0.42 AUC = 0.862
Epoch 240: VAL Accuracy = 0.757 Loss = 0.534 AUC = 0.842


10it [01:10,  7.01s/it]


Epoch 245: TRAIN Accuracy = 0.793 Loss = 0.431 AUC = 0.85
Epoch 245: VAL Accuracy = 0.757 Loss = 0.542 AUC = 0.844


10it [01:10,  7.00s/it]


Epoch 250: TRAIN Accuracy = 0.79 Loss = 0.436 AUC = 0.837
Epoch 250: VAL Accuracy = 0.753 Loss = 0.543 AUC = 0.843


10it [01:07,  6.79s/it]


Epoch 255: TRAIN Accuracy = 0.799 Loss = 0.419 AUC = 0.859
Epoch 255: VAL Accuracy = 0.753 Loss = 0.547 AUC = 0.847


10it [01:09,  6.92s/it]


Epoch 260: TRAIN Accuracy = 0.794 Loss = 0.421 AUC = 0.866
Epoch 260: VAL Accuracy = 0.71 Loss = 0.545 AUC = 0.838


10it [01:09,  6.92s/it]


Epoch 265: TRAIN Accuracy = 0.799 Loss = 0.426 AUC = 0.849
Epoch 265: VAL Accuracy = 0.75 Loss = 0.542 AUC = 0.841


10it [01:08,  6.81s/it]


Epoch 270: TRAIN Accuracy = 0.8 Loss = 0.415 AUC = 0.857
Epoch 270: VAL Accuracy = 0.75 Loss = 0.542 AUC = 0.843


10it [01:10,  7.08s/it]


Epoch 275: TRAIN Accuracy = 0.819 Loss = 0.406 AUC = 0.867
Epoch 275: VAL Accuracy = 0.757 Loss = 0.549 AUC = 0.847


10it [01:10,  7.04s/it]


Epoch 280: TRAIN Accuracy = 0.8 Loss = 0.406 AUC = 0.865
Epoch 280: VAL Accuracy = 0.76 Loss = 0.555 AUC = 0.844


10it [01:10,  7.09s/it]


Epoch 285: TRAIN Accuracy = 0.801 Loss = 0.414 AUC = 0.863
Epoch 285: VAL Accuracy = 0.757 Loss = 0.546 AUC = 0.847


10it [01:01,  6.20s/it]


Epoch 290: TRAIN Accuracy = 0.789 Loss = 0.422 AUC = 0.852
Epoch 290: VAL Accuracy = 0.727 Loss = 0.546 AUC = 0.85


10it [01:01,  6.20s/it]


Epoch 295: TRAIN Accuracy = 0.796 Loss = 0.421 AUC = 0.855
Epoch 295: VAL Accuracy = 0.717 Loss = 0.542 AUC = 0.847


10it [01:01,  6.11s/it]

Epoch 300: TRAIN Accuracy = 0.799 Loss = 0.412 AUC = 0.861
Epoch 300: VAL Accuracy = 0.743 Loss = 0.541 AUC = 0.845
CPU times: user 2h 11min 49s, sys: 37min 49s, total: 2h 49min 39s
Wall time: 1h 13min 6s
Parser   : 210 ms



