In [None]:
# General imports
import os
import uuid
import time
import random
import datetime
import glob
import pickle
import tqdm
import copy
import optuna
import numpy as np
import pandas as pd
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import mne
from rich import print as rprint
from rich.pretty import pprint as rpprint
from tqdm import tqdm
from itertools import chain
from functools import partial

# JAX + Keras
os.environ["KERAS_BACKEND"] = "jax"
os.environ["TF_USE_LEGACY_KERAS"] = "0"
import jax
import jax.numpy as jnp
import keras
from keras.models import Model
from keras.layers import Dense, Activation, Permute, Dropout
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv1D, MaxPooling1D, AveragePooling1D
from keras.layers import SeparableConv2D, DepthwiseConv2D
from keras.layers import BatchNormalization
from keras.layers import SpatialDropout2D
from keras.regularizers import l1_l2
from keras.layers import Input, Flatten
from keras.constraints import max_norm
from keras import backend as K

# Sklearn
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout, TimeDistributed, BatchNormalization
import h5py
from sklearn import preprocessing

# Dataset
from custom_datasets.fatigue_mi import FatigueMI

%load_ext autoreload
%autoreload 3

In [None]:
SKLRNG = 42
RNG = jax.random.PRNGKey(SKLRNG)

In [None]:
def data_generator(dataset, subjects = [1], channel_idx = [], filters = ([8, 32],), sfreq = 250):

    find_events = lambda raw, event_id: mne.find_events(raw, shortest_event=0, verbose=False) if len(mne.utils._get_stim_channel(None, raw.info, raise_error=False)) > 0 else mne.events_from_annotations(raw, event_id=event_id, verbose=False)[0]
    
    data = dataset.get_data(subjects=subjects)
    
    X = []
    y = []
    metadata = []

    for subject_id in data.keys():
        for session_id in data[subject_id].keys():
            for run_id in data[subject_id][session_id].keys():
                raw = data[subject_id][session_id][run_id]
                
                for fmin, fmax in filters:
                    raw = raw.filter(l_freq = fmin, h_freq = fmax, method = 'iir', picks = 'eeg', verbose = False)
                
                events = find_events(raw, dataset.event_id)

                tmin = dataset.interval[0]
                tmax = dataset.interval[1]

                channels = np.asarray(raw.info['ch_names'])[channel_idx] if len(channel_idx) > 0 else np.asarray(raw.info['ch_names'])

                # rpprint(channels)
                
                stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
                picks = mne.pick_channels(raw.info["ch_names"], include=channels, exclude=stim_channels, ordered=True)

                x = mne.Epochs(
                    raw,
                    events,
                    event_id=dataset.event_id,
                    tmin=tmin,
                    tmax=tmax,
                    proj=False,
                    baseline=None,
                    preload=True,
                    verbose=False,
                    picks=picks,
                    event_repeated="drop",
                    on_missing="ignore",
                )
                x_events = x.events
                inv_events = {k: v for v, k in dataset.event_id.items()}
                labels = [inv_events[e] for e in x_events[:, -1]]

                # rpprint({
                #     "X": np.asarray(x.get_data(copy=False)).shape,
                #     "y": np.asarray(labels).shape,
                #     "channels selected": np.asarray(raw.info['ch_names'])[channel_idx]
                # })

                # x.plot(scalings="auto")
                # display(x.info)
                
                x_resampled = x.resample(sfreq) # Resampler_Epoch
                x_resampled_data = x_resampled.get_data(copy=False) # Convert_Epoch_Array
                x_resampled_data_standard_scaler = np.asarray([
                    StandardScaler().fit_transform(x_resampled_data[i])
                    for i in np.arange(x_resampled_data.shape[0])
                ]) # Standard_Scaler_Epoch

                # x_resampled.plot(scalings="auto")
                # display(x_resampled.info)

                n = x_resampled_data_standard_scaler.shape[0]
                # n = x.get_data(copy=False).shape[0]
                met = pd.DataFrame(index=range(n))
                met["subject"] = subject_id
                met["session"] = session_id
                met["run"] = run_id
                x.metadata = met.copy()
                
                # X.append(x_resampled_data_standard_scaler)
                X.append(x)
                y.append(labels)
                metadata.append(met)

    return np.concatenate(X, axis=0), np.concatenate(y), pd.concat(metadata, ignore_index=True)

fat_dataset = FatigueMI()
# X, y, _ = data_generator(fat_dataset, subjects=fat_dataset.subject_list, sfreq=250)
X, y, _ = data_generator(fat_dataset, subjects=[5], channel_idx=[], sfreq=128)
# X, y, _ = data_generator(fat_dataset, subjects=[6], channel_idx=[1, 2, 5, 12, 13, 14, 17, 18], sfreq=128)
# X, y, _ = data_generator(fat_dataset, subjects=[5, 6, 12, 17, 24, 28, 29], channel_idx=[], sfreq=128)

In [None]:
y_encoded = LabelEncoder().fit_transform(y) # 0 = left, 1 = right, 2 = unlabelled (!NOTE: should be ignored for binary classification)
# rpprint(y_encoded)
rprint({
    "X": X.shape,
    "y": y_encoded.shape
})

NUM_SAMPLES = X.shape[-1]
NUM_CHANNELS = X.shape[-2]
NUM_CLASSES = len(np.unique(y_encoded))

rpprint({
    "NUM_SAMPLES": NUM_SAMPLES,
    "NUM_CHANNELS": NUM_CHANNELS,
    "NUM_CLASSES": NUM_CLASSES
})

# sns.histplot(y_encoded); # Plot the class distribution

In [None]:
TRAIN_SIZE = 0.8
TEST_SIZE = 0.2

# X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, NUM_SAMPLES, NUM_CHANNELS), y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=None, shuffle=False)
# X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, NUM_SAMPLES, NUM_CHANNELS), y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=y_encoded, shuffle=True)
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=y_encoded, shuffle=True)
# X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=None, shuffle=False)

total = X_train.shape[0] + X_test.shape[0]
rpprint({
    "X_train": f"{X_train.shape[0]} ({X_train.shape[0] / total * 100:.2f}%)",
    "X_test": f"{X_test.shape[0]} ({X_test.shape[0] / total * 100:.2f}%)",
}, expand_all=True)

# batch_size = 32
# batches = batch_generator(X_train, batch_size, RNG)

In [83]:
# region Helper funcs
def shallow_conv_net_square_layer(x):
    return jnp.square(x)

def shallow_conv_net_log_layer(x):
    return jnp.log(jnp.clip(x, 1e-7, 10000))

CUSTOM_OBJECTS = {
    "shallow_conv_net_square_layer": shallow_conv_net_square_layer, 
    "shallow_conv_net_log_layer": shallow_conv_net_log_layer 
}
# endregion Helper funcs

# region Models
def eeg_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, dropout_rate = 0.5, 
            kernLength = 64, F1 = 8, D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):
    """
        From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """

    dropoutType = {'Dropout': Dropout, 'SpatialDropout2D': SpatialDropout2D}[dropoutType]

    input1   = Input(shape = (channels, samples, 1))

    ##################################################################
    block1       = Conv2D(F1, (1, kernLength), padding = 'same',
                                   input_shape = (channels, samples, 1),
                                   use_bias = False)(input1)
    block1       = BatchNormalization()(block1)
    block1       = DepthwiseConv2D((channels, 1), use_bias = False, 
                                   depth_multiplier = D,
                                   depthwise_constraint = max_norm(1.))(block1)
    block1       = BatchNormalization()(block1)
    block1       = Activation('elu')(block1)
    block1       = AveragePooling2D((1, 4))(block1)
    block1       = dropoutType(dropout_rate)(block1)
    
    block2       = SeparableConv2D(F2, (1, 16), use_bias = False, padding = 'same')(block1)
    block2       = BatchNormalization()(block2)
    block2       = Activation('elu')(block2)
    block2       = AveragePooling2D((1, 8))(block2)
    block2       = dropoutType(dropout_rate)(block2)
        
    flatten      = Flatten(name = 'flatten')(block2)
    
    dense        = Dense(nb_classes, name = 'dense', 
                         kernel_constraint = max_norm(norm_rate))(flatten)
    softmax      = Activation('softmax', name = 'softmax')(dense)
    
    return Model(inputs=input1, outputs=softmax)

def deep_conv_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, dropout_rate = 0.5):
    """
        From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """
    input_main   = Input((channels, samples, 1))
    block1       = Conv2D(25, (1, 5), 
                                 input_shape=(channels, samples, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(input_main)
    block1       = Conv2D(25, (channels, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block1       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
    block1       = Activation('elu')(block1)
    block1       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block1)
    block1       = Dropout(dropout_rate)(block1)
  
    block2       = Conv2D(50, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block2       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block2)
    block2       = Activation('elu')(block2)
    block2       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block2)
    block2       = Dropout(dropout_rate)(block2)
    
    block3       = Conv2D(100, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block2)
    block3       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block3)
    block3       = Activation('elu')(block3)
    block3       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block3)
    block3       = Dropout(dropout_rate)(block3)
    
    block4       = Conv2D(200, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block3)
    block4       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block4)
    block4       = Activation('elu')(block4)
    block4       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block4)
    block4       = Dropout(dropout_rate)(block4)
    
    flatten      = Flatten()(block4)
    
    dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
    softmax      = Activation('softmax')(dense)
    
    return Model(inputs=input_main, outputs=softmax)

def shallow_conv_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, **kwargs):
    """
        From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """
    
    _POOL_SIZE_D2_ = kwargs.get("pool_size_d2", 35)
    _STRIDES_D2_ = kwargs.get("strides_d2", 7)
    _CONV_FILTERS_D2_ = kwargs.get("conv_filters_d2", 13)

    _POOL_SIZE_ = kwargs.get("pool_size", (1, _POOL_SIZE_D2_))
    _STRIDES_ = kwargs.get("strides", (1, _STRIDES_D2_))
    _CONV_FILTERS_ = kwargs.get("conv_filters", (1, _CONV_FILTERS_D2_))

    _CONV2D_1_UNITS_ = kwargs.get("conv2d_1_units", 40)
    _CONV2D_2_UNITS_ = kwargs.get("conv2d_2_units", 40)
    _L2_REG_1_ = kwargs.get("l2_reg_1", 0.01)
    _L2_REG_2_ = kwargs.get("l2_reg_2", 0.01)
    _L2_REG_3_ = kwargs.get("l2_reg_3", 0.01)
    _DROPOUT_RATE_ = kwargs.get("dropout_rate", 0.5)

    input_main   = Input(shape=(channels, samples, 1))
    block1       = Conv2D(_CONV2D_1_UNITS_, _CONV_FILTERS_,
                                 input_shape=(channels, samples, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)), kernel_regularizer=keras.regularizers.L2(_L2_REG_1_))(input_main)
    # block1       = Conv2D(40, (channels, 1), use_bias=False, 
    #                       kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block1       = Conv2D(_CONV2D_2_UNITS_, (channels, 1), use_bias=False, 
                          kernel_constraint = max_norm(2., axis=(0,1,2)), kernel_regularizer=keras.regularizers.L2(_L2_REG_2_))(block1)
    block1       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
    block1       = Activation(shallow_conv_net_square_layer)(block1)
    block1       = AveragePooling2D(pool_size=_POOL_SIZE_, strides=_STRIDES_)(block1)
    block1       = Activation(shallow_conv_net_log_layer)(block1)
    block1       = Dropout(_DROPOUT_RATE_)(block1)
    flatten      = Flatten()(block1)
    # dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
    dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5), kernel_regularizer=keras.regularizers.L2(_L2_REG_3_))(flatten)
    softmax      = Activation('softmax')(dense)
    
    return Model(inputs=input_main, outputs=softmax)

def lstm_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES):
    model = Sequential()
    model.add(Input(shape=(samples, channels)))
    model.add(LSTM(40, return_sequences=True, stateful=False))
    # model.add(LSTM(40, return_sequences=True, stateful=False, batch_input_shape=(batch_size, timesteps, data_dim)))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    model.add(LSTM(40, return_sequences=True, stateful=False))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    model.add(LSTM(40, stateful=False))
    model.add(Dropout(0.5))
    # model.add(TimeDistributed(Dense(T_train)))
    model.add(Dense(50,activation='softmax'))
    model.add(Dense(nb_classes, activation='softmax'))
    
    return model

def lstm_cnn_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, **kwargs):

    _CONV1D_1_UNITS_ = kwargs.get("conv1d_1_units", 40)
    _CONV1D_1_KERNEL_SIZE_ = kwargs.get("conv1d_1_kernel_size", 20)
    _CONV1D_1_STRIDES_ = kwargs.get("conv1d_1_strides", 4)
    _CONV1D_1_MAXPOOL_SIZE_ = kwargs.get("conv1d_1_maxpool_size", 4)
    _CONV1D_1_MAXPOOL_STRIDES_ = kwargs.get("conv1d_1_maxpool_strides", 4)
    _LSTM_1_UNITS_ = kwargs.get("lstm_1_units", 50)
    _L2_REG_1_ = kwargs.get("l2_reg_1", 0.01)
    _L2_REG_2_ = kwargs.get("l2_reg_2", 0.01)
    _L2_REG_3_ = kwargs.get("l2_reg_3", 0.01)
    _DROPOUT_RATE_1_ = kwargs.get("dropout_rate_1", 0.5)
    _DROPOUT_RATE_2_ = kwargs.get("dropout_rate_2", 0.5)

    model = Sequential()
    model.add(Input(shape=(samples, channels)))

    # add 1-layer cnn
    model.add(Conv1D(_CONV1D_1_UNITS_, kernel_size=_CONV1D_1_KERNEL_SIZE_, strides=_CONV1D_1_STRIDES_, kernel_regularizer=keras.regularizers.L2(_L2_REG_1_)))
    model.add(Activation('relu'))
    model.add(Dropout(_DROPOUT_RATE_1_))
    model.add(BatchNormalization())
    model.add(MaxPooling1D(pool_size=_CONV1D_1_MAXPOOL_SIZE_, strides=_CONV1D_1_MAXPOOL_STRIDES_))


    # add 1-layer lstm
    model.add(LSTM(_LSTM_1_UNITS_, return_sequences=True, stateful=False, kernel_regularizer=keras.regularizers.L2(_L2_REG_2_)))
    model.add(Dropout(_DROPOUT_RATE_2_))
    model.add(BatchNormalization())
    model.add(Flatten())
    model.add(Dense(nb_classes, activation='softmax', kernel_regularizer=keras.regularizers.L2(_L2_REG_3_)))
    
    return model

def lstm_cnn_net_v2(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES):
    model = Sequential()
    model.add(Input(shape=(samples, channels)))

    # Add 1-layer LSTM
    model.add(LSTM(50, return_sequences=True, stateful=False))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    # Add 1-layer CNN
    model.add(Conv1D(40, kernel_size=20, strides=4))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    model.add(MaxPooling1D(pool_size=4, strides=4))
    model.add(Flatten())
    # Fully connected layer for classification
    model.add(Dense(nb_classes, activation='softmax'))
    return model

# endregion Models

In [None]:
NUM_FOLDS = 5
kfold = KFold(n_splits=NUM_FOLDS, shuffle=True)
acc_per_fold = []
loss_per_fold = []

model = shallow_conv_net()
# model = lstm_cnn_net()

# K-fold Cross Validation model evaluation
fold_no = 1
for train, test in kfold.split(X_train, y_train):
        # Fit data to model
        model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
        
        history = model.fit(
                X_train[train],
                y_train[train],
                batch_size=64,
                epochs=25,
                validation_data=(X_train[test], y_encoded[test]),
                # validation_split=0.2,
                shuffle=True,
                callbacks=[
                        keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
                        keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
                        # keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
                        # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
                ],
        )
        
        # Generate generalization metrics
        scores = model.evaluate(X_train[test], y_train[test], verbose=0)
        rprint(f'Score for fold {fold_no}: {model.metrics_names[0]} of {scores[0]}; {model.metrics_names[1]} of {scores[1]*100}%')
        acc_per_fold.append(scores[1] * 100)
        loss_per_fold.append(scores[0])
        
        # Increase fold number
        fold_no = fold_no + 1

In [None]:
mean_cross_val_score = np.mean(acc_per_fold)
rprint("Mean Cross Validation Score (5-fold):", mean_cross_val_score)

results_test = model.evaluate(X_test, y_test)
rprint("Test Accuracy:", results_test[1])

In [None]:
# results_val = model.evaluate(X_val, y_val)
results_test = model.evaluate(X_test, y_test)

# X_test.reshape(-1, NUM_SAMPLES, NUM_CHANNELS).shape

In [None]:
model = shallow_conv_net()
# model = lstm_cnn_net()
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

history = model.fit(
    X_train,
    y_train,
    batch_size=128,
    epochs=25,
    # validation_data=(X_test, y_test),
    validation_split=0.2,
    shuffle=True,
    callbacks=[
        keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
        keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
        # keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
        # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
    ],
)

In [334]:
def objective_fn(trial, subjects=[], model_str = "lstm_cnn_net", sfreq = 128, batch_size = 128, **kwargs):
    _SFREQ_ = sfreq if sfreq else trial.suggest_categorical("sfreq", [128, 256, 300])
    _TRAIN_SIZE_ = 0.8
    _TEST_SIZE_ = 1 - _TRAIN_SIZE_
    _BATCH_SIZE = batch_size if batch_size else trial.suggest_int("batch_size", 32, 256, step=32)

    models = {
        "lstm_cnn_net": lstm_cnn_net,
        "shallow_conv_net": shallow_conv_net
    }
    
    subjects = subjects if len(subjects) > 0 else fat_dataset.subject_list
    model_fn = models[model_str]
    
    all_channels = fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1]
    channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in all_channels }
    channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    while len(channels_idx) == 0:
        channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in all_channels }
        channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    
    X, y, _ = data_generator(fat_dataset, subjects=subjects, channel_idx=channels_idx, sfreq=_SFREQ_)
    
    _NUM_SAMPLES_ = X.shape[-1]
    _NUM_CHANNELS_ = X.shape[-2]
    
    y_encoded = LabelEncoder().fit_transform(y)
    _NUM_CLASSES_ = len(np.unique(y_encoded))
    if "lstm" in model_str:
        X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, _NUM_SAMPLES_, _NUM_CHANNELS_), y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
    else:
        X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, shuffle=True, stratify=y_encoded)
    # X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, _NUM_SAMPLES_, _NUM_CHANNELS_), y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
        
    # y_train = keras.utils.to_categorical(y_train)
    # y_test = keras.utils.to_categorical(y_test)
    
    model_params = {
        "nb_classes": _NUM_CLASSES_,
        "channels": _NUM_CHANNELS_,
        "samples": _NUM_SAMPLES_,
    }

    if model_str == "shallow_conv_net":
        model_params = {
            **model_params,

            "pool_size_d2": trial.suggest_int("pool_size_d2", 5, 95, step=5),
            "strides_d2": trial.suggest_int("strides_d2", 1, 31, step=2),
            "conv_filters_d2": trial.suggest_int("conv_filters_d2", 5, 55, step=2),
            
            "conv2d_1_units": trial.suggest_int("conv2d_1_units", 10, 200, step=10),
            "conv2d_2_units": trial.suggest_int("conv2d_2_units", 10, 200, step=10),
            "l2_reg_1": trial.suggest_float("l2_reg_1", 0.001, 0.8),
            "l2_reg_2": trial.suggest_float("l2_reg_2", 0.001, 0.8),
            "l2_reg_3": trial.suggest_float("l2_reg_3", 0.001, 0.8),
            "dropout_rate": trial.suggest_float("dropout_rate", 0.1, 0.9, step=0.1),
        }

    if model_str == "lstm_cnn_net":
        model_params = {
            **model_params,

            "conv1d_1_units": trial.suggest_int("conv1d_1_units", 10, 200, step=10),
            "conv1d_1_kernel_size": trial.suggest_int("conv1d_1_kernel_size", 1, 100, step=1),
            "conv1d_1_strides": trial.suggest_int("conv1d_1_strides", 1, 50, step=1),
            # "conv1d_1_maxpool_size": trial.suggest_int("conv1d_1_maxpool_size", 1, 100, step=2),
            "conv1d_1_maxpool_strides": trial.suggest_int("conv1d_1_maxpool_strides", 1, 4, step=1),
            "lstm_1_units": trial.suggest_int("lstm_1_units", 10, 200, step=10),
            "l2_reg_1": trial.suggest_float("l2_reg_1", 0.001, 0.8),
            "l2_reg_2": trial.suggest_float("l2_reg_2", 0.001, 0.8),
            "l2_reg_3": trial.suggest_float("l2_reg_3", 0.001, 0.8),
            "dropout_rate_1": trial.suggest_float("dropout_rate_1", 0.1, 0.9, step=0.1),
            "dropout_rate_2": trial.suggest_float("dropout_rate_2", 0.1, 0.9, step=0.1),
        }
    
    model = model_fn(**model_params)
    model.compile(
        loss="sparse_categorical_crossentropy", 
        optimizer="rmsprop", 
        metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")]
    )
    
    history = model.fit(
        X_train,
        y_train,
        batch_size=_BATCH_SIZE,
        epochs=5,
        validation_split=0.2,
        shuffle=True,
        callbacks=[
            # keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
            # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
            keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True),
            keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
        ],
    )

    print("\n")
    
    # Rename history.history["sparse_categorical_accuracy"] to history.history["accuracy"]
    # history.history["accuracy"] = history.history["sparse_categorical_accuracy"]
    # history.history["val_accuracy"] = history.history["val_sparse_categorical_accuracy"]
    
    trial.set_user_attr("trial_data", {
        "channels_selected": np.asarray(fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'])[channels_idx],
        # "history": history,
        "model": model.to_json(),
        "train_accuracy": history.history["accuracy"],
        "val_accuracy": history.history["val_accuracy"],
        "test_accuracy": model.evaluate(X_test, y_test, batch_size=_BATCH_SIZE)[1]
    })

    print("\n")

    # cost = np.abs(1 - np.max(history.history["val_accuracy"])) + 1e-5*len(channels_idx) + np.abs(np.mean(np.asarray(history.history["accuracy"]) - np.asarray(history.history["val_accuracy"])))
    # cost = np.abs(1 - np.max(history.history["val_accuracy"])) + 0.25*(len(channels_idx)/len(all_channels))
    cost = (1 - np.max(history.history["val_accuracy"]))**2 + 0.5*(1 - np.mean(history.history["val_accuracy"]))**2 + 1e-2*len(channels_idx) + 1e-3*np.mean(np.asarray(history.history["accuracy"]) - np.asarray(history.history["val_accuracy"]))**2
    # cost = (1 - np.max(history.history["val_accuracy"]))**2 + 1e-5*len(channels_idx) + 1e-3*np.mean(np.asarray(history.history["accuracy"]) - np.asarray(history.history["val_accuracy"]))**2
    # cost = (1 - np.max(history.history["val_accuracy"]))**2 + 8e-5*len(channels_idx) #+ 1e-3*(np.mean(history.history["accuracy"][-2:]) - np.mean(history.history["val_accuracy"][-2:]))**2
    # cost = (1 - np.max(history.history["val_accuracy"]))**2 + 0.25*(1 - np.mean(history.history["val_accuracy"]))**2 + 0.25*(1 - history.history["val_accuracy"][-1])**2 + 8e-5*len(channels_idx) #+ 1e-3*np.mean(np.asarray(history.history["accuracy"]) - np.asarray(history.history["val_accuracy"]))**2
    # cost = (1 - np.max(history.history["val_accuracy"]))**2 + 1e-5*len(channels_idx) + 1e-4*np.mean(np.asarray(history.history["accuracy"]) - np.asarray(history.history["val_accuracy"]))**2 + 1e-5*np.mean(1 - np.max(history.history["accuracy"]))**2
    # if not 0.70 <= np.max(history.history["accuracy"]) <= 1.00 or len(channels_idx) < 5:
    if not 0.75 <= np.max(history.history["accuracy"]) <= 1.00:
        cost += .5

    # cost = (1 - np.mean(history.history["val_accuracy"][-2:])) + 0.005*len(channels_idx)
    # cost = (1 - np.mean(history.history["val_accuracy"]))**2 + 2.5e-1*(np.mean(history.history["accuracy"]) - np.mean(history.history["val_accuracy"]))**2 + 1.5e-1*(1 - np.max(history.history["val_accuracy"]))**2 + 5e-2*len(channels_idx)
    # cost = (1 - np.mean(history.history["val_accuracy"]))**2 + 2.5e-1*(np.mean(history.history["accuracy"]) - np.mean(history.history["val_accuracy"]))**2 + 5e-2*len(channels_idx)
    return cost

def early_stopping_check(study, trial, early_stopping_rounds=10):
    current_trial_number = trial.number
    best_trial_number = study.best_trial.number
    should_stop = (current_trial_number - best_trial_number) >= early_stopping_rounds
    if should_stop:
        optuna.logging._get_library_root_logger().info("Early stopping detected: %s", should_stop)
        study.stop()

def heuristic_optimizer(obj_fn, max_iter = 25, show_progress_bar = True, subjects=[], model_str="lstm_cnn_net", sfreq = 128, batch_size = 128, max_stag_count = 10, **kwargs):
    
    optuna.logging.set_verbosity(optuna.logging.CRITICAL)

    objective = partial(obj_fn, subjects = subjects, model_str = model_str, sfreq = sfreq, batch_size = batch_size)
    
    study = optuna.create_study(direction="minimize", sampler=optuna.samplers.NSGAIIISampler())
    try:
        study.optimize(objective, n_trials=max_iter, show_progress_bar=show_progress_bar, callbacks=[partial(early_stopping_check, early_stopping_rounds=max_stag_count)], **kwargs)
    except KeyboardInterrupt:
        pass
    
    return study

study = heuristic_optimizer(obj_fn = objective_fn, model_str="shallow_conv_net", subjects=[12], sfreq = None, batch_size = None, max_iter = 50, max_stag_count = 10)

  study = optuna.create_study(direction="minimize", sampler=optuna.samplers.NSGAIIISampler())


  0%|          | 0/50 [00:00<?, ?it/s]

Adding metadata with 3 columns


  super().__init__(


Epoch 1/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 817ms/step - accuracy: 0.5230 - loss: 7.9744 - val_accuracy: 0.6111 - val_loss: 4.7130 - learning_rate: 0.0010
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.7461 - loss: 4.5281 - val_accuracy: 0.5556 - val_loss: 4.3598 - learning_rate: 0.0010
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.8072 - loss: 4.1678 - val_accuracy: 0.6667 - val_loss: 4.0307 - learning_rate: 0.0010
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.8938 - loss: 3.8339 - val_accuracy: 0.7222 - val_loss: 3.7844 - learning_rate: 0.0010
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.9775 - loss: 3.5327 - val_accuracy: 0.7222 - val_loss: 3.5486 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 385ms/step - accuracy: 0.

  super().__init__(


Epoch 1/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2s/step - accuracy: 0.5502 - loss: 101.5019 - val_accuracy: 0.5000 - val_loss: 78.5219 - learning_rate: 0.0010
Epoch 2/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8199 - loss: 75.4410 - val_accuracy: 0.4444 - val_loss: 64.0073 - learning_rate: 0.0010
Epoch 3/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8199 - loss: 61.6612 - val_accuracy: 0.5556 - val_loss: 53.6127 - learning_rate: 0.0010
Epoch 4/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8401 - loss: 51.8435 - val_accuracy: 0.5556 - val_loss: 45.9884 - learning_rate: 0.0010
Epoch 5/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.9452 - loss: 44.2447 - val_accuracy: 0.5000 - val_loss: 39.4880 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 467ms/step - accu

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5588 - loss: 87.0757 - val_accuracy: 0.5556 - val_loss: 69.6356 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 0.6324 - loss: 69.6016 - val_accuracy: 0.6111 - val_loss: 61.9135 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.7353 - loss: 61.8462 - val_accuracy: 0.6111 - val_loss: 56.0866 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.6471 - loss: 56.0591 - val_accuracy: 0.6667 - val_loss: 51.3230 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.8382 - loss: 51.2439 - val_accuracy: 0.6667 - val_loss: 47.2642 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 693ms/step - accur

  super().__init__(


Epoch 1/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 847ms/step - accuracy: 0.5572 - loss: 6.7768 - val_accuracy: 0.5000 - val_loss: 5.7957 - learning_rate: 0.0010
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.6287 - loss: 5.6944 - val_accuracy: 0.4444 - val_loss: 5.4244 - learning_rate: 0.0010
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.5372 - loss: 5.4075 - val_accuracy: 0.5000 - val_loss: 5.1426 - learning_rate: 0.0010
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6859 - loss: 4.9253 - val_accuracy: 0.5000 - val_loss: 4.8773 - learning_rate: 0.0010
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6144 - loss: 4.7548 - val_accuracy: 0.5556 - val_loss: 4.6816 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 424ms/step - accuracy: 0.

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4s/step - accuracy: 0.5000 - loss: 107.4217 - val_accuracy: 0.5000 - val_loss: 94.0545 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 46ms/step - accuracy: 0.7059 - loss: 93.9041 - val_accuracy: 0.4444 - val_loss: 80.4627 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.7500 - loss: 80.2061 - val_accuracy: 0.5556 - val_loss: 70.8650 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.8676 - loss: 70.2185 - val_accuracy: 0.5556 - val_loss: 63.6051 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.7059 - loss: 62.7182 - val_accuracy: 0.4444 - val_loss: 58.6406 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 563ms/step - accu

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5147 - loss: 14.0023 - val_accuracy: 0.6667 - val_loss: 12.2597 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 0.7794 - loss: 12.1666 - val_accuracy: 0.6111 - val_loss: 11.0377 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.8529 - loss: 10.8431 - val_accuracy: 0.5556 - val_loss: 10.1178 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.8824 - loss: 9.8669 - val_accuracy: 0.5000 - val_loss: 9.3436 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.9265 - loss: 9.0220 - val_accuracy: 0.5000 - val_loss: 8.7146 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 488ms/step - accuracy: 0.5455 - 

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 26s/step - accuracy: 0.4412 - loss: 191.9791 - val_accuracy: 0.6667 - val_loss: 160.1072 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.6471 - loss: 160.1323 - val_accuracy: 0.5556 - val_loss: 142.3359 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.6176 - loss: 142.2985 - val_accuracy: 0.5556 - val_loss: 128.8402 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.6618 - loss: 128.8058 - val_accuracy: 0.5556 - val_loss: 117.8254 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.7500 - loss: 117.7690 - val_accuracy: 0.6111 - val_loss: 108.4199 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 511ms/

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4s/step - accuracy: 0.4265 - loss: 7.7265 - val_accuracy: 0.5000 - val_loss: 6.0923 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.6029 - loss: 6.2555 - val_accuracy: 0.3889 - val_loss: 5.9072 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.4559 - loss: 6.2074 - val_accuracy: 0.5000 - val_loss: 5.7120 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.7059 - loss: 5.7110 - val_accuracy: 0.4444 - val_loss: 5.5686 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.6324 - loss: 5.4639 - val_accuracy: 0.3889 - val_loss: 5.3772 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 643ms/step - accuracy: 0.2727 - loss: 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.4559 - loss: 42.8375 - val_accuracy: 0.5000 - val_loss: 32.8214 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.7941 - loss: 32.6302 - val_accuracy: 0.4444 - val_loss: 29.7278 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.9265 - loss: 29.4562 - val_accuracy: 0.5000 - val_loss: 27.3440 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.9265 - loss: 27.0321 - val_accuracy: 0.5000 - val_loss: 25.3962 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.9853 - loss: 25.0321 - val_accuracy: 0.4444 - val_loss: 23.7168 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 699ms/step - accuracy: 0.590

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5147 - loss: 12.2999 - val_accuracy: 0.4444 - val_loss: 8.4480 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 67ms/step - accuracy: 0.6029 - loss: 8.3895 - val_accuracy: 0.6111 - val_loss: 7.8151 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.5882 - loss: 7.7794 - val_accuracy: 0.4444 - val_loss: 7.3593 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.6912 - loss: 7.2380 - val_accuracy: 0.4444 - val_loss: 6.9527 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.6618 - loss: 6.8467 - val_accuracy: 0.4444 - val_loss: 6.6234 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 426ms/step - accuracy: 0.5455 - loss:

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4s/step - accuracy: 0.4853 - loss: 235.7976 - val_accuracy: 0.4444 - val_loss: 187.4459 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 45ms/step - accuracy: 0.6471 - loss: 186.3750 - val_accuracy: 0.5556 - val_loss: 163.2521 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - accuracy: 0.4412 - loss: 158.6896 - val_accuracy: 0.4444 - val_loss: 140.6676 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 40ms/step - accuracy: 0.7059 - loss: 136.3092 - val_accuracy: 0.5556 - val_loss: 120.4014 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - accuracy: 0.8971 - loss: 119.3370 - val_accuracy: 0.5000 - val_loss: 106.2733 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 579ms/step - accur

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5588 - loss: 109.0380 - val_accuracy: 0.6667 - val_loss: 87.1289 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.5882 - loss: 87.7141 - val_accuracy: 0.5556 - val_loss: 81.3745 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.4559 - loss: 79.4348 - val_accuracy: 0.4444 - val_loss: 69.2824 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.8824 - loss: 68.7921 - val_accuracy: 0.5556 - val_loss: 63.1555 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.9265 - loss: 62.1925 - val_accuracy: 0.4444 - val_loss: 57.9070 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 875ms/step - accu

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5294 - loss: 16.2382 - val_accuracy: 0.6667 - val_loss: 10.6140 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.6912 - loss: 10.5379 - val_accuracy: 0.5556 - val_loss: 9.7288 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.6029 - loss: 9.9799 - val_accuracy: 0.5556 - val_loss: 9.1468 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.8529 - loss: 8.7240 - val_accuracy: 0.8333 - val_loss: 8.1288 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.8971 - loss: 7.9564 - val_accuracy: 0.7778 - val_loss: 7.5965 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 710ms/step - accuracy: 0.8182 - los

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5147 - loss: 12.8762 - val_accuracy: 0.6667 - val_loss: 8.4870 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.8382 - loss: 8.3450 - val_accuracy: 0.5556 - val_loss: 8.5384 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.5735 - loss: 8.0775 - val_accuracy: 0.5000 - val_loss: 8.2673 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.7059 - loss: 7.4865 - val_accuracy: 0.6111 - val_loss: 7.3265 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.8235 - loss: 6.8338 - val_accuracy: 0.7222 - val_loss: 6.7139 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 395ms/step - accuracy: 0.7727 - loss:

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5588 - loss: 34.7421 - val_accuracy: 0.5556 - val_loss: 26.3157 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.7353 - loss: 26.1768 - val_accuracy: 0.5000 - val_loss: 23.7795 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.7794 - loss: 23.6636 - val_accuracy: 0.5556 - val_loss: 21.8588 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.7647 - loss: 21.7292 - val_accuracy: 0.5556 - val_loss: 20.2798 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.8529 - loss: 20.1088 - val_accuracy: 0.5556 - val_loss: 18.9109 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 588ms/step - accuracy: 0.545

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4s/step - accuracy: 0.4559 - loss: 130.3545 - val_accuracy: 0.5000 - val_loss: 97.2046 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 54ms/step - accuracy: 0.8676 - loss: 96.8456 - val_accuracy: 0.5000 - val_loss: 78.6595 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step - accuracy: 0.9706 - loss: 78.3577 - val_accuracy: 0.4444 - val_loss: 68.8143 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step - accuracy: 0.7206 - loss: 65.8276 - val_accuracy: 0.5556 - val_loss: 58.9266 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 48ms/step - accuracy: 0.6029 - loss: 56.5969 - val_accuracy: 0.4444 - val_loss: 52.4243 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 613ms/step - accuracy: 0.50

  super().__init__(


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 2s/step - accuracy: 0.5202 - loss: 34.0995 - val_accuracy: 0.6111 - val_loss: 26.3736 - learning_rate: 0.0010
Epoch 2/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.5398 - loss: 25.5083 - val_accuracy: 0.4444 - val_loss: 23.3751 - learning_rate: 0.0010
Epoch 3/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.6351 - loss: 22.3087 - val_accuracy: 0.5556 - val_loss: 20.3894 - learning_rate: 0.0010
Epoch 4/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.7148 - loss: 19.2907 - val_accuracy: 0.6111 - val_loss: 17.7281 - learning_rate: 0.0010
Epoch 5/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.9400 - loss: 16.8848 - val_accuracy: 0.6111 - val_loss: 15.7833 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 469ms/step - accuracy: 0.545

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.4265 - loss: 67.4277 - val_accuracy: 0.4444 - val_loss: 57.8631 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.6176 - loss: 57.7912 - val_accuracy: 0.2778 - val_loss: 52.0802 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.6324 - loss: 52.0001 - val_accuracy: 0.3333 - val_loss: 47.6823 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.5441 - loss: 47.6012 - val_accuracy: 0.4444 - val_loss: 44.0710 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.6912 - loss: 43.9277 - val_accuracy: 0.4444 - val_loss: 40.9629 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 647ms/step - accuracy: 0.454

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.6471 - loss: 191.5438 - val_accuracy: 0.6111 - val_loss: 154.5200 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 41ms/step - accuracy: 0.6912 - loss: 154.4304 - val_accuracy: 0.5000 - val_loss: 133.0341 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 0.7206 - loss: 132.8671 - val_accuracy: 0.5000 - val_loss: 117.3808 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 0.7647 - loss: 117.0494 - val_accuracy: 0.5556 - val_loss: 104.7385 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 0.8676 - loss: 104.4350 - val_accuracy: 0.5556 - val_loss: 94.2057 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 542ms/step - accura

  super().__init__(


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 1s/step - accuracy: 0.5000 - loss: 18.5900 - val_accuracy: 0.5000 - val_loss: 14.5857 - learning_rate: 0.0010
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.6027 - loss: 14.1608 - val_accuracy: 0.5000 - val_loss: 12.5070 - learning_rate: 0.0010
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6018 - loss: 12.0021 - val_accuracy: 0.6667 - val_loss: 10.8891 - learning_rate: 0.0010
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.7656 - loss: 10.3927 - val_accuracy: 0.7222 - val_loss: 9.5858 - learning_rate: 0.0010
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.8640 - loss: 9.1577 - val_accuracy: 0.5556 - val_loss: 8.5475 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 648ms/step - accuracy: 0.6364 -

  super().__init__(


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 985ms/step - accuracy: 0.4727 - loss: 23.1200 - val_accuracy: 0.8333 - val_loss: 17.7022 - learning_rate: 0.0010
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.7700 - loss: 17.1297 - val_accuracy: 0.8333 - val_loss: 15.0101 - learning_rate: 0.0010
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.7011 - loss: 14.5415 - val_accuracy: 0.6667 - val_loss: 13.0245 - learning_rate: 0.0010
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.7652 - loss: 12.6218 - val_accuracy: 0.5556 - val_loss: 11.4591 - learning_rate: 0.0010
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.8679 - loss: 11.0389 - val_accuracy: 0.6111 - val_loss: 10.0909 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 491ms/step - accuracy: 0.

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.6176 - loss: 66.4857 - val_accuracy: 0.4444 - val_loss: 53.8176 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.7794 - loss: 53.6688 - val_accuracy: 0.3889 - val_loss: 46.3014 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.6618 - loss: 46.1535 - val_accuracy: 0.3333 - val_loss: 40.8435 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.7353 - loss: 40.6060 - val_accuracy: 0.4444 - val_loss: 36.5098 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 0.7941 - loss: 36.1814 - val_accuracy: 0.4444 - val_loss: 32.9307 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 799ms/step - accuracy: 0.500

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.3971 - loss: 98.9253 - val_accuracy: 0.6111 - val_loss: 85.7980 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.5588 - loss: 86.0381 - val_accuracy: 0.6111 - val_loss: 77.7726 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.6471 - loss: 77.8619 - val_accuracy: 0.7222 - val_loss: 71.6345 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.6176 - loss: 71.8392 - val_accuracy: 0.6667 - val_loss: 66.5732 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.7647 - loss: 66.4934 - val_accuracy: 0.7778 - val_loss: 62.1123 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 426ms/step - accuracy: 0.500

In [335]:
trial_metrics_dict = {
    "train_acc": [],
    "test_acc": [],
    "val_acc": [],
    "train_val_diff": [],
    "scores": [],
    "channels_selected": [],
}
for i, trial in enumerate(study.trials_dataframe().itertuples()):
    trial_user_attrs = trial.user_attrs_trial_data
    trial_metrics_dict["scores"].append(trial.value)
    trial_metrics_dict["train_acc"].append(np.max(trial_user_attrs['train_accuracy']))
    trial_metrics_dict["test_acc"].append(trial_user_attrs['test_accuracy'])
    trial_metrics_dict["val_acc"].append(np.max(trial_user_attrs['val_accuracy']))
    trial_metrics_dict["channels_selected"].append(trial_user_attrs['channels_selected'])
    trial_metrics_dict["train_val_diff"].append(abs(np.max(trial_user_attrs['train_accuracy']) - np.max(trial_user_attrs['val_accuracy'])))
trial_metrics_df = pd.DataFrame(trial_metrics_dict)
# trial_metrics_df.sort_values("test_acc", ascending=False)
trial_metrics_df.sort_values("scores", ascending=True)
# trial_metrics_df.sort_values("train_val_diff", ascending=True)

Unnamed: 0,train_acc,test_acc,val_acc,train_val_diff,scores,channels_selected
12,0.897059,0.818182,0.833333,0.063726,0.189693,"[P3, C3, F3, C4, P4, Pz, Fp1, T5, O2, T6, T4]"
22,0.764706,0.5,0.777778,0.013072,0.191303,"[P3, C3, F3, C4, P4, T3, F7, A2, T6]"
20,0.852941,0.454545,0.833333,0.019608,0.192778,"[P3, Fz, F4, C4, Pz, Fp1, O1, O2, F7, F8, T6, T4]"
13,0.838235,0.772727,0.722222,0.116013,0.232784,"[P3, C3, C4, Cz, T3, T5, O2, T4]"
6,0.75,0.454545,0.666667,0.083333,0.255618,"[P4, Cz, Fp1, O1, O2, T4]"
0,0.970588,0.545455,0.722222,0.248366,0.256498,"[C3, F3, Fz, C4, P4, Pz, Fp2, T3, T5, O2, F8, T6]"
19,0.852941,0.636364,0.722222,0.130719,0.271672,"[C3, F3, Fz, F4, Cz, Pz, Fp1, O1, F7, F8, A2]"
5,0.926471,0.545455,0.666667,0.259804,0.28505,"[F3, F4, Cz, Pz, Fp2, T3, T5, A2]"
2,0.838235,0.545455,0.666667,0.171569,0.322473,"[C3, Fz, F4, C4, P4, Cz, Fp1, Fp2, O2, F7, F8,..."
18,0.867647,0.545455,0.611111,0.256536,0.335038,"[F3, Fz, C4, T3, T5, F7, F8, A2]"


In [339]:
max_val_acc = trial_metrics_df.sort_values("val_acc", ascending=False).query("train_acc > val_acc")[:10].sort_values('train_acc', ascending=False)["val_acc"].max()
trial_metrics_df.sort_values("val_acc", ascending=False).query("train_acc > val_acc")[:10].sort_values('train_acc', ascending=False).query(f"val_acc == {max_val_acc}")

Unnamed: 0,train_acc,test_acc,val_acc,train_val_diff,scores,channels_selected
12,0.897059,0.818182,0.833333,0.063726,0.189693,"[P3, C3, F3, C4, P4, Pz, Fp1, T5, O2, T6, T4]"
20,0.852941,0.454545,0.833333,0.019608,0.192778,"[P3, Fz, F4, C4, Pz, Fp1, O1, O2, F7, F8, T6, T4]"


In [337]:
max_val_acc = trial_metrics_df.sort_values("val_acc", ascending=False).query("train_acc > val_acc")[:10].sort_values('train_acc', ascending=False)["val_acc"].max()
# best_candidate = trial_metrics_df.sort_values("scores", ascending=False).query(f"val_acc == {max_val_acc}").sort_values('train_acc', ascending=False)
# best_candidate = trial_metrics_df.sort_values("val_acc", ascending=False).query("train_acc > val_acc")[:10].sort_values('train_acc', ascending=False).query(f"val_acc == {max_val_acc}").sort_values("train_val_diff", ascending=True)
best_candidate = trial_metrics_df.sort_values("val_acc", ascending=False).query("train_acc > val_acc")[:10].sort_values('train_acc', ascending=False).query(f"val_acc == {max_val_acc}").query("train_acc >= 0.7").sort_values("train_val_diff", ascending=True)
best_candidate = best_candidate.sort_values("train_acc", ascending=False)
display(best_candidate)
best_candidate_trial_id = best_candidate.index[0]

rpprint({ k: v for k, v in study.best_trial.params.items() if not k.startswith("channels") })
rprint("test_accuracy =", study.trials[best_candidate_trial_id].user_attrs["trial_data"]["test_accuracy"])
rprint("val_accuracy =", np.max(study.trials[best_candidate_trial_id].user_attrs["trial_data"]["val_accuracy"]))
rprint("channels_selected =", study.trials[best_candidate_trial_id].user_attrs["trial_data"]["channels_selected"])

Unnamed: 0,train_acc,test_acc,val_acc,train_val_diff,scores,channels_selected
12,0.897059,0.818182,0.833333,0.063726,0.189693,"[P3, C3, F3, C4, P4, Pz, Fp1, T5, O2, T6, T4]"
20,0.852941,0.454545,0.833333,0.019608,0.192778,"[P3, Fz, F4, C4, Pz, Fp1, O1, O2, F7, F8, T6, T4]"


In [139]:
# max_val_acc = trial_metrics_df.sort_values("train_val_diff", ascending=True).query("train_acc >= val_acc")[:15]["val_acc"].max()
# # max_val_acc = trial_metrics_df.sort_values("train_val_diff", ascending=True)[:15]["val_acc"].max()
# best_candidate = trial_metrics_df.sort_values("train_acc", ascending=False).query(f"val_acc == {max_val_acc}")
# display(best_candidate)
# best_candidate_trial_id = best_candidate.index[0]

# rpprint({ k: v for k, v in study.best_trial.params.items() if not k.startswith("channels") })
# rprint("test_accuracy =", study.trials[best_candidate_trial_id].user_attrs["trial_data"]["test_accuracy"])
# rprint("val_accuracy =", np.max(study.trials[best_candidate_trial_id].user_attrs["trial_data"]["val_accuracy"]))
# rprint("channels_selected =", study.trials[best_candidate_trial_id].user_attrs["trial_data"]["channels_selected"])

In [338]:
# Save to disk
np.save(f"./best_subjects_search_results/subject_12_shallow_conv_net_{uuid.uuid4()}.npy", study.trials[best_candidate_trial_id])

In [340]:
rpprint({ k: v for k, v in study.best_trial.params.items() if not k.startswith("channels") })
rprint("test_accuracy =", study.best_trial.user_attrs["trial_data"]["test_accuracy"])
rprint("val_accuracy =", np.max(study.best_trial.user_attrs["trial_data"]["val_accuracy"]))
rprint("channels_selected =", study.best_trial.user_attrs["trial_data"]["channels_selected"])

In [375]:
subject_data = study.trials[best_candidate_trial_id]
subject_trial_data = subject_data.user_attrs['trial_data']
subject_model_json = subject_trial_data['model']
subject_model = keras.models.model_from_json(subject_model_json, custom_objects=CUSTOM_OBJECTS)

def channels_to_channels_idx(channels, dataset):
    all_channels = dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1]
    channels_dict = { k: (True if k in channels else False) for k in all_channels }
    channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    return channels_idx

subject_channels_idx = channels_to_channels_idx(subject_trial_data['channels_selected'], fat_dataset)
subject_sfreq = subject_trial_data['sfreq'] if 'sfreq' in subject_trial_data else 128
subject_data_hyper_params = { k: v for k, v in subject_data.params.items() if not k.startswith('channels_') }

X, y, _ = data_generator(fat_dataset, subjects=[12], channel_idx=subject_channels_idx, sfreq=subject_sfreq)
y_encoded = LabelEncoder().fit_transform(y)

# X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=SKLRNG, shuffle=True, stratify=y_encoded)

subject_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
subject_model.evaluate(X, y_encoded)

Adding metadata with 3 columns
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 705ms/step - accuracy: 0.5016 - loss: 16.4667


[16.48405647277832, 0.5092592835426331]

In [None]:
# channels = fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1]
# channels_of_interest = ['C3', 'F3', 'C4', 'T5', 'O1', 'O2', 'A2', 'T6']
# channels_of_interest_idx = np.where(np.isin(channels, channels_of_interest))[0]

# channels_of_interest_idx

In [None]:
# Model/pipeline hyperparameter optimization test
data_store_dict = {}

def objective_fn_lstm_cnn_net(trial, **kwargs):
    # _SFREQ_ = trial.suggest_categorical("sfreq", [128, 250])
    _SFREQ_ = 300
    # _TRAIN_SIZE_ = trial.suggest_categorical("train_size", [0.8, 0.9])
    _TRAIN_SIZE_ = 0.8
    _TEST_SIZE_ = 1 - _TRAIN_SIZE_
    _BATCH_SIZE = 256
    
    channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
    channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    while len(channels_idx) == 0:
        channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
        channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    
    X, y, _ = data_generator(fat_dataset, subjects=[6], channel_idx=channels_idx, sfreq=_SFREQ_)
    
    _NUM_SAMPLES_ = X.shape[-1]
    _NUM_CHANNELS_ = X.shape[-2]
    # rpprint({
    #     "NUM_SAMPLES": _NUM_SAMPLES_,
    #     "NUM_CHANNELS": _NUM_CHANNELS_,
    #     "X": X.shape,
    #     "y": y.shape,
    #     "channels selected": np.asarray(fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'])[channels_idx]
    # })
    y_encoded = LabelEncoder().fit_transform(y)
    _NUM_CLASSES_ = len(np.unique(y_encoded))
    X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, _NUM_SAMPLES_, _NUM_CHANNELS_), y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
    # X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
    
    model_params = {
        "nb_classes": _NUM_CLASSES_,
        "channels": _NUM_CHANNELS_,
        "samples": _NUM_SAMPLES_,
    }
    
    model = lstm_cnn_net_v2(**model_params)
    # model = lstm_cnn_net(**model_params)
    # model = shallow_conv_net(**model_params)
    model.compile(
        loss="sparse_categorical_crossentropy", 
        optimizer="adam", 
        metrics=["accuracy"]
    )
    
    history = model.fit(
        X_train,
        y_train,
        batch_size=_BATCH_SIZE,
        epochs=25,
        validation_split=0.2,
        callbacks=[
            # keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
            # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
            keras.callbacks.EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True),
            keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
        ],
    )

    trial.set_user_attr("trial_data", {
        "channels_selected": np.asarray(fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'])[channels_idx],
        # "history": history,
        # "model": model,
        "val_accuracy": history.history["val_accuracy"],
        "test_accuracy": model.evaluate(X_test, y_test, batch_size=_BATCH_SIZE)[1]
    })

    # return (1 - np.mean(history.history["val_accuracy"][-2:])) + 0.005*len(channels_idx)
    return (1 - history.history["val_accuracy"][-1])**2 + 0.00005*len(channels_idx)

def heuristic_optimizer(obj_fn = lambda: None, max_iter = 15, show_progress_bar = True, **kwargs):
    
    optuna.logging.set_verbosity(optuna.logging.CRITICAL)

    # trial_data = []
    # trial_callback = lambda study, trial: trial_data.append(trial.user_attrs["trial_data"])
    
    # study = optuna.create_study(direction="minimize", sampler=optuna.samplers.CmaEsSampler())
    study = optuna.create_study(direction="minimize", sampler=optuna.samplers.NSGAIISampler())
    study.optimize(obj_fn, n_trials=max_iter, show_progress_bar=show_progress_bar, callbacks=[], **kwargs)
    
    return study

study = heuristic_optimizer(obj_fn = objective_fn_lstm_cnn_net)

In [312]:
rpprint({ k: v for k, v in study.best_trial.params.items() if not k.startswith("channels") })
rprint("test_accuracy =", study.best_trial.user_attrs["trial_data"]["test_accuracy"])
rprint("val_accuracy =", np.max(study.best_trial.user_attrs["trial_data"]["val_accuracy"]))
rprint("channels_selected =", study.best_trial.user_attrs["trial_data"]["channels_selected"])

In [None]:
# Save model
model.save("model.keras")

In [None]:
# Load model from disk
loaded_model = keras.saving.load_model("model.keras", custom_objects=CUSTOM_OBJECTS)

In [None]:
## Pruning and quantization test

from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper

baseline_model = loaded_model
pruning_params = {
    'pruning_schedule': pruning_wrapper.PolynomialDecay(initial_sparsity=0.00,
                                                        final_sparsity=0.80,
                                                        begin_step=0,
                                                        end_step=100)
}
sparse_model = pruning_wrapper.PruneLowMagnitude(baseline_model, **pruning_params)

sparse_model.summary()

### Data exploration, etc.

In [None]:
def create_spectrogram(epoch_data):
    spectrograms = []
    for epoch in epoch_data:
        epoch_spectrograms = []
        for channel in epoch:
            f, t, Sxx = sp.signal.spectrogram(channel, fs=50)
            epoch_spectrograms.append(Sxx)
        spectrograms.append(np.array(epoch_spectrograms))
    return np.array(spectrograms)

spectograms = create_spectrogram(X)
rprint(spectograms.shape) # (180, 21, 129, 2)

num_left_hand_epochs_to_visualize = 2
num_right_hand_epochs_to_visualize = 2

# Select epochs for visualization
left_hand_epochs_to_visualize = np.where((y == 'left_hand'))[0]
right_hand_epochs_to_visualize = np.where((y == 'right_hand'))[0]
left_hand_epochs_to_visualize = np.random.choice(left_hand_epochs_to_visualize, num_left_hand_epochs_to_visualize)
right_hand_epochs_to_visualize = np.random.choice(right_hand_epochs_to_visualize, num_right_hand_epochs_to_visualize)
epochs_to_visualize = np.concatenate((left_hand_epochs_to_visualize, right_hand_epochs_to_visualize))

rprint(epochs_to_visualize)

# Visualize spectrograms
fig, axs = plt.subplots(epochs_to_visualize.size, 1, figsize=(10, 10))
for i, idx in enumerate(epochs_to_visualize):
    axs[i].matshow(spectograms[idx, :, :, 0], cmap='viridis')
    axs[i].set_title(y[idx])
    axs[i].set_xlabel('Frequency (Hz)')
    axs[i].set_ylabel('Time (s)')
    axs[i].set_xlim(0, 60) # Limit x-axis to 60Hz

plt.tight_layout()
plt.show()