In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout, LSTM
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-02-19 11:19:16.735989: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-19 11:19:16.751911: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-19 11:19:16.784550: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1739953156.820067 2618431 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739953156.828311 2618431 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-19 11:19:16.867300: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 300
STEP = 5
NUM_EXPERIMENTS = 10

def create_model(train):
    model = Sequential()
    model.add(Input(shape=(train.shape[1], train.shape[2])))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=True))

    model.add(Dropout(0.2))
    model.add(LSTM(10, return_sequences=False))

    model.add(Dropout(0.2))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer=Adam(learning_rate=0.0003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])
    return model

# Experiment

In [3]:
ID = ["ID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
model = create_model(X)
model.summary()

2025-02-19 11:19:20.523367: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

models = []

for i in range(NUM_EPOCHS // STEP):
    epoch_acc = []
    epoch_loss = []
    epoch_auc = []

    epoch_val_acc = []
    epoch_val_loss = []
    epoch_val_auc = []

    for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
        np.random.seed(int(seed))
        random.seed(int(seed))
        tf.random.set_seed(int(seed))

        train_id = np.random.choice(np.unique(np.ravel(data[ID])), 70, replace=False)
        train_index = np.isin(data[ID], train_id)

        train = data.iloc[train_index]
        test = data.iloc[~train_index]

        X_train, y_train = reshape_dataset(train)
        X_test, y_test = reshape_dataset(test)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)

        if i == 0:
            model = create_model(X_train)
            models.append(model)
        else:
            model = models[j]

        history = model.fit(
            X_train, y_train,
            validation_data=(X_test, y_test),
            epochs=STEP,
            batch_size=16,
            verbose=0,
        )
        acc = history.history['accuracy'][0]
        loss = history.history['loss'][0]
        auc = history.history['auc'][0]

        val_acc = history.history['val_accuracy'][0]
        val_loss = history.history['val_loss'][0]
        val_auc = history.history['val_auc'][0]

        epoch_acc.append(acc)
        epoch_loss.append(loss)
        epoch_auc.append(auc)

        epoch_val_acc.append(val_acc)
        epoch_val_loss.append(val_loss)
        epoch_val_auc.append(val_auc)

    print(f"Epoch {(i + 1) * STEP}: TRAIN Accuracy = {np.round(np.mean(epoch_acc), 3)} Loss = {np.round(np.mean(epoch_loss), 3)} AUC = {np.round(np.mean(epoch_auc), 3)}")
    print(f"Epoch {(i + 1) * STEP}: VAL Accuracy = {np.round(np.mean(epoch_val_acc), 3)} Loss = {np.round(np.mean(epoch_val_loss), 3)} AUC = {np.round(np.mean(epoch_val_auc), 3)}")

10it [01:16,  7.69s/it]


Epoch 5: TRAIN Accuracy = 0.48 Loss = 0.695 AUC = 0.496
Epoch 5: VAL Accuracy = 0.487 Loss = 0.696 AUC = 0.487


10it [00:41,  4.20s/it]


Epoch 10: TRAIN Accuracy = 0.536 Loss = 0.688 AUC = 0.576
Epoch 10: VAL Accuracy = 0.55 Loss = 0.688 AUC = 0.572


10it [00:42,  4.21s/it]


Epoch 15: TRAIN Accuracy = 0.581 Loss = 0.678 AUC = 0.661
Epoch 15: VAL Accuracy = 0.54 Loss = 0.681 AUC = 0.629


10it [00:42,  4.24s/it]


Epoch 20: TRAIN Accuracy = 0.61 Loss = 0.67 AUC = 0.692
Epoch 20: VAL Accuracy = 0.597 Loss = 0.671 AUC = 0.7


10it [00:42,  4.28s/it]


Epoch 25: TRAIN Accuracy = 0.65 Loss = 0.659 AUC = 0.717
Epoch 25: VAL Accuracy = 0.623 Loss = 0.661 AUC = 0.725


10it [00:43,  4.30s/it]


Epoch 30: TRAIN Accuracy = 0.657 Loss = 0.649 AUC = 0.724
Epoch 30: VAL Accuracy = 0.637 Loss = 0.649 AUC = 0.749


10it [00:43,  4.32s/it]


Epoch 35: TRAIN Accuracy = 0.68 Loss = 0.633 AUC = 0.75
Epoch 35: VAL Accuracy = 0.663 Loss = 0.638 AUC = 0.75


10it [00:43,  4.33s/it]


Epoch 40: TRAIN Accuracy = 0.696 Loss = 0.623 AUC = 0.768
Epoch 40: VAL Accuracy = 0.66 Loss = 0.625 AUC = 0.743


10it [00:43,  4.33s/it]


Epoch 45: TRAIN Accuracy = 0.696 Loss = 0.601 AUC = 0.751
Epoch 45: VAL Accuracy = 0.683 Loss = 0.611 AUC = 0.754


10it [00:44,  4.43s/it]


Epoch 50: TRAIN Accuracy = 0.723 Loss = 0.587 AUC = 0.776
Epoch 50: VAL Accuracy = 0.703 Loss = 0.597 AUC = 0.764


10it [00:43,  4.36s/it]


Epoch 55: TRAIN Accuracy = 0.73 Loss = 0.571 AUC = 0.781
Epoch 55: VAL Accuracy = 0.717 Loss = 0.588 AUC = 0.756


10it [00:43,  4.33s/it]


Epoch 60: TRAIN Accuracy = 0.753 Loss = 0.56 AUC = 0.788
Epoch 60: VAL Accuracy = 0.733 Loss = 0.579 AUC = 0.764


10it [00:43,  4.39s/it]


Epoch 65: TRAIN Accuracy = 0.754 Loss = 0.551 AUC = 0.78
Epoch 65: VAL Accuracy = 0.74 Loss = 0.569 AUC = 0.771


10it [00:44,  4.42s/it]


Epoch 70: TRAIN Accuracy = 0.766 Loss = 0.537 AUC = 0.823
Epoch 70: VAL Accuracy = 0.747 Loss = 0.561 AUC = 0.773


10it [00:44,  4.41s/it]


Epoch 75: TRAIN Accuracy = 0.766 Loss = 0.531 AUC = 0.812
Epoch 75: VAL Accuracy = 0.75 Loss = 0.556 AUC = 0.779


10it [00:44,  4.43s/it]


Epoch 80: TRAIN Accuracy = 0.773 Loss = 0.528 AUC = 0.8
Epoch 80: VAL Accuracy = 0.753 Loss = 0.548 AUC = 0.784


10it [00:47,  4.70s/it]


Epoch 85: TRAIN Accuracy = 0.779 Loss = 0.51 AUC = 0.825
Epoch 85: VAL Accuracy = 0.757 Loss = 0.543 AUC = 0.793


10it [00:43,  4.34s/it]


Epoch 90: TRAIN Accuracy = 0.769 Loss = 0.508 AUC = 0.817
Epoch 90: VAL Accuracy = 0.757 Loss = 0.534 AUC = 0.802


10it [00:43,  4.33s/it]


Epoch 95: TRAIN Accuracy = 0.774 Loss = 0.495 AUC = 0.819
Epoch 95: VAL Accuracy = 0.757 Loss = 0.53 AUC = 0.802


10it [00:43,  4.32s/it]


Epoch 100: TRAIN Accuracy = 0.781 Loss = 0.485 AUC = 0.827
Epoch 100: VAL Accuracy = 0.753 Loss = 0.526 AUC = 0.809


10it [00:43,  4.36s/it]


Epoch 105: TRAIN Accuracy = 0.777 Loss = 0.474 AUC = 0.84
Epoch 105: VAL Accuracy = 0.753 Loss = 0.517 AUC = 0.813


10it [00:43,  4.35s/it]


Epoch 110: TRAIN Accuracy = 0.783 Loss = 0.47 AUC = 0.837
Epoch 110: VAL Accuracy = 0.75 Loss = 0.518 AUC = 0.815


10it [00:43,  4.37s/it]


Epoch 115: TRAIN Accuracy = 0.794 Loss = 0.455 AUC = 0.861
Epoch 115: VAL Accuracy = 0.757 Loss = 0.514 AUC = 0.82


10it [00:43,  4.40s/it]


Epoch 120: TRAIN Accuracy = 0.777 Loss = 0.471 AUC = 0.836
Epoch 120: VAL Accuracy = 0.763 Loss = 0.506 AUC = 0.819


10it [00:43,  4.38s/it]


Epoch 125: TRAIN Accuracy = 0.793 Loss = 0.46 AUC = 0.832
Epoch 125: VAL Accuracy = 0.753 Loss = 0.508 AUC = 0.822


10it [00:44,  4.42s/it]


Epoch 130: TRAIN Accuracy = 0.8 Loss = 0.443 AUC = 0.85
Epoch 130: VAL Accuracy = 0.763 Loss = 0.505 AUC = 0.823


10it [00:43,  4.40s/it]


Epoch 135: TRAIN Accuracy = 0.801 Loss = 0.437 AUC = 0.852
Epoch 135: VAL Accuracy = 0.757 Loss = 0.498 AUC = 0.822


10it [00:43,  4.37s/it]


Epoch 140: TRAIN Accuracy = 0.786 Loss = 0.446 AUC = 0.843
Epoch 140: VAL Accuracy = 0.757 Loss = 0.495 AUC = 0.834


10it [00:44,  4.41s/it]


Epoch 145: TRAIN Accuracy = 0.801 Loss = 0.424 AUC = 0.852
Epoch 145: VAL Accuracy = 0.75 Loss = 0.497 AUC = 0.832


10it [00:43,  4.38s/it]


Epoch 150: TRAIN Accuracy = 0.804 Loss = 0.431 AUC = 0.86
Epoch 150: VAL Accuracy = 0.75 Loss = 0.493 AUC = 0.831


10it [00:43,  4.39s/it]


Epoch 155: TRAIN Accuracy = 0.804 Loss = 0.427 AUC = 0.863
Epoch 155: VAL Accuracy = 0.753 Loss = 0.499 AUC = 0.832


10it [00:43,  4.39s/it]


Epoch 160: TRAIN Accuracy = 0.801 Loss = 0.416 AUC = 0.861
Epoch 160: VAL Accuracy = 0.753 Loss = 0.499 AUC = 0.831


10it [00:44,  4.45s/it]


Epoch 165: TRAIN Accuracy = 0.791 Loss = 0.425 AUC = 0.859
Epoch 165: VAL Accuracy = 0.753 Loss = 0.491 AUC = 0.834


10it [00:44,  4.42s/it]


Epoch 170: TRAIN Accuracy = 0.804 Loss = 0.422 AUC = 0.855
Epoch 170: VAL Accuracy = 0.75 Loss = 0.487 AUC = 0.837


10it [00:44,  4.45s/it]


Epoch 175: TRAIN Accuracy = 0.807 Loss = 0.408 AUC = 0.856
Epoch 175: VAL Accuracy = 0.75 Loss = 0.49 AUC = 0.832


10it [00:44,  4.42s/it]


Epoch 180: TRAIN Accuracy = 0.806 Loss = 0.409 AUC = 0.866
Epoch 180: VAL Accuracy = 0.75 Loss = 0.489 AUC = 0.839


10it [00:45,  4.52s/it]


Epoch 185: TRAIN Accuracy = 0.79 Loss = 0.416 AUC = 0.863
Epoch 185: VAL Accuracy = 0.753 Loss = 0.483 AUC = 0.836


10it [00:44,  4.43s/it]


Epoch 190: TRAIN Accuracy = 0.811 Loss = 0.404 AUC = 0.864
Epoch 190: VAL Accuracy = 0.753 Loss = 0.492 AUC = 0.838


10it [00:44,  4.41s/it]


Epoch 195: TRAIN Accuracy = 0.813 Loss = 0.405 AUC = 0.862
Epoch 195: VAL Accuracy = 0.75 Loss = 0.502 AUC = 0.84


10it [00:44,  4.43s/it]


Epoch 200: TRAIN Accuracy = 0.806 Loss = 0.399 AUC = 0.873
Epoch 200: VAL Accuracy = 0.757 Loss = 0.495 AUC = 0.839


10it [00:43,  4.40s/it]


Epoch 205: TRAIN Accuracy = 0.796 Loss = 0.409 AUC = 0.862
Epoch 205: VAL Accuracy = 0.75 Loss = 0.499 AUC = 0.838


10it [00:44,  4.47s/it]


Epoch 210: TRAIN Accuracy = 0.809 Loss = 0.4 AUC = 0.864
Epoch 210: VAL Accuracy = 0.76 Loss = 0.485 AUC = 0.838


10it [00:44,  4.41s/it]


Epoch 215: TRAIN Accuracy = 0.809 Loss = 0.398 AUC = 0.878
Epoch 215: VAL Accuracy = 0.757 Loss = 0.49 AUC = 0.84


10it [00:44,  4.40s/it]


Epoch 220: TRAIN Accuracy = 0.8 Loss = 0.388 AUC = 0.868
Epoch 220: VAL Accuracy = 0.75 Loss = 0.489 AUC = 0.838


10it [00:36,  3.68s/it]


Epoch 225: TRAIN Accuracy = 0.806 Loss = 0.393 AUC = 0.868
Epoch 225: VAL Accuracy = 0.753 Loss = 0.49 AUC = 0.841


10it [00:34,  3.41s/it]


Epoch 230: TRAIN Accuracy = 0.81 Loss = 0.382 AUC = 0.882
Epoch 230: VAL Accuracy = 0.753 Loss = 0.488 AUC = 0.843


10it [00:34,  3.42s/it]


Epoch 235: TRAIN Accuracy = 0.813 Loss = 0.375 AUC = 0.883
Epoch 235: VAL Accuracy = 0.763 Loss = 0.491 AUC = 0.842


10it [00:34,  3.41s/it]


Epoch 240: TRAIN Accuracy = 0.806 Loss = 0.396 AUC = 0.865
Epoch 240: VAL Accuracy = 0.76 Loss = 0.48 AUC = 0.848


10it [00:33,  3.39s/it]


Epoch 245: TRAIN Accuracy = 0.817 Loss = 0.375 AUC = 0.887
Epoch 245: VAL Accuracy = 0.753 Loss = 0.489 AUC = 0.846


10it [00:33,  3.39s/it]


Epoch 250: TRAIN Accuracy = 0.809 Loss = 0.382 AUC = 0.882
Epoch 250: VAL Accuracy = 0.763 Loss = 0.481 AUC = 0.848


10it [00:33,  3.39s/it]


Epoch 255: TRAIN Accuracy = 0.83 Loss = 0.374 AUC = 0.886
Epoch 255: VAL Accuracy = 0.76 Loss = 0.474 AUC = 0.846


10it [00:33,  3.40s/it]


Epoch 260: TRAIN Accuracy = 0.817 Loss = 0.379 AUC = 0.884
Epoch 260: VAL Accuracy = 0.753 Loss = 0.486 AUC = 0.845


10it [00:34,  3.46s/it]


Epoch 265: TRAIN Accuracy = 0.824 Loss = 0.371 AUC = 0.886
Epoch 265: VAL Accuracy = 0.76 Loss = 0.481 AUC = 0.846


10it [00:35,  3.60s/it]


Epoch 270: TRAIN Accuracy = 0.804 Loss = 0.381 AUC = 0.883
Epoch 270: VAL Accuracy = 0.757 Loss = 0.483 AUC = 0.844


10it [00:33,  3.36s/it]


Epoch 275: TRAIN Accuracy = 0.811 Loss = 0.379 AUC = 0.878
Epoch 275: VAL Accuracy = 0.76 Loss = 0.481 AUC = 0.847


10it [00:33,  3.36s/it]


Epoch 280: TRAIN Accuracy = 0.82 Loss = 0.368 AUC = 0.891
Epoch 280: VAL Accuracy = 0.753 Loss = 0.491 AUC = 0.846


10it [00:36,  3.63s/it]


Epoch 285: TRAIN Accuracy = 0.81 Loss = 0.374 AUC = 0.884
Epoch 285: VAL Accuracy = 0.76 Loss = 0.487 AUC = 0.853


10it [00:43,  4.32s/it]


Epoch 290: TRAIN Accuracy = 0.813 Loss = 0.366 AUC = 0.893
Epoch 290: VAL Accuracy = 0.757 Loss = 0.481 AUC = 0.851


10it [00:43,  4.36s/it]


Epoch 295: TRAIN Accuracy = 0.819 Loss = 0.372 AUC = 0.891
Epoch 295: VAL Accuracy = 0.76 Loss = 0.48 AUC = 0.853


10it [00:44,  4.45s/it]

Epoch 300: TRAIN Accuracy = 0.821 Loss = 0.37 AUC = 0.887
Epoch 300: VAL Accuracy = 0.763 Loss = 0.473 AUC = 0.854
CPU times: user 1h 10min 37s, sys: 23min 1s, total: 1h 33min 39s
Wall time: 42min 21s



