In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-02-18 00:58:24.421612: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-18 00:58:24.451641: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-18 00:58:24.481150: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1739829504.520175 1417638 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739829504.530729 1417638 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-18 00:58:24.591941: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 300
STEP = 5
NUM_EXPERIMENTS = 10

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(GRU(10, return_sequences=False))
    
    model.add(Dropout(0.2))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.0003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data_confusion")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-02-18 00:58:31.781513: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

models = []

for i in range(NUM_EPOCHS // STEP):
    epoch_acc = []
    epoch_loss = []
    epoch_auc = []

    epoch_val_acc = []
    epoch_val_loss = []
    epoch_val_auc = []

    for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
        np.random.seed(int(seed))
        random.seed(int(seed))
        tf.random.set_seed(int(seed))

        train_id = np.random.choice(np.unique(np.ravel(data[ID])), 70, replace=False)
        train_index = np.isin(data[ID], train_id)

        train = data.iloc[train_index]
        test = data.iloc[~train_index]

        X_train, y_train = reshape_dataset(train)
        X_test, y_test = reshape_dataset(test)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)

        if i == 0:
            model = create_model(X_train)
            models.append(model)
        else:
            model = models[j]

        history = model.fit(
            X_train, y_train,
            validation_data=(X_test, y_test),
            epochs=STEP,
            batch_size=16,
            verbose=0,
        )
        acc = history.history['accuracy'][0]
        loss = history.history['loss'][0]
        auc = history.history['auc'][0]

        val_acc = history.history['val_accuracy'][0]
        val_loss = history.history['val_loss'][0]
        val_auc = history.history['val_auc'][0]

        epoch_acc.append(acc)
        epoch_loss.append(loss)
        epoch_auc.append(auc)

        epoch_val_acc.append(val_acc)
        epoch_val_loss.append(val_loss)
        epoch_val_auc.append(val_auc)

    print(f"Epoch {(i + 1) * STEP}: TRAIN Accuracy = {np.round(np.mean(epoch_acc), 3)} Loss = {np.round(np.mean(epoch_loss), 3)} AUC = {np.round(np.mean(epoch_auc), 3)}")
    print(f"Epoch {(i + 1) * STEP}: VAL Accuracy = {np.round(np.mean(epoch_val_acc), 3)} Loss = {np.round(np.mean(epoch_val_loss), 3)} AUC = {np.round(np.mean(epoch_val_auc), 3)}")

10it [02:45, 16.58s/it]


Epoch 5: TRAIN Accuracy = 0.503 Loss = 0.682 AUC = 0.579
Epoch 5: VAL Accuracy = 0.547 Loss = 0.681 AUC = 0.564


10it [01:25,  8.59s/it]


Epoch 10: TRAIN Accuracy = 0.643 Loss = 0.671 AUC = 0.631
Epoch 10: VAL Accuracy = 0.64 Loss = 0.668 AUC = 0.644


10it [01:21,  8.14s/it]


Epoch 15: TRAIN Accuracy = 0.673 Loss = 0.658 AUC = 0.707
Epoch 15: VAL Accuracy = 0.66 Loss = 0.654 AUC = 0.72


10it [01:21,  8.19s/it]


Epoch 20: TRAIN Accuracy = 0.699 Loss = 0.646 AUC = 0.735
Epoch 20: VAL Accuracy = 0.68 Loss = 0.64 AUC = 0.765


10it [01:26,  8.64s/it]


Epoch 25: TRAIN Accuracy = 0.729 Loss = 0.628 AUC = 0.785
Epoch 25: VAL Accuracy = 0.707 Loss = 0.623 AUC = 0.761


10it [01:26,  8.63s/it]


Epoch 30: TRAIN Accuracy = 0.744 Loss = 0.611 AUC = 0.794
Epoch 30: VAL Accuracy = 0.75 Loss = 0.606 AUC = 0.774


10it [01:25,  8.53s/it]


Epoch 35: TRAIN Accuracy = 0.771 Loss = 0.588 AUC = 0.794
Epoch 35: VAL Accuracy = 0.757 Loss = 0.589 AUC = 0.772


10it [01:23,  8.35s/it]


Epoch 40: TRAIN Accuracy = 0.757 Loss = 0.58 AUC = 0.784
Epoch 40: VAL Accuracy = 0.76 Loss = 0.572 AUC = 0.782


10it [01:19,  7.99s/it]


Epoch 45: TRAIN Accuracy = 0.77 Loss = 0.557 AUC = 0.785
Epoch 45: VAL Accuracy = 0.767 Loss = 0.558 AUC = 0.778


10it [01:25,  8.60s/it]


Epoch 50: TRAIN Accuracy = 0.779 Loss = 0.546 AUC = 0.794
Epoch 50: VAL Accuracy = 0.773 Loss = 0.544 AUC = 0.787


10it [01:23,  8.36s/it]


Epoch 55: TRAIN Accuracy = 0.774 Loss = 0.533 AUC = 0.776
Epoch 55: VAL Accuracy = 0.777 Loss = 0.533 AUC = 0.786


10it [01:23,  8.33s/it]


Epoch 60: TRAIN Accuracy = 0.776 Loss = 0.521 AUC = 0.8
Epoch 60: VAL Accuracy = 0.777 Loss = 0.524 AUC = 0.773


10it [01:22,  8.22s/it]


Epoch 65: TRAIN Accuracy = 0.78 Loss = 0.517 AUC = 0.772
Epoch 65: VAL Accuracy = 0.777 Loss = 0.515 AUC = 0.788


10it [01:25,  8.50s/it]


Epoch 70: TRAIN Accuracy = 0.777 Loss = 0.505 AUC = 0.798
Epoch 70: VAL Accuracy = 0.777 Loss = 0.509 AUC = 0.79


10it [01:25,  8.53s/it]


Epoch 75: TRAIN Accuracy = 0.79 Loss = 0.497 AUC = 0.793
Epoch 75: VAL Accuracy = 0.777 Loss = 0.504 AUC = 0.793


10it [01:23,  8.36s/it]


Epoch 80: TRAIN Accuracy = 0.791 Loss = 0.496 AUC = 0.774
Epoch 80: VAL Accuracy = 0.78 Loss = 0.499 AUC = 0.793


10it [01:28,  8.82s/it]


Epoch 85: TRAIN Accuracy = 0.791 Loss = 0.48 AUC = 0.798
Epoch 85: VAL Accuracy = 0.777 Loss = 0.495 AUC = 0.793


10it [01:28,  8.80s/it]


Epoch 90: TRAIN Accuracy = 0.791 Loss = 0.48 AUC = 0.79
Epoch 90: VAL Accuracy = 0.773 Loss = 0.491 AUC = 0.803


10it [01:23,  8.33s/it]


Epoch 95: TRAIN Accuracy = 0.794 Loss = 0.477 AUC = 0.781
Epoch 95: VAL Accuracy = 0.77 Loss = 0.486 AUC = 0.803


10it [01:22,  8.25s/it]


Epoch 100: TRAIN Accuracy = 0.791 Loss = 0.472 AUC = 0.777
Epoch 100: VAL Accuracy = 0.77 Loss = 0.483 AUC = 0.799


10it [01:23,  8.40s/it]


Epoch 105: TRAIN Accuracy = 0.794 Loss = 0.466 AUC = 0.793
Epoch 105: VAL Accuracy = 0.777 Loss = 0.479 AUC = 0.8


10it [01:20,  8.08s/it]


Epoch 110: TRAIN Accuracy = 0.803 Loss = 0.463 AUC = 0.785
Epoch 110: VAL Accuracy = 0.777 Loss = 0.474 AUC = 0.816


10it [01:24,  8.46s/it]


Epoch 115: TRAIN Accuracy = 0.803 Loss = 0.454 AUC = 0.805
Epoch 115: VAL Accuracy = 0.777 Loss = 0.47 AUC = 0.816


10it [01:21,  8.10s/it]


Epoch 120: TRAIN Accuracy = 0.801 Loss = 0.457 AUC = 0.792
Epoch 120: VAL Accuracy = 0.78 Loss = 0.467 AUC = 0.81


10it [01:21,  8.16s/it]


Epoch 125: TRAIN Accuracy = 0.804 Loss = 0.45 AUC = 0.807
Epoch 125: VAL Accuracy = 0.78 Loss = 0.463 AUC = 0.81


10it [01:26,  8.63s/it]


Epoch 130: TRAIN Accuracy = 0.803 Loss = 0.441 AUC = 0.829
Epoch 130: VAL Accuracy = 0.783 Loss = 0.461 AUC = 0.828


10it [01:22,  8.29s/it]


Epoch 135: TRAIN Accuracy = 0.804 Loss = 0.442 AUC = 0.812
Epoch 135: VAL Accuracy = 0.787 Loss = 0.457 AUC = 0.821


10it [01:29,  8.94s/it]


Epoch 140: TRAIN Accuracy = 0.807 Loss = 0.438 AUC = 0.817
Epoch 140: VAL Accuracy = 0.79 Loss = 0.453 AUC = 0.816


10it [01:22,  8.21s/it]


Epoch 145: TRAIN Accuracy = 0.81 Loss = 0.436 AUC = 0.815
Epoch 145: VAL Accuracy = 0.793 Loss = 0.452 AUC = 0.823


10it [01:25,  8.51s/it]


Epoch 150: TRAIN Accuracy = 0.804 Loss = 0.434 AUC = 0.826
Epoch 150: VAL Accuracy = 0.793 Loss = 0.448 AUC = 0.825


10it [01:28,  8.81s/it]


Epoch 155: TRAIN Accuracy = 0.806 Loss = 0.43 AUC = 0.83
Epoch 155: VAL Accuracy = 0.793 Loss = 0.444 AUC = 0.828


10it [01:34,  9.43s/it]


Epoch 160: TRAIN Accuracy = 0.804 Loss = 0.439 AUC = 0.795
Epoch 160: VAL Accuracy = 0.793 Loss = 0.442 AUC = 0.854


10it [02:01, 12.11s/it]


Epoch 165: TRAIN Accuracy = 0.81 Loss = 0.43 AUC = 0.808
Epoch 165: VAL Accuracy = 0.793 Loss = 0.441 AUC = 0.85


10it [02:02, 12.30s/it]


Epoch 170: TRAIN Accuracy = 0.806 Loss = 0.428 AUC = 0.825
Epoch 170: VAL Accuracy = 0.793 Loss = 0.44 AUC = 0.844


10it [02:01, 12.13s/it]


Epoch 175: TRAIN Accuracy = 0.811 Loss = 0.425 AUC = 0.812
Epoch 175: VAL Accuracy = 0.793 Loss = 0.439 AUC = 0.851


10it [01:58, 11.85s/it]


Epoch 180: TRAIN Accuracy = 0.807 Loss = 0.426 AUC = 0.816
Epoch 180: VAL Accuracy = 0.797 Loss = 0.437 AUC = 0.847


10it [02:04, 12.50s/it]


Epoch 185: TRAIN Accuracy = 0.807 Loss = 0.42 AUC = 0.832
Epoch 185: VAL Accuracy = 0.797 Loss = 0.435 AUC = 0.848


10it [02:00, 12.00s/it]


Epoch 190: TRAIN Accuracy = 0.806 Loss = 0.419 AUC = 0.838
Epoch 190: VAL Accuracy = 0.797 Loss = 0.433 AUC = 0.856


10it [02:04, 12.47s/it]


Epoch 195: TRAIN Accuracy = 0.807 Loss = 0.422 AUC = 0.828
Epoch 195: VAL Accuracy = 0.797 Loss = 0.433 AUC = 0.863


10it [02:01, 12.20s/it]


Epoch 200: TRAIN Accuracy = 0.807 Loss = 0.411 AUC = 0.848
Epoch 200: VAL Accuracy = 0.8 Loss = 0.431 AUC = 0.856


10it [02:10, 13.00s/it]


Epoch 205: TRAIN Accuracy = 0.804 Loss = 0.418 AUC = 0.83
Epoch 205: VAL Accuracy = 0.8 Loss = 0.429 AUC = 0.858


10it [02:07, 12.72s/it]


Epoch 210: TRAIN Accuracy = 0.809 Loss = 0.416 AUC = 0.833
Epoch 210: VAL Accuracy = 0.8 Loss = 0.426 AUC = 0.862


10it [02:05, 12.59s/it]


Epoch 215: TRAIN Accuracy = 0.813 Loss = 0.409 AUC = 0.848
Epoch 215: VAL Accuracy = 0.8 Loss = 0.425 AUC = 0.866


10it [02:02, 12.26s/it]


Epoch 220: TRAIN Accuracy = 0.806 Loss = 0.413 AUC = 0.832
Epoch 220: VAL Accuracy = 0.8 Loss = 0.425 AUC = 0.862


10it [02:10, 13.01s/it]


Epoch 225: TRAIN Accuracy = 0.799 Loss = 0.403 AUC = 0.866
Epoch 225: VAL Accuracy = 0.8 Loss = 0.422 AUC = 0.86


10it [02:01, 12.19s/it]


Epoch 230: TRAIN Accuracy = 0.804 Loss = 0.403 AUC = 0.855
Epoch 230: VAL Accuracy = 0.797 Loss = 0.42 AUC = 0.866


10it [02:14, 13.43s/it]


Epoch 235: TRAIN Accuracy = 0.804 Loss = 0.402 AUC = 0.859
Epoch 235: VAL Accuracy = 0.797 Loss = 0.421 AUC = 0.871


10it [02:22, 14.27s/it]


Epoch 240: TRAIN Accuracy = 0.81 Loss = 0.402 AUC = 0.846
Epoch 240: VAL Accuracy = 0.8 Loss = 0.419 AUC = 0.86


10it [02:21, 14.13s/it]


Epoch 245: TRAIN Accuracy = 0.8 Loss = 0.402 AUC = 0.849
Epoch 245: VAL Accuracy = 0.8 Loss = 0.418 AUC = 0.862


10it [02:23, 14.34s/it]


Epoch 250: TRAIN Accuracy = 0.809 Loss = 0.4 AUC = 0.853
Epoch 250: VAL Accuracy = 0.797 Loss = 0.416 AUC = 0.868


10it [02:25, 14.57s/it]


Epoch 255: TRAIN Accuracy = 0.809 Loss = 0.397 AUC = 0.849
Epoch 255: VAL Accuracy = 0.8 Loss = 0.413 AUC = 0.88


10it [02:22, 14.24s/it]


Epoch 260: TRAIN Accuracy = 0.804 Loss = 0.399 AUC = 0.854
Epoch 260: VAL Accuracy = 0.797 Loss = 0.413 AUC = 0.873


10it [02:27, 14.76s/it]


Epoch 265: TRAIN Accuracy = 0.813 Loss = 0.395 AUC = 0.855
Epoch 265: VAL Accuracy = 0.8 Loss = 0.413 AUC = 0.873


10it [02:24, 14.40s/it]


Epoch 270: TRAIN Accuracy = 0.811 Loss = 0.396 AUC = 0.866
Epoch 270: VAL Accuracy = 0.797 Loss = 0.409 AUC = 0.874


10it [02:23, 14.40s/it]


Epoch 275: TRAIN Accuracy = 0.809 Loss = 0.389 AUC = 0.864
Epoch 275: VAL Accuracy = 0.8 Loss = 0.409 AUC = 0.871


10it [02:23, 14.31s/it]


Epoch 280: TRAIN Accuracy = 0.809 Loss = 0.393 AUC = 0.861
Epoch 280: VAL Accuracy = 0.803 Loss = 0.409 AUC = 0.874


10it [02:18, 13.83s/it]


Epoch 285: TRAIN Accuracy = 0.8 Loss = 0.4 AUC = 0.838
Epoch 285: VAL Accuracy = 0.8 Loss = 0.406 AUC = 0.883


10it [02:25, 14.53s/it]


Epoch 290: TRAIN Accuracy = 0.797 Loss = 0.398 AUC = 0.846
Epoch 290: VAL Accuracy = 0.8 Loss = 0.403 AUC = 0.878


10it [02:19, 13.96s/it]


Epoch 295: TRAIN Accuracy = 0.81 Loss = 0.387 AUC = 0.867
Epoch 295: VAL Accuracy = 0.803 Loss = 0.403 AUC = 0.878


10it [02:19, 13.92s/it]

Epoch 300: TRAIN Accuracy = 0.807 Loss = 0.386 AUC = 0.861
Epoch 300: VAL Accuracy = 0.803 Loss = 0.405 AUC = 0.878
CPU times: user 3h 10min 38s, sys: 51min 43s, total: 4h 2min 22s
Wall time: 1h 48min 32s



