In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout, RepeatVector
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-04 16:15:08.483533: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-04 16:15:08.602757: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-04 16:15:08.721723: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741094108.816668  275971 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741094108.849810  275971 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-04 16:15:09.092932: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 300
STEP = 5
NUM_EXPERIMENTS = 10

def create_model(train):
    input_dim = train.shape[1]
    feature_dim = train.shape[2]

    # Encoder
    inputs = Input(shape=(input_dim, feature_dim))
    x = Dropout(0.2)(inputs)
    x = GRU(10, return_sequences=True)(x)
    x = Dropout(0.2)(x)
    x = GRU(10, return_sequences=True)(x)
    x = Dropout(0.2)(x)
    encoded = GRU(10, return_sequences=False)(x)

    # Decoder
    x = RepeatVector(input_dim)(encoded)
    x = GRU(10, return_sequences=True)(x)
    x = Dropout(0.2)(x)
    x = GRU(10, return_sequences=True)(x)
    x = Dropout(0.2)(x)
    reconstructed = GRU(feature_dim, return_sequences=True)(x)

    classification = Dense(6, activation='relu')(encoded)
    classification = Dense(1, activation='sigmoid')(classification)

    autoencoder = Model(inputs, reconstructed, name="Autoencoder")
    classifier = Model(inputs, classification, name="Classifier")

    autoencoder.compile(optimizer=Adam(learning_rate=0.0003), loss='mse')
    classifier.compile(optimizer=Adam(learning_rate=0.0003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])

    return autoencoder, classifier

# Experiment

In [3]:
ID = ["ID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [7]:
X, _ = reshape_dataset(data)
autoencoder, model = create_model(X)
model.summary()

In [8]:
%%time

autoencoders = []
models = []

for i in range(NUM_EPOCHS // STEP):
    epoch_acc = []
    epoch_loss = []
    epoch_auc = []

    epoch_val_acc = []
    epoch_val_loss = []
    epoch_val_auc = []

    for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
        np.random.seed(int(seed))
        random.seed(int(seed))
        tf.random.set_seed(int(seed))

        train_id = np.random.choice(np.unique(np.ravel(data[ID])), 70, replace=False)
        train_index = np.isin(data[ID], train_id)

        train = data.iloc[train_index]
        test = data.iloc[~train_index]

        X_train, y_train = reshape_dataset(train)
        X_test, y_test = reshape_dataset(test)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)

        if i == 0:
            autoencoder, model = create_model(X_train)
            autoencoders.append(autoencoder)
            models.append(model)
        else:
            autoencoder = autoencoders[j]
            model = models[j]

        autoencoder.fit(
            X_train, X_train,
            validation_data=(X_test, X_test),
            epochs=STEP,
            batch_size=16,
            verbose=0
        )
        history = model.fit(
            X_train, y_train,
            validation_data=(X_test, y_test),
            epochs=STEP,
            batch_size=16,
            verbose=0,
        )

        acc = history.history['accuracy'][0]
        loss = history.history['loss'][0]
        auc = history.history['auc'][0]

        val_acc = history.history['val_accuracy'][0]
        val_loss = history.history['val_loss'][0]
        val_auc = history.history['val_auc'][0]

        epoch_acc.append(acc)
        epoch_loss.append(loss)
        epoch_auc.append(auc)

        epoch_val_acc.append(val_acc)
        epoch_val_loss.append(val_loss)
        epoch_val_auc.append(val_auc)

    print(f"Epoch {(i + 1) * STEP}: TRAIN Accuracy = {np.round(np.mean(epoch_acc), 3)} Loss = {np.round(np.mean(epoch_loss), 3)} AUC = {np.round(np.mean(epoch_auc), 3)}")
    print(f"Epoch {(i + 1) * STEP}: VAL Accuracy = {np.round(np.mean(epoch_val_acc), 3)} Loss = {np.round(np.mean(epoch_val_loss), 3)} AUC = {np.round(np.mean(epoch_val_auc), 3)}")

10it [03:58, 23.83s/it]


Epoch 5: TRAIN Accuracy = 0.486 Loss = 0.698 AUC = 0.519
Epoch 5: VAL Accuracy = 0.507 Loss = 0.699 AUC = 0.532


10it [02:49, 16.92s/it]


Epoch 10: TRAIN Accuracy = 0.571 Loss = 0.682 AUC = 0.617
Epoch 10: VAL Accuracy = 0.57 Loss = 0.687 AUC = 0.614


10it [02:52, 17.22s/it]


Epoch 15: TRAIN Accuracy = 0.567 Loss = 0.673 AUC = 0.628
Epoch 15: VAL Accuracy = 0.59 Loss = 0.678 AUC = 0.643


10it [02:56, 17.61s/it]


Epoch 20: TRAIN Accuracy = 0.589 Loss = 0.664 AUC = 0.676
Epoch 20: VAL Accuracy = 0.597 Loss = 0.673 AUC = 0.665


10it [02:50, 17.03s/it]


Epoch 25: TRAIN Accuracy = 0.554 Loss = 0.658 AUC = 0.695
Epoch 25: VAL Accuracy = 0.593 Loss = 0.668 AUC = 0.686


10it [03:11, 19.19s/it]


Epoch 30: TRAIN Accuracy = 0.584 Loss = 0.651 AUC = 0.715
Epoch 30: VAL Accuracy = 0.587 Loss = 0.661 AUC = 0.69


10it [03:29, 20.94s/it]


Epoch 35: TRAIN Accuracy = 0.604 Loss = 0.64 AUC = 0.734
Epoch 35: VAL Accuracy = 0.62 Loss = 0.65 AUC = 0.72


10it [03:30, 21.08s/it]


Epoch 40: TRAIN Accuracy = 0.633 Loss = 0.626 AUC = 0.764
Epoch 40: VAL Accuracy = 0.637 Loss = 0.638 AUC = 0.734


10it [04:27, 26.71s/it]


Epoch 45: TRAIN Accuracy = 0.65 Loss = 0.614 AUC = 0.761
Epoch 45: VAL Accuracy = 0.693 Loss = 0.627 AUC = 0.712


10it [03:23, 20.32s/it]


Epoch 50: TRAIN Accuracy = 0.686 Loss = 0.604 AUC = 0.741
Epoch 50: VAL Accuracy = 0.727 Loss = 0.615 AUC = 0.739


10it [03:01, 18.16s/it]


Epoch 55: TRAIN Accuracy = 0.7 Loss = 0.588 AUC = 0.776
Epoch 55: VAL Accuracy = 0.74 Loss = 0.603 AUC = 0.761


10it [03:04, 18.46s/it]


Epoch 60: TRAIN Accuracy = 0.727 Loss = 0.572 AUC = 0.78
Epoch 60: VAL Accuracy = 0.747 Loss = 0.59 AUC = 0.754


10it [03:16, 19.69s/it]


Epoch 65: TRAIN Accuracy = 0.736 Loss = 0.559 AUC = 0.782
Epoch 65: VAL Accuracy = 0.75 Loss = 0.577 AUC = 0.766


10it [03:39, 21.97s/it]


Epoch 70: TRAIN Accuracy = 0.763 Loss = 0.544 AUC = 0.788
Epoch 70: VAL Accuracy = 0.763 Loss = 0.566 AUC = 0.77


10it [02:57, 17.78s/it]


Epoch 75: TRAIN Accuracy = 0.77 Loss = 0.533 AUC = 0.778
Epoch 75: VAL Accuracy = 0.767 Loss = 0.554 AUC = 0.783


10it [02:59, 17.93s/it]


Epoch 80: TRAIN Accuracy = 0.773 Loss = 0.524 AUC = 0.786
Epoch 80: VAL Accuracy = 0.773 Loss = 0.542 AUC = 0.776


10it [03:03, 18.36s/it]


Epoch 85: TRAIN Accuracy = 0.787 Loss = 0.508 AUC = 0.788
Epoch 85: VAL Accuracy = 0.773 Loss = 0.533 AUC = 0.786


10it [03:00, 18.07s/it]


Epoch 90: TRAIN Accuracy = 0.783 Loss = 0.497 AUC = 0.788
Epoch 90: VAL Accuracy = 0.773 Loss = 0.525 AUC = 0.778


10it [03:00, 18.07s/it]


Epoch 95: TRAIN Accuracy = 0.781 Loss = 0.488 AUC = 0.812
Epoch 95: VAL Accuracy = 0.773 Loss = 0.516 AUC = 0.786


10it [03:01, 18.18s/it]


Epoch 100: TRAIN Accuracy = 0.793 Loss = 0.482 AUC = 0.801
Epoch 100: VAL Accuracy = 0.773 Loss = 0.509 AUC = 0.789


10it [02:58, 17.84s/it]


Epoch 105: TRAIN Accuracy = 0.791 Loss = 0.475 AUC = 0.794
Epoch 105: VAL Accuracy = 0.78 Loss = 0.5 AUC = 0.791


10it [02:59, 17.98s/it]


Epoch 110: TRAIN Accuracy = 0.784 Loss = 0.472 AUC = 0.801
Epoch 110: VAL Accuracy = 0.78 Loss = 0.494 AUC = 0.788


10it [03:33, 21.38s/it]


Epoch 115: TRAIN Accuracy = 0.797 Loss = 0.462 AUC = 0.795
Epoch 115: VAL Accuracy = 0.78 Loss = 0.49 AUC = 0.799


10it [03:39, 22.00s/it]


Epoch 120: TRAIN Accuracy = 0.797 Loss = 0.461 AUC = 0.798
Epoch 120: VAL Accuracy = 0.78 Loss = 0.486 AUC = 0.807


10it [03:36, 21.66s/it]


Epoch 125: TRAIN Accuracy = 0.796 Loss = 0.458 AUC = 0.8
Epoch 125: VAL Accuracy = 0.783 Loss = 0.48 AUC = 0.814


10it [03:50, 23.07s/it]


Epoch 130: TRAIN Accuracy = 0.799 Loss = 0.453 AUC = 0.811
Epoch 130: VAL Accuracy = 0.783 Loss = 0.475 AUC = 0.808


10it [03:49, 22.92s/it]


Epoch 135: TRAIN Accuracy = 0.796 Loss = 0.451 AUC = 0.799
Epoch 135: VAL Accuracy = 0.783 Loss = 0.471 AUC = 0.809


10it [03:47, 22.72s/it]


Epoch 140: TRAIN Accuracy = 0.8 Loss = 0.447 AUC = 0.814
Epoch 140: VAL Accuracy = 0.783 Loss = 0.467 AUC = 0.816


10it [03:46, 22.70s/it]


Epoch 145: TRAIN Accuracy = 0.803 Loss = 0.444 AUC = 0.809
Epoch 145: VAL Accuracy = 0.787 Loss = 0.464 AUC = 0.811


10it [03:47, 22.70s/it]


Epoch 150: TRAIN Accuracy = 0.804 Loss = 0.439 AUC = 0.805
Epoch 150: VAL Accuracy = 0.79 Loss = 0.46 AUC = 0.811


10it [04:00, 24.02s/it]


Epoch 155: TRAIN Accuracy = 0.806 Loss = 0.437 AUC = 0.817
Epoch 155: VAL Accuracy = 0.79 Loss = 0.458 AUC = 0.814


10it [04:10, 25.08s/it]


Epoch 160: TRAIN Accuracy = 0.803 Loss = 0.438 AUC = 0.808
Epoch 160: VAL Accuracy = 0.79 Loss = 0.456 AUC = 0.81


10it [03:48, 22.86s/it]


Epoch 165: TRAIN Accuracy = 0.804 Loss = 0.434 AUC = 0.813
Epoch 165: VAL Accuracy = 0.793 Loss = 0.453 AUC = 0.814


10it [03:45, 22.50s/it]


Epoch 170: TRAIN Accuracy = 0.81 Loss = 0.431 AUC = 0.82
Epoch 170: VAL Accuracy = 0.797 Loss = 0.451 AUC = 0.81


10it [03:46, 22.63s/it]


Epoch 175: TRAIN Accuracy = 0.81 Loss = 0.428 AUC = 0.824
Epoch 175: VAL Accuracy = 0.8 Loss = 0.447 AUC = 0.811


10it [03:46, 22.60s/it]


Epoch 180: TRAIN Accuracy = 0.807 Loss = 0.427 AUC = 0.815
Epoch 180: VAL Accuracy = 0.8 Loss = 0.445 AUC = 0.822


10it [03:43, 22.36s/it]


Epoch 185: TRAIN Accuracy = 0.804 Loss = 0.43 AUC = 0.808
Epoch 185: VAL Accuracy = 0.8 Loss = 0.442 AUC = 0.826


10it [03:43, 22.39s/it]


Epoch 190: TRAIN Accuracy = 0.807 Loss = 0.424 AUC = 0.836
Epoch 190: VAL Accuracy = 0.803 Loss = 0.442 AUC = 0.818


10it [03:41, 22.18s/it]


Epoch 195: TRAIN Accuracy = 0.807 Loss = 0.424 AUC = 0.82
Epoch 195: VAL Accuracy = 0.797 Loss = 0.441 AUC = 0.84


10it [03:37, 21.80s/it]


Epoch 200: TRAIN Accuracy = 0.809 Loss = 0.417 AUC = 0.853
Epoch 200: VAL Accuracy = 0.797 Loss = 0.44 AUC = 0.824


10it [03:39, 21.93s/it]


Epoch 205: TRAIN Accuracy = 0.804 Loss = 0.424 AUC = 0.835
Epoch 205: VAL Accuracy = 0.8 Loss = 0.437 AUC = 0.844


10it [03:40, 22.09s/it]


Epoch 210: TRAIN Accuracy = 0.806 Loss = 0.421 AUC = 0.83
Epoch 210: VAL Accuracy = 0.8 Loss = 0.435 AUC = 0.842


10it [03:48, 22.85s/it]


Epoch 215: TRAIN Accuracy = 0.81 Loss = 0.413 AUC = 0.854
Epoch 215: VAL Accuracy = 0.8 Loss = 0.433 AUC = 0.846


10it [03:41, 22.19s/it]


Epoch 220: TRAIN Accuracy = 0.809 Loss = 0.419 AUC = 0.834
Epoch 220: VAL Accuracy = 0.8 Loss = 0.432 AUC = 0.852


10it [03:44, 22.49s/it]


Epoch 225: TRAIN Accuracy = 0.811 Loss = 0.412 AUC = 0.846
Epoch 225: VAL Accuracy = 0.8 Loss = 0.43 AUC = 0.858


10it [03:44, 22.49s/it]


Epoch 230: TRAIN Accuracy = 0.813 Loss = 0.414 AUC = 0.839
Epoch 230: VAL Accuracy = 0.8 Loss = 0.429 AUC = 0.853


10it [03:44, 22.43s/it]


Epoch 235: TRAIN Accuracy = 0.804 Loss = 0.418 AUC = 0.815
Epoch 235: VAL Accuracy = 0.8 Loss = 0.428 AUC = 0.853


10it [03:42, 22.28s/it]


Epoch 240: TRAIN Accuracy = 0.809 Loss = 0.413 AUC = 0.832
Epoch 240: VAL Accuracy = 0.8 Loss = 0.428 AUC = 0.855


10it [03:43, 22.33s/it]


Epoch 245: TRAIN Accuracy = 0.803 Loss = 0.413 AUC = 0.849
Epoch 245: VAL Accuracy = 0.797 Loss = 0.426 AUC = 0.869


10it [03:56, 23.68s/it]


Epoch 250: TRAIN Accuracy = 0.804 Loss = 0.406 AUC = 0.864
Epoch 250: VAL Accuracy = 0.797 Loss = 0.424 AUC = 0.868


10it [04:12, 25.25s/it]


Epoch 255: TRAIN Accuracy = 0.807 Loss = 0.408 AUC = 0.858
Epoch 255: VAL Accuracy = 0.8 Loss = 0.422 AUC = 0.87


10it [04:15, 25.50s/it]


Epoch 260: TRAIN Accuracy = 0.811 Loss = 0.404 AUC = 0.857
Epoch 260: VAL Accuracy = 0.8 Loss = 0.421 AUC = 0.875


10it [04:10, 25.02s/it]


Epoch 265: TRAIN Accuracy = 0.8 Loss = 0.411 AUC = 0.843
Epoch 265: VAL Accuracy = 0.8 Loss = 0.417 AUC = 0.876


10it [04:13, 25.39s/it]


Epoch 270: TRAIN Accuracy = 0.814 Loss = 0.405 AUC = 0.85
Epoch 270: VAL Accuracy = 0.8 Loss = 0.416 AUC = 0.876


10it [04:03, 24.31s/it]


Epoch 275: TRAIN Accuracy = 0.816 Loss = 0.402 AUC = 0.86
Epoch 275: VAL Accuracy = 0.8 Loss = 0.415 AUC = 0.875


10it [03:43, 22.40s/it]


Epoch 280: TRAIN Accuracy = 0.806 Loss = 0.398 AUC = 0.864
Epoch 280: VAL Accuracy = 0.8 Loss = 0.412 AUC = 0.885


10it [03:50, 23.07s/it]


Epoch 285: TRAIN Accuracy = 0.809 Loss = 0.402 AUC = 0.854
Epoch 285: VAL Accuracy = 0.803 Loss = 0.412 AUC = 0.882


10it [03:49, 22.93s/it]


Epoch 290: TRAIN Accuracy = 0.804 Loss = 0.407 AUC = 0.851
Epoch 290: VAL Accuracy = 0.803 Loss = 0.411 AUC = 0.876


10it [03:48, 22.87s/it]


Epoch 295: TRAIN Accuracy = 0.813 Loss = 0.396 AUC = 0.87
Epoch 295: VAL Accuracy = 0.803 Loss = 0.41 AUC = 0.873


10it [03:49, 22.93s/it]

Epoch 300: TRAIN Accuracy = 0.791 Loss = 0.4 AUC = 0.875
Epoch 300: VAL Accuracy = 0.8 Loss = 0.41 AUC = 0.869
CPU times: user 7h 45min, sys: 3h 11min 17s, total: 10h 56min 17s
Wall time: 3h 35min 52s



