In [1]:
import numpy as np
import pandas as pd
import random

from pathlib import Path
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.metrics import AUC
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dense, GRU, Input, BatchNormalization, Dropout, RepeatVector, TimeDistributed
from ncps.wirings import AutoNCP
from ncps.keras import LTC

2025-03-04 17:27:16.142424: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-04 17:27:16.247749: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-04 17:27:16.370691: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741098436.485832  435287 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741098436.514112  435287 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-04 17:27:16.825425: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

# Configuration

In [2]:
NUM_EPOCHS = 300
STEP = 5
NUM_EXPERIMENTS = 10

def create_model(train):
    input_dim = train.shape[1]
    feature_dim = train.shape[2]

    # Encoder
    inputs = Input(shape=(input_dim, feature_dim))
    x = TimeDistributed(Dense(feature_dim, activation='relu'))(inputs)
    encoded = TimeDistributed(Dense(16, activation='relu'))(x)

    # Decoder
    x = TimeDistributed(Dense(16, activation='relu'))(encoded)
    reconstructed = TimeDistributed(Dense(feature_dim, activation='relu'))(x)

    classification = Dropout(0.2)(encoded)
    classification = GRU(10, return_sequences=True)(classification)
    classification = Dropout(0.2)(classification)
    classification = GRU(10, return_sequences=True)(classification)
    classification = Dropout(0.2)(classification)
    classification = GRU(10, return_sequences=False)(classification)
    classification = Dense(6, activation='relu')(classification)
    classification = Dense(1, activation='sigmoid')(classification)

    autoencoder = Model(inputs, reconstructed, name="Autoencoder")
    classifier = Model(inputs, classification, name="Classifier")

    autoencoder.compile(optimizer=Adam(learning_rate=0.0003), loss='mse')
    classifier.compile(optimizer=Adam(learning_rate=0.0003), loss='binary_crossentropy', metrics=["accuracy", AUC(name="auc")])

    return autoencoder, classifier

# Experiment

In [3]:
ID = ["ID"]
IDS = ["SubjectID", "VideoID"]
TARGET = ["predefinedlabel"]
FEATURES = ["Raw", "Delta", "Theta", "Alpha1", "Alpha2", "Beta1", "Beta2", "Gamma1", "Gamma2"]
INIT_SEED = 5412

In [4]:
data_dir = Path("/home/aseliverstov/projects/brain_signals/data")
data = pd.read_csv(data_dir / "EEG_data.csv")

data["ID"] = (len(np.unique(data["VideoID"])) * data["SubjectID"] + data["VideoID"]).astype("int")
data = data[ID + FEATURES + TARGET]

data.head(3)

Unnamed: 0,ID,Raw,Delta,Theta,Alpha1,Alpha2,Beta1,Beta2,Gamma1,Gamma2,predefinedlabel
0,0,278.0,301963.0,90612.0,33735.0,23991.0,27946.0,45097.0,33228.0,8293.0,0.0
1,0,-50.0,73787.0,28083.0,1439.0,2240.0,2746.0,3687.0,5293.0,2740.0,0.0
2,0,101.0,758353.0,383745.0,201999.0,62107.0,36293.0,130536.0,57243.0,25354.0,0.0


In [5]:
def reshape_dataset(data):
    features = []
    target = []
    for cur_id in np.unique(data[ID].to_numpy()):
        cur_id_data = data[data[ID].to_numpy() == cur_id]
        target.append(np.mean(cur_id_data[TARGET].to_numpy()).astype("int"))
        features.append(cur_id_data[FEATURES].to_numpy())

    features = pad_sequences(features)
    return np.array(features), np.array(target)

def pad_sequences(arrays, pad_value=0):
    max_length = max(arr.shape[0] for arr in arrays)
    padded_arrays = [
        np.pad(
            arr,
            ((0, max_length - arr.shape[0]), (0, 0)),
            mode='constant',
            constant_values=pad_value)
            for arr in arrays
        ]
    return np.stack(padded_arrays)

In [6]:
X, _ = reshape_dataset(data)
autoencoder, model = create_model(X)
model.summary()

2025-03-04 17:27:22.972208: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [7]:
%%time

autoencoders = []
models = []

for i in range(NUM_EPOCHS // STEP):
    epoch_acc = []
    epoch_loss = []
    epoch_auc = []

    epoch_val_acc = []
    epoch_val_loss = []
    epoch_val_auc = []

    for j, seed in tqdm(enumerate(np.arange(NUM_EXPERIMENTS) + INIT_SEED)):
        np.random.seed(int(seed))
        random.seed(int(seed))
        tf.random.set_seed(int(seed))

        train_id = np.random.choice(np.unique(np.ravel(data[ID])), 70, replace=False)
        train_index = np.isin(data[ID], train_id)

        train = data.iloc[train_index]
        test = data.iloc[~train_index]

        X_train, y_train = reshape_dataset(train)
        X_test, y_test = reshape_dataset(test)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)

        if i == 0:
            autoencoder, model = create_model(X_train)
            autoencoders.append(autoencoder)
            models.append(model)
        else:
            autoencoder = autoencoders[j]
            model = models[j]

        autoencoder.fit(
            X_train, X_train,
            validation_data=(X_test, X_test),
            epochs=STEP,
            batch_size=16,
            verbose=0
        )
        history = model.fit(
            X_train, y_train,
            validation_data=(X_test, y_test),
            epochs=STEP,
            batch_size=16,
            verbose=0,
        )

        acc = history.history['accuracy'][0]
        loss = history.history['loss'][0]
        auc = history.history['auc'][0]

        val_acc = history.history['val_accuracy'][0]
        val_loss = history.history['val_loss'][0]
        val_auc = history.history['val_auc'][0]

        epoch_acc.append(acc)
        epoch_loss.append(loss)
        epoch_auc.append(auc)

        epoch_val_acc.append(val_acc)
        epoch_val_loss.append(val_loss)
        epoch_val_auc.append(val_auc)

    print(f"Epoch {(i + 1) * STEP}: TRAIN Accuracy = {np.round(np.mean(epoch_acc), 3)} Loss = {np.round(np.mean(epoch_loss), 3)} AUC = {np.round(np.mean(epoch_auc), 3)}")
    print(f"Epoch {(i + 1) * STEP}: VAL Accuracy = {np.round(np.mean(epoch_val_acc), 3)} Loss = {np.round(np.mean(epoch_val_loss), 3)} AUC = {np.round(np.mean(epoch_val_auc), 3)}")

10it [10:06, 60.67s/it]


Epoch 5: TRAIN Accuracy = 0.454 Loss = 0.7 AUC = 0.469
Epoch 5: VAL Accuracy = 0.537 Loss = 0.697 AUC = 0.514


10it [03:46, 22.61s/it]


Epoch 10: TRAIN Accuracy = 0.563 Loss = 0.682 AUC = 0.578
Epoch 10: VAL Accuracy = 0.55 Loss = 0.676 AUC = 0.593


10it [03:42, 22.24s/it]


Epoch 15: TRAIN Accuracy = 0.624 Loss = 0.664 AUC = 0.666
Epoch 15: VAL Accuracy = 0.613 Loss = 0.659 AUC = 0.671


10it [03:45, 22.60s/it]


Epoch 20: TRAIN Accuracy = 0.651 Loss = 0.651 AUC = 0.703
Epoch 20: VAL Accuracy = 0.677 Loss = 0.644 AUC = 0.724


10it [03:48, 22.85s/it]


Epoch 25: TRAIN Accuracy = 0.689 Loss = 0.639 AUC = 0.727
Epoch 25: VAL Accuracy = 0.707 Loss = 0.629 AUC = 0.748


10it [03:56, 23.64s/it]


Epoch 30: TRAIN Accuracy = 0.709 Loss = 0.622 AUC = 0.762
Epoch 30: VAL Accuracy = 0.717 Loss = 0.614 AUC = 0.759


10it [04:07, 24.75s/it]


Epoch 35: TRAIN Accuracy = 0.723 Loss = 0.61 AUC = 0.766
Epoch 35: VAL Accuracy = 0.717 Loss = 0.601 AUC = 0.767


10it [04:20, 26.02s/it]


Epoch 40: TRAIN Accuracy = 0.744 Loss = 0.598 AUC = 0.774
Epoch 40: VAL Accuracy = 0.713 Loss = 0.593 AUC = 0.776


10it [03:53, 23.39s/it]


Epoch 45: TRAIN Accuracy = 0.744 Loss = 0.582 AUC = 0.787
Epoch 45: VAL Accuracy = 0.717 Loss = 0.586 AUC = 0.773


10it [03:50, 23.00s/it]


Epoch 50: TRAIN Accuracy = 0.743 Loss = 0.567 AUC = 0.797
Epoch 50: VAL Accuracy = 0.73 Loss = 0.577 AUC = 0.772


10it [03:52, 23.26s/it]


Epoch 55: TRAIN Accuracy = 0.759 Loss = 0.548 AUC = 0.809
Epoch 55: VAL Accuracy = 0.723 Loss = 0.566 AUC = 0.781


10it [03:52, 23.23s/it]


Epoch 60: TRAIN Accuracy = 0.756 Loss = 0.538 AUC = 0.796
Epoch 60: VAL Accuracy = 0.727 Loss = 0.551 AUC = 0.782


10it [03:49, 22.98s/it]


Epoch 65: TRAIN Accuracy = 0.764 Loss = 0.522 AUC = 0.781
Epoch 65: VAL Accuracy = 0.72 Loss = 0.534 AUC = 0.788


10it [03:54, 23.45s/it]


Epoch 70: TRAIN Accuracy = 0.759 Loss = 0.515 AUC = 0.772
Epoch 70: VAL Accuracy = 0.727 Loss = 0.525 AUC = 0.787


10it [03:48, 22.87s/it]


Epoch 75: TRAIN Accuracy = 0.764 Loss = 0.502 AUC = 0.793
Epoch 75: VAL Accuracy = 0.73 Loss = 0.513 AUC = 0.786


10it [03:43, 22.39s/it]


Epoch 80: TRAIN Accuracy = 0.761 Loss = 0.498 AUC = 0.794
Epoch 80: VAL Accuracy = 0.763 Loss = 0.504 AUC = 0.791


10it [03:45, 22.56s/it]


Epoch 85: TRAIN Accuracy = 0.763 Loss = 0.49 AUC = 0.804
Epoch 85: VAL Accuracy = 0.727 Loss = 0.5 AUC = 0.796


10it [03:59, 23.91s/it]


Epoch 90: TRAIN Accuracy = 0.783 Loss = 0.479 AUC = 0.826
Epoch 90: VAL Accuracy = 0.767 Loss = 0.496 AUC = 0.788


10it [03:56, 23.62s/it]


Epoch 95: TRAIN Accuracy = 0.77 Loss = 0.479 AUC = 0.799
Epoch 95: VAL Accuracy = 0.773 Loss = 0.488 AUC = 0.803


10it [03:54, 23.45s/it]


Epoch 100: TRAIN Accuracy = 0.78 Loss = 0.48 AUC = 0.799
Epoch 100: VAL Accuracy = 0.773 Loss = 0.481 AUC = 0.8


10it [03:56, 23.69s/it]


Epoch 105: TRAIN Accuracy = 0.783 Loss = 0.471 AUC = 0.809
Epoch 105: VAL Accuracy = 0.78 Loss = 0.481 AUC = 0.81


10it [03:56, 23.65s/it]


Epoch 110: TRAIN Accuracy = 0.787 Loss = 0.467 AUC = 0.804
Epoch 110: VAL Accuracy = 0.783 Loss = 0.477 AUC = 0.81


10it [03:56, 23.61s/it]


Epoch 115: TRAIN Accuracy = 0.79 Loss = 0.457 AUC = 0.801
Epoch 115: VAL Accuracy = 0.78 Loss = 0.472 AUC = 0.813


10it [04:00, 24.05s/it]


Epoch 120: TRAIN Accuracy = 0.79 Loss = 0.45 AUC = 0.834
Epoch 120: VAL Accuracy = 0.78 Loss = 0.472 AUC = 0.81


10it [04:10, 25.01s/it]


Epoch 125: TRAIN Accuracy = 0.791 Loss = 0.452 AUC = 0.814
Epoch 125: VAL Accuracy = 0.78 Loss = 0.468 AUC = 0.812


10it [04:31, 27.15s/it]


Epoch 130: TRAIN Accuracy = 0.794 Loss = 0.448 AUC = 0.818
Epoch 130: VAL Accuracy = 0.78 Loss = 0.466 AUC = 0.816


10it [04:33, 27.30s/it]


Epoch 135: TRAIN Accuracy = 0.801 Loss = 0.44 AUC = 0.822
Epoch 135: VAL Accuracy = 0.78 Loss = 0.46 AUC = 0.815


10it [04:31, 27.13s/it]


Epoch 140: TRAIN Accuracy = 0.794 Loss = 0.443 AUC = 0.827
Epoch 140: VAL Accuracy = 0.783 Loss = 0.458 AUC = 0.825


10it [04:27, 26.71s/it]


Epoch 145: TRAIN Accuracy = 0.797 Loss = 0.442 AUC = 0.809
Epoch 145: VAL Accuracy = 0.783 Loss = 0.458 AUC = 0.824


10it [04:12, 25.23s/it]


Epoch 150: TRAIN Accuracy = 0.8 Loss = 0.435 AUC = 0.833
Epoch 150: VAL Accuracy = 0.777 Loss = 0.457 AUC = 0.82


10it [04:03, 24.40s/it]


Epoch 155: TRAIN Accuracy = 0.8 Loss = 0.432 AUC = 0.83
Epoch 155: VAL Accuracy = 0.78 Loss = 0.455 AUC = 0.824


10it [04:04, 24.46s/it]


Epoch 160: TRAIN Accuracy = 0.804 Loss = 0.433 AUC = 0.816
Epoch 160: VAL Accuracy = 0.777 Loss = 0.455 AUC = 0.825


10it [04:06, 24.69s/it]


Epoch 165: TRAIN Accuracy = 0.799 Loss = 0.435 AUC = 0.815
Epoch 165: VAL Accuracy = 0.78 Loss = 0.453 AUC = 0.831


10it [04:05, 24.55s/it]


Epoch 170: TRAIN Accuracy = 0.803 Loss = 0.428 AUC = 0.831
Epoch 170: VAL Accuracy = 0.777 Loss = 0.448 AUC = 0.837


10it [03:38, 21.80s/it]


Epoch 175: TRAIN Accuracy = 0.8 Loss = 0.429 AUC = 0.824
Epoch 175: VAL Accuracy = 0.78 Loss = 0.447 AUC = 0.84


10it [02:50, 17.01s/it]


Epoch 180: TRAIN Accuracy = 0.799 Loss = 0.427 AUC = 0.831
Epoch 180: VAL Accuracy = 0.79 Loss = 0.443 AUC = 0.837


10it [02:41, 16.11s/it]


Epoch 185: TRAIN Accuracy = 0.804 Loss = 0.423 AUC = 0.827
Epoch 185: VAL Accuracy = 0.787 Loss = 0.442 AUC = 0.837


10it [02:41, 16.14s/it]


Epoch 190: TRAIN Accuracy = 0.801 Loss = 0.418 AUC = 0.835
Epoch 190: VAL Accuracy = 0.783 Loss = 0.442 AUC = 0.842


10it [02:43, 16.32s/it]


Epoch 195: TRAIN Accuracy = 0.8 Loss = 0.419 AUC = 0.826
Epoch 195: VAL Accuracy = 0.783 Loss = 0.439 AUC = 0.838


10it [02:42, 16.24s/it]


Epoch 200: TRAIN Accuracy = 0.806 Loss = 0.416 AUC = 0.83
Epoch 200: VAL Accuracy = 0.79 Loss = 0.434 AUC = 0.836


10it [02:43, 16.32s/it]


Epoch 205: TRAIN Accuracy = 0.8 Loss = 0.422 AUC = 0.831
Epoch 205: VAL Accuracy = 0.793 Loss = 0.434 AUC = 0.843


10it [02:54, 17.50s/it]


Epoch 210: TRAIN Accuracy = 0.801 Loss = 0.412 AUC = 0.84
Epoch 210: VAL Accuracy = 0.79 Loss = 0.432 AUC = 0.85


10it [02:58, 17.87s/it]


Epoch 215: TRAIN Accuracy = 0.8 Loss = 0.415 AUC = 0.838
Epoch 215: VAL Accuracy = 0.79 Loss = 0.432 AUC = 0.853


10it [03:00, 18.05s/it]


Epoch 220: TRAIN Accuracy = 0.801 Loss = 0.416 AUC = 0.841
Epoch 220: VAL Accuracy = 0.793 Loss = 0.43 AUC = 0.856


10it [03:01, 18.20s/it]


Epoch 225: TRAIN Accuracy = 0.8 Loss = 0.412 AUC = 0.842
Epoch 225: VAL Accuracy = 0.797 Loss = 0.425 AUC = 0.867


10it [03:00, 18.07s/it]


Epoch 230: TRAIN Accuracy = 0.8 Loss = 0.406 AUC = 0.85
Epoch 230: VAL Accuracy = 0.797 Loss = 0.427 AUC = 0.859


10it [03:03, 18.30s/it]


Epoch 235: TRAIN Accuracy = 0.804 Loss = 0.405 AUC = 0.857
Epoch 235: VAL Accuracy = 0.79 Loss = 0.427 AUC = 0.872


10it [03:00, 18.04s/it]


Epoch 240: TRAIN Accuracy = 0.806 Loss = 0.405 AUC = 0.859
Epoch 240: VAL Accuracy = 0.79 Loss = 0.428 AUC = 0.86


10it [03:00, 18.04s/it]


Epoch 245: TRAIN Accuracy = 0.807 Loss = 0.397 AUC = 0.872
Epoch 245: VAL Accuracy = 0.787 Loss = 0.43 AUC = 0.865


10it [03:03, 18.31s/it]


Epoch 250: TRAIN Accuracy = 0.806 Loss = 0.405 AUC = 0.861
Epoch 250: VAL Accuracy = 0.787 Loss = 0.43 AUC = 0.868


10it [03:01, 18.13s/it]


Epoch 255: TRAIN Accuracy = 0.809 Loss = 0.396 AUC = 0.857
Epoch 255: VAL Accuracy = 0.79 Loss = 0.428 AUC = 0.869


10it [03:00, 18.06s/it]


Epoch 260: TRAIN Accuracy = 0.804 Loss = 0.398 AUC = 0.865
Epoch 260: VAL Accuracy = 0.793 Loss = 0.426 AUC = 0.874


10it [02:59, 17.96s/it]


Epoch 265: TRAIN Accuracy = 0.804 Loss = 0.398 AUC = 0.853
Epoch 265: VAL Accuracy = 0.793 Loss = 0.425 AUC = 0.875


10it [03:01, 18.20s/it]


Epoch 270: TRAIN Accuracy = 0.807 Loss = 0.396 AUC = 0.85
Epoch 270: VAL Accuracy = 0.797 Loss = 0.424 AUC = 0.871


10it [03:04, 18.42s/it]


Epoch 275: TRAIN Accuracy = 0.803 Loss = 0.397 AUC = 0.863
Epoch 275: VAL Accuracy = 0.797 Loss = 0.425 AUC = 0.873


10it [03:03, 18.39s/it]


Epoch 280: TRAIN Accuracy = 0.804 Loss = 0.39 AUC = 0.877
Epoch 280: VAL Accuracy = 0.793 Loss = 0.424 AUC = 0.876


10it [03:02, 18.25s/it]


Epoch 285: TRAIN Accuracy = 0.806 Loss = 0.392 AUC = 0.865
Epoch 285: VAL Accuracy = 0.79 Loss = 0.424 AUC = 0.875


10it [03:00, 18.06s/it]


Epoch 290: TRAIN Accuracy = 0.81 Loss = 0.388 AUC = 0.873
Epoch 290: VAL Accuracy = 0.79 Loss = 0.423 AUC = 0.878


10it [03:01, 18.18s/it]


Epoch 295: TRAIN Accuracy = 0.81 Loss = 0.388 AUC = 0.863
Epoch 295: VAL Accuracy = 0.787 Loss = 0.422 AUC = 0.879


10it [03:02, 18.27s/it]

Epoch 300: TRAIN Accuracy = 0.807 Loss = 0.393 AUC = 0.863
Epoch 300: VAL Accuracy = 0.79 Loss = 0.419 AUC = 0.879
CPU times: user 5h 37min 18s, sys: 1h 21min 33s, total: 6h 58min 51s
Wall time: 3h 39min 53s



