In [None]:
# General imports
import os
import time
import random
import datetime
import glob
import pickle
import tqdm
import copy
import optuna
import numpy as np
import pandas as pd
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import mne
from rich import print as rprint
from rich.pretty import pprint as rpprint
from tqdm import tqdm
from itertools import chain
from functools import partial

# JAX + Keras
os.environ["KERAS_BACKEND"] = "jax"
os.environ["TF_USE_LEGACY_KERAS"] = "0"
import jax
import jax.numpy as jnp
import keras
from keras.models import Model
from keras.layers import Dense, Activation, Permute, Dropout
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv1D, MaxPooling1D, AveragePooling1D
from keras.layers import SeparableConv2D, DepthwiseConv2D
from keras.layers import BatchNormalization
from keras.layers import SpatialDropout2D
from keras.regularizers import l1_l2
from keras.layers import Input, Flatten
from keras.constraints import max_norm
from keras import backend as K

# Sklearn
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout, TimeDistributed, BatchNormalization
import h5py
from sklearn import preprocessing

# Dataset
from custom_datasets.fatigue_mi import FatigueMI

%load_ext autoreload
%autoreload 3

In [None]:
SKLRNG = 42
RNG = jax.random.PRNGKey(SKLRNG)

In [None]:
def data_generator(dataset, subjects = [1], channel_idx = [], filters = ([8, 32],), sfreq = 250):

    find_events = lambda raw, event_id: mne.find_events(raw, shortest_event=0, verbose=False) if len(mne.utils._get_stim_channel(None, raw.info, raise_error=False)) > 0 else mne.events_from_annotations(raw, event_id=event_id, verbose=False)[0]
    
    data = dataset.get_data(subjects=subjects)
    
    X = []
    y = []
    metadata = []

    for subject_id in data.keys():
        for session_id in data[subject_id].keys():
            for run_id in data[subject_id][session_id].keys():
                raw = data[subject_id][session_id][run_id]
                
                for fmin, fmax in filters:
                    raw = raw.filter(l_freq = fmin, h_freq = fmax, method = 'iir', picks = 'eeg', verbose = False)
                
                events = find_events(raw, dataset.event_id)

                tmin = dataset.interval[0]
                tmax = dataset.interval[1]

                channels = np.asarray(raw.info['ch_names'])[channel_idx] if len(channel_idx) > 0 else np.asarray(raw.info['ch_names'])

                # rpprint(channels)
                
                stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
                picks = mne.pick_channels(raw.info["ch_names"], include=channels, exclude=stim_channels, ordered=True)

                x = mne.Epochs(
                    raw,
                    events,
                    event_id=dataset.event_id,
                    tmin=tmin,
                    tmax=tmax,
                    proj=False,
                    baseline=None,
                    preload=True,
                    verbose=False,
                    picks=picks,
                    event_repeated="drop",
                    on_missing="ignore",
                )
                x_events = x.events
                inv_events = {k: v for v, k in dataset.event_id.items()}
                labels = [inv_events[e] for e in x_events[:, -1]]

                # rpprint({
                #     "X": np.asarray(x.get_data(copy=False)).shape,
                #     "y": np.asarray(labels).shape,
                #     "channels selected": np.asarray(raw.info['ch_names'])[channel_idx]
                # })

                # x.plot(scalings="auto")
                # display(x.info)
                
                x_resampled = x.resample(sfreq) # Resampler_Epoch
                x_resampled_data = x_resampled.get_data(copy=False) # Convert_Epoch_Array
                x_resampled_data_standard_scaler = np.asarray([
                    StandardScaler().fit_transform(x_resampled_data[i])
                    for i in np.arange(x_resampled_data.shape[0])
                ]) # Standard_Scaler_Epoch

                # x_resampled.plot(scalings="auto")
                # display(x_resampled.info)

                n = x_resampled_data_standard_scaler.shape[0]
                # n = x.get_data(copy=False).shape[0]
                met = pd.DataFrame(index=range(n))
                met["subject"] = subject_id
                met["session"] = session_id
                met["run"] = run_id
                x.metadata = met.copy()
                
                # X.append(x_resampled_data_standard_scaler)
                X.append(x)
                y.append(labels)
                metadata.append(met)

    return np.concatenate(X, axis=0), np.concatenate(y), pd.concat(metadata, ignore_index=True)

fat_dataset = FatigueMI()
# X, y, _ = data_generator(fat_dataset, subjects=fat_dataset.subject_list, sfreq=250)
X, y, _ = data_generator(fat_dataset, subjects=[5], channel_idx=[], sfreq=128)
# X, y, _ = data_generator(fat_dataset, subjects=[6], channel_idx=[1, 2, 5, 12, 13, 14, 17, 18], sfreq=128)
# X, y, _ = data_generator(fat_dataset, subjects=[5, 6, 12, 17, 24, 28, 29], channel_idx=[], sfreq=128)

In [None]:
y_encoded = LabelEncoder().fit_transform(y) # 0 = left, 1 = right, 2 = unlabelled (!NOTE: should be ignored for binary classification)
# rpprint(y_encoded)
rprint({
    "X": X.shape,
    "y": y_encoded.shape
})

NUM_SAMPLES = X.shape[-1]
NUM_CHANNELS = X.shape[-2]
NUM_CLASSES = len(np.unique(y_encoded))

rpprint({
    "NUM_SAMPLES": NUM_SAMPLES,
    "NUM_CHANNELS": NUM_CHANNELS,
    "NUM_CLASSES": NUM_CLASSES
})

# sns.histplot(y_encoded); # Plot the class distribution

In [None]:
TRAIN_SIZE = 0.8
TEST_SIZE = 0.2

# X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, NUM_SAMPLES, NUM_CHANNELS), y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=None, shuffle=False)
# X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, NUM_SAMPLES, NUM_CHANNELS), y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=y_encoded, shuffle=True)
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=y_encoded, shuffle=True)
# X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=None, shuffle=False)

total = X_train.shape[0] + X_test.shape[0]
rpprint({
    "X_train": f"{X_train.shape[0]} ({X_train.shape[0] / total * 100:.2f}%)",
    "X_test": f"{X_test.shape[0]} ({X_test.shape[0] / total * 100:.2f}%)",
}, expand_all=True)

# batch_size = 32
# batches = batch_generator(X_train, batch_size, RNG)

In [None]:
# region Helper funcs
def shallow_conv_net_square_layer(x):
    return jnp.square(x)

def shallow_conv_net_log_layer(x):
    return jnp.log(jnp.clip(x, 1e-7, 10000))

CUSTOM_OBJECTS = {
    "shallow_conv_net_square_layer": shallow_conv_net_square_layer, 
    "shallow_conv_net_log_layer": shallow_conv_net_log_layer 
}
# endregion Helper funcs

# region Models
def eeg_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, dropout_rate = 0.5, 
            kernLength = 64, F1 = 8, D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):
    """
        From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """

    dropoutType = {'Dropout': Dropout, 'SpatialDropout2D': SpatialDropout2D}[dropoutType]

    input1   = Input(shape = (channels, samples, 1))

    ##################################################################
    block1       = Conv2D(F1, (1, kernLength), padding = 'same',
                                   input_shape = (channels, samples, 1),
                                   use_bias = False)(input1)
    block1       = BatchNormalization()(block1)
    block1       = DepthwiseConv2D((channels, 1), use_bias = False, 
                                   depth_multiplier = D,
                                   depthwise_constraint = max_norm(1.))(block1)
    block1       = BatchNormalization()(block1)
    block1       = Activation('elu')(block1)
    block1       = AveragePooling2D((1, 4))(block1)
    block1       = dropoutType(dropout_rate)(block1)
    
    block2       = SeparableConv2D(F2, (1, 16), use_bias = False, padding = 'same')(block1)
    block2       = BatchNormalization()(block2)
    block2       = Activation('elu')(block2)
    block2       = AveragePooling2D((1, 8))(block2)
    block2       = dropoutType(dropout_rate)(block2)
        
    flatten      = Flatten(name = 'flatten')(block2)
    
    dense        = Dense(nb_classes, name = 'dense', 
                         kernel_constraint = max_norm(norm_rate))(flatten)
    softmax      = Activation('softmax', name = 'softmax')(dense)
    
    return Model(inputs=input1, outputs=softmax)

def deep_conv_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, dropout_rate = 0.5):
    """
        From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """
    input_main   = Input((channels, samples, 1))
    block1       = Conv2D(25, (1, 5), 
                                 input_shape=(channels, samples, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(input_main)
    block1       = Conv2D(25, (channels, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block1       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
    block1       = Activation('elu')(block1)
    block1       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block1)
    block1       = Dropout(dropout_rate)(block1)
  
    block2       = Conv2D(50, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block2       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block2)
    block2       = Activation('elu')(block2)
    block2       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block2)
    block2       = Dropout(dropout_rate)(block2)
    
    block3       = Conv2D(100, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block2)
    block3       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block3)
    block3       = Activation('elu')(block3)
    block3       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block3)
    block3       = Dropout(dropout_rate)(block3)
    
    block4       = Conv2D(200, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block3)
    block4       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block4)
    block4       = Activation('elu')(block4)
    block4       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block4)
    block4       = Dropout(dropout_rate)(block4)
    
    flatten      = Flatten()(block4)
    
    dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
    softmax      = Activation('softmax')(dense)
    
    return Model(inputs=input_main, outputs=softmax)

def shallow_conv_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, **kwargs):
    """
        From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """

    _POOL_SIZE_D2_ = kwargs.get("pool_size_d2", 35)
    _STRIDES_D2_ = kwargs.get("strides_d2", 7)
    _CONV_FILTERS_D2_ = kwargs.get("conv_filters_d2", 13)

    _POOL_SIZE_ = kwargs.get("pool_size", (1, _POOL_SIZE_D2_))
    _STRIDES_ = kwargs.get("strides", (1, _STRIDES_D2_))
    _CONV_FILTERS_ = kwargs.get("conv_filters", (1, _CONV_FILTERS_D2_))

    _CONV2D_1_UNITS_ = kwargs.get("conv2d_1_units", 40)
    _CONV2D_2_UNITS_ = kwargs.get("conv2d_2_units", 40)
    _L2_REG_1_ = kwargs.get("l2_reg_1", 0.01)
    _L2_REG_2_ = kwargs.get("l2_reg_2", 0.01)
    _L2_REG_3_ = kwargs.get("l2_reg_3", 0.01)
    _DROPOUT_RATE_ = kwargs.get("dropout_rate", 0.5)

    input_main   = Input(shape=(channels, samples, 1))
    block1       = Conv2D(_CONV2D_1_UNITS_, _CONV_FILTERS_,
                                 input_shape=(channels, samples, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)), kernel_regularizer=keras.regularizers.L2(_L2_REG_1_))(input_main)
    # block1       = Conv2D(40, (channels, 1), use_bias=False, 
    #                       kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block1       = Conv2D(_CONV2D_2_UNITS_, (channels, 1), use_bias=False, 
                          kernel_constraint = max_norm(2., axis=(0,1,2)), kernel_regularizer=keras.regularizers.L2(_L2_REG_2_))(block1)
    block1       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
    block1       = Activation(shallow_conv_net_square_layer)(block1)
    block1       = AveragePooling2D(pool_size=_POOL_SIZE_, strides=_STRIDES_)(block1)
    block1       = Activation(shallow_conv_net_log_layer)(block1)
    block1       = Dropout(_DROPOUT_RATE_)(block1)
    flatten      = Flatten()(block1)
    # dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
    dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5), kernel_regularizer=keras.regularizers.L2(_L2_REG_3_))(flatten)
    softmax      = Activation('softmax')(dense)
    
    return Model(inputs=input_main, outputs=softmax)

def lstm_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES):
    model = Sequential()
    model.add(Input(shape=(samples, channels)))
    model.add(LSTM(40, return_sequences=True, stateful=False))
    # model.add(LSTM(40, return_sequences=True, stateful=False, batch_input_shape=(batch_size, timesteps, data_dim)))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    model.add(LSTM(40, return_sequences=True, stateful=False))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    model.add(LSTM(40, stateful=False))
    model.add(Dropout(0.5))
    # model.add(TimeDistributed(Dense(T_train)))
    model.add(Dense(50,activation='softmax'))
    model.add(Dense(nb_classes, activation='softmax'))
    
    return model

def lstm_cnn_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, dropout_rate_1 = 0.5, dropout_rate_2 = 0.5):
    model = Sequential()
    model.add(Input(shape=(samples, channels)))

    # add 1-layer cnn
    # model.add(Conv1D(40, kernel_size=20, strides=4))
    model.add(Conv1D(40, kernel_size=20, strides=4))
    # model.add(Conv1D(40, kernel_size=20, strides=4, kernel_regularizer=keras.regularizers.L2(0.001), activity_regularizer=keras.regularizers.L2(0.001)))
    model.add(Activation('relu'))
    model.add(Dropout(dropout_rate_1))
    model.add(BatchNormalization())
    model.add(MaxPooling1D(pool_size=4, strides=4))


    # add 1-layer lstm
    # model.add(LSTM(50, return_sequences=True, stateful=False))
    model.add(LSTM(128, return_sequences=True, stateful=False))
    # model.add(LSTM(50, return_sequences=True, stateful=False, kernel_regularizer=keras.regularizers.L2(0.01)))
    # model.add(LSTM(128, return_sequences=True, stateful=False))
    # model.add(LSTM(32, return_sequences=True, stateful=False))
    model.add(Dropout(dropout_rate_2))
    model.add(BatchNormalization())
    model.add(Flatten())
    model.add(Dense(nb_classes, activation='softmax'))
    # model.add(Dense(nb_classes, activation='softmax', kernel_regularizer=keras.regularizers.L2(0.01)))
    
    return model

def lstm_cnn_net_v2(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES):
    model = Sequential()
    model.add(Input(shape=(samples, channels)))

    # Add 1-layer LSTM
    model.add(LSTM(50, return_sequences=True, stateful=False))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    # Add 1-layer CNN
    model.add(Conv1D(40, kernel_size=20, strides=4))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    model.add(MaxPooling1D(pool_size=4, strides=4))
    model.add(Flatten())
    # Fully connected layer for classification
    model.add(Dense(nb_classes, activation='softmax'))
    return model

# endregion Models

In [None]:
NUM_FOLDS = 5
kfold = KFold(n_splits=NUM_FOLDS, shuffle=True)
acc_per_fold = []
loss_per_fold = []

model = shallow_conv_net()
# model = lstm_cnn_net()

# K-fold Cross Validation model evaluation
fold_no = 1
for train, test in kfold.split(X_train, y_train):
        # Fit data to model
        model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
        
        history = model.fit(
                X_train[train],
                y_train[train],
                batch_size=64,
                epochs=100,
                validation_data=(X_train[test], y_encoded[test]),
                # validation_split=0.2,
                shuffle=True,
                callbacks=[
                        keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
                        keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
                        # keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
                        # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
                ],
        )
        
        # Generate generalization metrics
        scores = model.evaluate(X_train[test], y_train[test], verbose=0)
        rprint(f'Score for fold {fold_no}: {model.metrics_names[0]} of {scores[0]}; {model.metrics_names[1]} of {scores[1]*100}%')
        acc_per_fold.append(scores[1] * 100)
        loss_per_fold.append(scores[0])
        
        # Increase fold number
        fold_no = fold_no + 1

In [None]:
mean_cross_val_score = np.mean(acc_per_fold)
rprint("Mean Cross Validation Score (5-fold):", mean_cross_val_score)

results_test = model.evaluate(X_test, y_test)
rprint("Test Accuracy:", results_test[1])

In [None]:
# results_val = model.evaluate(X_val, y_val)
results_test = model.evaluate(X_test, y_test)

# X_test.reshape(-1, NUM_SAMPLES, NUM_CHANNELS).shape

In [None]:
model = shallow_conv_net()
# model = lstm_cnn_net()
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

history = model.fit(
    X_train,
    y_train,
    batch_size=128,
    epochs=25,
    # validation_data=(X_test, y_test),
    validation_split=0.2,
    shuffle=True,
    callbacks=[
        keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
        keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
        # keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
        # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
    ],
)

In [None]:
# model.history.history.keys()

In [25]:
def objective_fn(trial, subjects=[], model_str = "lstm_cnn_net", sfreq = 128, batch_size = 128, **kwargs):
    _SFREQ_ = sfreq if sfreq else trial.suggest_categorical("sfreq", [128, 256, 300])
    _TRAIN_SIZE_ = 0.8
    _TEST_SIZE_ = 1 - _TRAIN_SIZE_
    _BATCH_SIZE = batch_size if batch_size else trial.suggest_int("batch_size", 32, 256, step=32)

    models = {
        "lstm_cnn_net": lstm_cnn_net,
        "shallow_conv_net": shallow_conv_net
    }
    
    subjects = subjects if len(subjects) > 0 else fat_dataset.subject_list
    model_fn = models[model_str]
    
    channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
    channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    while len(channels_idx) == 0:
        channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
        channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    
    X, y, _ = data_generator(fat_dataset, subjects=subjects, channel_idx=channels_idx, sfreq=_SFREQ_)
    
    _NUM_SAMPLES_ = X.shape[-1]
    _NUM_CHANNELS_ = X.shape[-2]
    
    y_encoded = LabelEncoder().fit_transform(y)
    _NUM_CLASSES_ = len(np.unique(y_encoded))
    if "lstm" in model_str:
        X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, _NUM_SAMPLES_, _NUM_CHANNELS_), y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
    else:
        X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, shuffle=True, stratify=y_encoded)
    # X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, _NUM_SAMPLES_, _NUM_CHANNELS_), y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
        
    # y_train = keras.utils.to_categorical(y_train)
    # y_test = keras.utils.to_categorical(y_test)
    
    model_params = {
        "nb_classes": _NUM_CLASSES_,
        "channels": _NUM_CHANNELS_,
        "samples": _NUM_SAMPLES_,
        
        "pool_size_d2": trial.suggest_int("pool_size_d2", 15, 75, step=5),
        "strides_d2": trial.suggest_int("strides_d2", 3, 31, step=2),
        "conv_filters_d2": trial.suggest_int("conv_filters_d2", 5, 35, step=2),
        
        "conv2d_1_units": trial.suggest_int("conv2d_1_units", 10, 200, step=10),
        "conv2d_2_units": trial.suggest_int("conv2d_2_units", 10, 200, step=10),
        "l2_reg_1": trial.suggest_float("l2_reg_1", 0.001, 0.5),
        "l2_reg_2": trial.suggest_float("l2_reg_2", 0.001, 0.5),
        "l2_reg_3": trial.suggest_float("l2_reg_3", 0.001, 0.5),
        "dropout_rate": trial.suggest_float("dropout_rate", 0.1, 0.9, step=0.1),
    }
    
    model = model_fn(**model_params)
    model.compile(
        loss="sparse_categorical_crossentropy", 
        optimizer="adam", 
        metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")]
    )
    
    history = model.fit(
        X_train,
        y_train,
        batch_size=_BATCH_SIZE,
        epochs=5,
        validation_split=0.2,
        shuffle=True,
        callbacks=[
            # keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
            # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
            keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True),
            keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
        ],
    )
    
    # Rename history.history["sparse_categorical_accuracy"] to history.history["accuracy"]
    # history.history["accuracy"] = history.history["sparse_categorical_accuracy"]
    # history.history["val_accuracy"] = history.history["val_sparse_categorical_accuracy"]
    
    trial.set_user_attr("trial_data", {
        "channels_selected": np.asarray(fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'])[channels_idx],
        # "history": history,
        "model": model.to_json(),
        "train_accuracy": history.history["accuracy"],
        "val_accuracy": history.history["val_accuracy"],
        "test_accuracy": model.evaluate(X_test, y_test, batch_size=_BATCH_SIZE)[1]
    })

    # cost = (1 - np.mean(history.history["val_accuracy"]))**2 + 1e-2*(1 - np.max(history.history["val_accuracy"]))**2 + 1e-4*len(channels_idx)
    # cost = (1 - np.max(history.history["val_accuracy"]))**2 + 8e-5*len(channels_idx) #+ 1e-3*(np.mean(history.history["accuracy"][-2:]) - np.mean(history.history["val_accuracy"][-2:]))**2
    # cost = (1 - np.max(history.history["val_accuracy"]))**2 + 0.25*(1 - np.mean(history.history["val_accuracy"]))**2 + 0.25*(1 - history.history["val_accuracy"][-1])**2 + 8e-5*len(channels_idx) #+ 1e-3*np.mean(np.asarray(history.history["accuracy"]) - np.asarray(history.history["val_accuracy"]))**2
    cost = (1 - np.max(history.history["val_accuracy"]))**2 + 1e-5*len(channels_idx) + 1e-4*np.mean(np.asarray(history.history["accuracy"]) - np.asarray(history.history["val_accuracy"]))**2 + 1e-5*np.mean(1 - np.max(history.history["accuracy"]))**2
    if not 0.70 <= np.max(history.history["accuracy"]) <= 1.00 or len(channels_idx) < 5:
        cost += 1

    # cost = (1 - np.mean(history.history["val_accuracy"][-2:])) + 0.005*len(channels_idx)
    # cost = (1 - np.mean(history.history["val_accuracy"]))**2 + 2.5e-1*(np.mean(history.history["accuracy"]) - np.mean(history.history["val_accuracy"]))**2 + 1.5e-1*(1 - np.max(history.history["val_accuracy"]))**2 + 5e-2*len(channels_idx)
    # cost = (1 - np.mean(history.history["val_accuracy"]))**2 + 2.5e-1*(np.mean(history.history["accuracy"]) - np.mean(history.history["val_accuracy"]))**2 + 5e-2*len(channels_idx)
    return cost

def heuristic_optimizer(obj_fn, max_iter = 25, show_progress_bar = True, subjects=[], model_str="lstm_cnn_net", sfreq = 128, batch_size = 128, **kwargs):
    
    optuna.logging.set_verbosity(optuna.logging.CRITICAL)

    objective = partial(obj_fn, subjects = subjects, model_str = model_str, sfreq = sfreq, batch_size = batch_size)
    
    study = optuna.create_study(direction="minimize", sampler=optuna.samplers.NSGAIIISampler())
    try:
        study.optimize(objective, n_trials=max_iter, show_progress_bar=show_progress_bar, callbacks=[], **kwargs)
    except KeyboardInterrupt:
        pass
    
    return study

study = heuristic_optimizer(obj_fn = objective_fn, model_str="shallow_conv_net", subjects=[1], sfreq = 128, batch_size = None, max_iter = 25)

  study = optuna.create_study(direction="minimize", sampler=optuna.samplers.NSGAIIISampler())


  0%|          | 0/25 [00:00<?, ?it/s]

Adding metadata with 3 columns


  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4706 - loss: 21.5412 - val_accuracy: 0.3889 - val_loss: 20.6021 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.6912 - loss: 20.4956 - val_accuracy: 0.4444 - val_loss: 20.3488 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.6471 - loss: 20.0398 - val_accuracy: 0.4444 - val_loss: 19.9266 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.6618 - loss: 19.5692 - val_accuracy: 0.4444 - val_loss: 19.3984 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.6912 - loss: 19.0288 - val_accuracy: 0.3333 - val_loss: 18.8728 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 382ms/step - accurac

  super().__init__(


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 1s/step - accuracy: 0.4619 - loss: 122.0442 - val_accuracy: 0.5556 - val_loss: 102.0166 - learning_rate: 0.0010
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.7960 - loss: 97.3265 - val_accuracy: 0.6111 - val_loss: 81.0662 - learning_rate: 0.0010
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.9051 - loss: 76.9452 - val_accuracy: 0.6111 - val_loss: 63.4293 - learning_rate: 0.0010
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.9428 - loss: 59.7839 - val_accuracy: 0.6111 - val_loss: 48.8345 - learning_rate: 0.0010
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 1.0000 - loss: 45.7210 - val_accuracy: 0.5556 - val_loss: 37.0436 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 544ms/step - accuracy: 0.545

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.4265 - loss: 38.0038 - val_accuracy: 0.5556 - val_loss: 32.8780 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.4853 - loss: 32.8680 - val_accuracy: 0.6111 - val_loss: 31.2903 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.6765 - loss: 31.2580 - val_accuracy: 0.6667 - val_loss: 29.7630 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.7941 - loss: 29.6965 - val_accuracy: 0.6111 - val_loss: 28.2869 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.7941 - loss: 28.2071 - val_accuracy: 0.6111 - val_loss: 26.8614 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 460ms/step - accurac

  super().__init__(


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 977ms/step - accuracy: 0.4237 - loss: 18.6565 - val_accuracy: 0.5556 - val_loss: 16.3940 - learning_rate: 0.0010
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.7114 - loss: 15.6662 - val_accuracy: 0.4444 - val_loss: 14.4509 - learning_rate: 0.0010
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.8869 - loss: 13.6377 - val_accuracy: 0.5000 - val_loss: 12.6513 - learning_rate: 0.0010
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.9584 - loss: 11.7782 - val_accuracy: 0.5000 - val_loss: 10.9822 - learning_rate: 0.0010
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 1.0000 - loss: 10.1146 - val_accuracy: 0.5556 - val_loss: 9.5065 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 391ms/step - accuracy: 0.590

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5441 - loss: 41.8548 - val_accuracy: 0.2778 - val_loss: 36.9012 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 51ms/step - accuracy: 0.6765 - loss: 36.7532 - val_accuracy: 0.3889 - val_loss: 35.4057 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 45ms/step - accuracy: 0.7500 - loss: 35.0004 - val_accuracy: 0.4444 - val_loss: 33.8333 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.7500 - loss: 33.3398 - val_accuracy: 0.4444 - val_loss: 32.2893 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.8824 - loss: 31.7037 - val_accuracy: 0.4444 - val_loss: 30.7502 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 495ms/step - accurac

  super().__init__(


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2s/step - accuracy: 0.5751 - loss: 10.7111 - val_accuracy: 0.6111 - val_loss: 9.4993 - learning_rate: 0.0010
Epoch 2/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8851 - loss: 9.3068 - val_accuracy: 0.6667 - val_loss: 8.7166 - learning_rate: 0.0010
Epoch 3/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8701 - loss: 8.5056 - val_accuracy: 0.7222 - val_loss: 7.9839 - learning_rate: 0.0010
Epoch 4/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.9400 - loss: 7.7589 - val_accuracy: 0.6667 - val_loss: 7.3009 - learning_rate: 0.0010
Epoch 5/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.8851 - loss: 7.0651 - val_accuracy: 0.6111 - val_loss: 6.6656 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 544ms/step - accuracy: 0.6364 - loss: 6

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.4559 - loss: 59.7056 - val_accuracy: 0.5556 - val_loss: 54.7336 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.6324 - loss: 54.6619 - val_accuracy: 0.5556 - val_loss: 52.4431 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - accuracy: 0.7059 - loss: 52.2978 - val_accuracy: 0.5556 - val_loss: 50.2026 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 67ms/step - accuracy: 0.8235 - loss: 49.9870 - val_accuracy: 0.5000 - val_loss: 48.0233 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8971 - loss: 47.8071 - val_accuracy: 0.5556 - val_loss: 45.9116 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 715ms/step - accurac

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5735 - loss: 9.8821 - val_accuracy: 0.4444 - val_loss: 8.9524 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 44ms/step - accuracy: 0.6176 - loss: 8.7874 - val_accuracy: 0.3333 - val_loss: 8.9310 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.7206 - loss: 8.3611 - val_accuracy: 0.5000 - val_loss: 8.5719 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.7500 - loss: 7.9843 - val_accuracy: 0.6667 - val_loss: 8.0829 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.8676 - loss: 7.5893 - val_accuracy: 0.6667 - val_loss: 7.6504 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 431ms/step - accuracy: 0.5455 - loss: 7.

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.3971 - loss: 75.5674 - val_accuracy: 0.4444 - val_loss: 68.1071 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - accuracy: 0.7206 - loss: 68.0679 - val_accuracy: 0.5556 - val_loss: 63.4800 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.8824 - loss: 63.3739 - val_accuracy: 0.5556 - val_loss: 59.0613 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.9118 - loss: 58.9292 - val_accuracy: 0.5556 - val_loss: 54.8591 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.9412 - loss: 54.6954 - val_accuracy: 0.6111 - val_loss: 50.8761 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 503ms/step - accuracy: 0.5455 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5882 - loss: 96.9187 - val_accuracy: 0.5000 - val_loss: 92.0253 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.5147 - loss: 91.8866 - val_accuracy: 0.5000 - val_loss: 87.5355 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.5000 - loss: 87.4479 - val_accuracy: 0.4444 - val_loss: 83.2455 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.5882 - loss: 83.2090 - val_accuracy: 0.3333 - val_loss: 79.1217 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.5441 - loss: 79.0646 - val_accuracy: 0.3333 - val_loss: 75.1497 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 421ms/step - accuracy: 0.6364 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4s/step - accuracy: 0.4853 - loss: 88.8212 - val_accuracy: 0.3333 - val_loss: 81.1678 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 55ms/step - accuracy: 0.7941 - loss: 81.0068 - val_accuracy: 0.5556 - val_loss: 74.7363 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 48ms/step - accuracy: 0.8382 - loss: 74.4976 - val_accuracy: 0.5556 - val_loss: 68.6507 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step - accuracy: 0.8676 - loss: 68.3480 - val_accuracy: 0.5556 - val_loss: 62.9348 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step - accuracy: 0.9265 - loss: 62.5502 - val_accuracy: 0.5556 - val_loss: 57.5922 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 596ms/step - accuracy: 0.5455 

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5441 - loss: 33.7050 - val_accuracy: 0.5556 - val_loss: 30.3549 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.5882 - loss: 30.2832 - val_accuracy: 0.5556 - val_loss: 29.0490 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.5735 - loss: 29.0092 - val_accuracy: 0.5000 - val_loss: 27.7664 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.5735 - loss: 27.6353 - val_accuracy: 0.5000 - val_loss: 26.5127 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.6176 - loss: 26.4040 - val_accuracy: 0.5556 - val_loss: 25.2947 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 449ms/step - accurac

  super().__init__(


Epoch 1/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 972ms/step - accuracy: 0.4892 - loss: 106.2254 - val_accuracy: 0.5556 - val_loss: 90.1136 - learning_rate: 0.0010
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.6434 - loss: 86.6046 - val_accuracy: 0.5556 - val_loss: 73.9640 - learning_rate: 0.0010
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.6360 - loss: 70.9033 - val_accuracy: 0.5556 - val_loss: 60.0399 - learning_rate: 0.0010
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.8263 - loss: 57.3660 - val_accuracy: 0.5556 - val_loss: 48.2616 - learning_rate: 0.0010
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.8068 - loss: 45.9366 - val_accuracy: 0.6111 - val_loss: 38.4934 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 470ms/step - acc

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.4706 - loss: 35.7102 - val_accuracy: 0.6667 - val_loss: 32.6885 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step - accuracy: 1.0000 - loss: 32.4381 - val_accuracy: 0.6111 - val_loss: 30.7468 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 1.0000 - loss: 30.4050 - val_accuracy: 0.6111 - val_loss: 28.8134 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 1.0000 - loss: 28.4084 - val_accuracy: 0.6111 - val_loss: 26.9272 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 1.0000 - loss: 26.4924 - val_accuracy: 0.6111 - val_loss: 25.1029 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 491ms/step - accuracy: 0.5455 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.6029 - loss: 102.6921 - val_accuracy: 0.3889 - val_loss: 95.5182 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 0.7500 - loss: 95.4283 - val_accuracy: 0.5000 - val_loss: 89.6248 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8382 - loss: 89.4649 - val_accuracy: 0.5000 - val_loss: 83.9815 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.9118 - loss: 83.7670 - val_accuracy: 0.5556 - val_loss: 78.5899 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.9853 - loss: 78.2934 - val_accuracy: 0.5556 - val_loss: 73.4496 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 424ms/step - accuracy: 0.3636

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5735 - loss: 12.5231 - val_accuracy: 0.4444 - val_loss: 11.2849 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 0.6176 - loss: 11.1694 - val_accuracy: 0.4444 - val_loss: 10.8030 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.6912 - loss: 10.6773 - val_accuracy: 0.3333 - val_loss: 10.3442 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.7794 - loss: 10.1829 - val_accuracy: 0.5000 - val_loss: 9.9015 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.7353 - loss: 9.7228 - val_accuracy: 0.4444 - val_loss: 9.4748 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 302ms/step - accuracy: 0.5000 - l

  super().__init__(


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 2s/step - accuracy: 0.5600 - loss: 24.4154 - val_accuracy: 0.5000 - val_loss: 21.6757 - learning_rate: 0.0010
Epoch 2/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.5901 - loss: 21.3481 - val_accuracy: 0.5556 - val_loss: 20.5868 - learning_rate: 0.0010
Epoch 3/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.6599 - loss: 20.2870 - val_accuracy: 0.5000 - val_loss: 19.5525 - learning_rate: 0.0010
Epoch 4/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.6599 - loss: 19.2094 - val_accuracy: 0.5556 - val_loss: 18.5636 - learning_rate: 0.0010
Epoch 5/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.6854 - loss: 18.2572 - val_accuracy: 0.5000 - val_loss: 17.6282 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 356ms/step - accuracy: 0.5455 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5441 - loss: 16.0417 - val_accuracy: 0.4444 - val_loss: 14.1185 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.4853 - loss: 13.8069 - val_accuracy: 0.4444 - val_loss: 13.7075 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.5000 - loss: 13.5367 - val_accuracy: 0.4444 - val_loss: 13.3543 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.4853 - loss: 13.2693 - val_accuracy: 0.4444 - val_loss: 13.0334 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.4706 - loss: 12.9628 - val_accuracy: 0.4444 - val_loss: 12.7314 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 329ms/step - accuracy: 0.5000 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5294 - loss: 93.9075 - val_accuracy: 0.2778 - val_loss: 85.0478 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 0.7500 - loss: 84.9021 - val_accuracy: 0.3333 - val_loss: 79.3893 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8971 - loss: 79.1251 - val_accuracy: 0.5000 - val_loss: 73.9817 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 1.0000 - loss: 73.6303 - val_accuracy: 0.6111 - val_loss: 68.8410 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 1.0000 - loss: 68.4282 - val_accuracy: 0.6111 - val_loss: 63.9603 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 444ms/step - accuracy: 0.5455 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5588 - loss: 13.2459 - val_accuracy: 0.5000 - val_loss: 9.4457 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.6912 - loss: 9.3345 - val_accuracy: 0.4444 - val_loss: 9.1797 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.7941 - loss: 8.8953 - val_accuracy: 0.3889 - val_loss: 8.8173 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.8235 - loss: 8.4735 - val_accuracy: 0.4444 - val_loss: 8.4034 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.8529 - loss: 8.0682 - val_accuracy: 0.5000 - val_loss: 7.9979 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 636ms/step - accuracy: 0.5000 - loss: 7

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5882 - loss: 91.8599 - val_accuracy: 0.6111 - val_loss: 86.2221 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.6765 - loss: 86.1690 - val_accuracy: 0.5000 - val_loss: 81.2752 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.7647 - loss: 81.1277 - val_accuracy: 0.6111 - val_loss: 76.5378 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 0.8235 - loss: 76.3040 - val_accuracy: 0.5556 - val_loss: 72.0122 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.8676 - loss: 71.6906 - val_accuracy: 0.5556 - val_loss: 67.6850 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 467ms/step - accuracy: 0.5000 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.4412 - loss: 58.0935 - val_accuracy: 0.4444 - val_loss: 54.5463 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 0.6324 - loss: 54.4960 - val_accuracy: 0.5000 - val_loss: 51.2953 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.7059 - loss: 51.1611 - val_accuracy: 0.3889 - val_loss: 48.1264 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.7206 - loss: 47.9648 - val_accuracy: 0.5000 - val_loss: 45.0816 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.7206 - loss: 44.9070 - val_accuracy: 0.5556 - val_loss: 42.1927 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 753ms/step - accuracy: 0.6364 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5147 - loss: 7.2262 - val_accuracy: 0.5000 - val_loss: 6.7048 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 41ms/step - accuracy: 0.8382 - loss: 6.2946 - val_accuracy: 0.4444 - val_loss: 6.5965 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.8676 - loss: 6.0245 - val_accuracy: 0.5000 - val_loss: 6.3023 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 0.9118 - loss: 5.7249 - val_accuracy: 0.5000 - val_loss: 5.9818 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.9559 - loss: 5.4622 - val_accuracy: 0.6111 - val_loss: 5.7076 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 489ms/step - accuracy: 0.6364 - loss: 5.

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.6029 - loss: 69.6781 - val_accuracy: 0.3889 - val_loss: 65.6108 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.4118 - loss: 65.7044 - val_accuracy: 0.4444 - val_loss: 62.1552 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.6029 - loss: 62.0923 - val_accuracy: 0.4444 - val_loss: 58.7973 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.6618 - loss: 58.7193 - val_accuracy: 0.5000 - val_loss: 55.5492 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.7500 - loss: 55.3889 - val_accuracy: 0.5000 - val_loss: 52.4204 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 785ms/step - accurac

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.4412 - loss: 32.5166 - val_accuracy: 0.6111 - val_loss: 28.2603 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.6176 - loss: 28.2201 - val_accuracy: 0.5556 - val_loss: 26.5349 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.7059 - loss: 26.4641 - val_accuracy: 0.5556 - val_loss: 24.8691 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.7500 - loss: 24.7873 - val_accuracy: 0.5556 - val_loss: 23.2678 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.7794 - loss: 23.1533 - val_accuracy: 0.5556 - val_loss: 21.7369 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 802ms/step - accuracy: 0.4545 

In [26]:
trial_metrics_dict = {
    "train_acc": [],
    "test_acc": [],
    "val_acc": [],
    "scores": [],
    "channels_selected": [],
}
for i, trial in enumerate(study.trials_dataframe().itertuples()):
    trial_user_attrs = trial.user_attrs_trial_data
    trial_metrics_dict["scores"].append(trial.value)
    trial_metrics_dict["train_acc"].append(np.max(trial_user_attrs['train_accuracy']))
    trial_metrics_dict["test_acc"].append(trial_user_attrs['test_accuracy'])
    trial_metrics_dict["val_acc"].append(np.max(trial_user_attrs['val_accuracy']))
    trial_metrics_dict["channels_selected"].append(trial_user_attrs['channels_selected'])
trial_metrics_df = pd.DataFrame(trial_metrics_dict)
# trial_metrics_df.sort_values("test_acc", ascending=False)
trial_metrics_df.sort_values("scores", ascending=True)

Unnamed: 0,train_acc,test_acc,val_acc,scores,channels_selected
5,0.941176,0.636364,0.722222,0.077244,"[P3, Fz, C4, P4, Fp1, F7, F8, T6]"
2,0.794118,0.5,0.666667,0.111192,"[C3, Fz, F4, Cz, Fp1, Fp2, T3, F8]"
7,0.867647,0.545455,0.666667,0.111215,"[P3, C3, F3, Fz, F4, P4, O1, A2, T6, T4]"
13,1.0,0.545455,0.666667,0.111238,"[P3, F3, Fz, F4, P4, Cz, Fp2, T3, T5, O2, F7, A2]"
20,0.867647,0.5,0.611111,0.151308,"[F3, Cz, Fp1, T3, T5, O2, F8]"
12,0.808824,0.5,0.611111,0.151316,"[Fz, C4, Pz, Fp1, Fp2, T3, F7, F8]"
1,1.0,0.545455,0.611111,0.15134,"[F4, P4, Cz, Pz, Fp1, T5, O2, F7, F8, T4]"
8,0.941176,0.545455,0.611111,0.15136,"[P3, F3, F4, Cz, Pz, Fp1, Fp2, T3, O1, F7, F8,..."
22,0.955882,0.636364,0.611111,0.151364,"[F3, Fz, F4, P4, Cz, Pz, Fp1, T3, O1, O2, F7, T4]"
24,0.779412,0.454545,0.611111,0.151366,"[P3, C3, Fz, C4, P4, Cz, Pz, T3, O1, O2, F8, T..."


In [27]:
rpprint({ k: v for k, v in study.best_trial.params.items() if not k.startswith("channels") })
rprint("test_accuracy =", study.best_trial.user_attrs["trial_data"]["test_accuracy"])
rprint("val_accuracy =", np.max(study.best_trial.user_attrs["trial_data"]["val_accuracy"]))
rprint("channels_selected =", study.best_trial.user_attrs["trial_data"]["channels_selected"])

In [None]:
channels = fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1]
channels_of_interest = ['C3', 'F3', 'C4', 'T5', 'O1', 'O2', 'A2', 'T6']
channels_of_interest_idx = np.where(np.isin(channels, channels_of_interest))[0]

channels_of_interest_idx

In [None]:
# Model/pipeline hyperparameter optimization test
data_store_dict = {}

def objective_fn_lstm_cnn_net(trial, **kwargs):
    # _SFREQ_ = trial.suggest_categorical("sfreq", [128, 250])
    _SFREQ_ = 250
    # _TRAIN_SIZE_ = trial.suggest_categorical("train_size", [0.8, 0.9])
    _TRAIN_SIZE_ = 0.8
    _TEST_SIZE_ = 1 - _TRAIN_SIZE_
    _BATCH_SIZE = 256
    
    channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
    channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    while len(channels_idx) == 0:
        channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
        channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    
    X, y, _ = data_generator(fat_dataset, subjects=[3], channel_idx=channels_idx, sfreq=_SFREQ_)
    
    _NUM_SAMPLES_ = X.shape[-1]
    _NUM_CHANNELS_ = X.shape[-2]
    # rpprint({
    #     "NUM_SAMPLES": _NUM_SAMPLES_,
    #     "NUM_CHANNELS": _NUM_CHANNELS_,
    #     "X": X.shape,
    #     "y": y.shape,
    #     "channels selected": np.asarray(fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'])[channels_idx]
    # })
    y_encoded = LabelEncoder().fit_transform(y)
    _NUM_CLASSES_ = len(np.unique(y_encoded))
    X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, _NUM_SAMPLES_, _NUM_CHANNELS_), y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
    # X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
    
    model_params = {
        "nb_classes": _NUM_CLASSES_,
        "channels": _NUM_CHANNELS_,
        "samples": _NUM_SAMPLES_,
    }
    
    model = lstm_cnn_net_v2(**model_params)
    # model = lstm_cnn_net(**model_params)
    # model = shallow_conv_net(**model_params)
    model.compile(
        loss="sparse_categorical_crossentropy", 
        optimizer="adam", 
        metrics=["accuracy"]
    )
    
    history = model.fit(
        X_train,
        y_train,
        batch_size=_BATCH_SIZE,
        epochs=50,
        validation_split=0.2,
        callbacks=[
            # keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
            # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
            keras.callbacks.EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True),
            keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
        ],
    )

    trial.set_user_attr("trial_data", {
        "channels_selected": np.asarray(fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'])[channels_idx],
        # "history": history,
        # "model": model,
        "val_accuracy": history.history["val_accuracy"],
        "test_accuracy": model.evaluate(X_test, y_test, batch_size=_BATCH_SIZE)[1]
    })

    # return (1 - np.mean(history.history["val_accuracy"][-2:])) + 0.005*len(channels_idx)
    return (1 - history.history["val_accuracy"][-1])**2 + 0.00005*len(channels_idx)

def heuristic_optimizer(obj_fn = lambda: None, max_iter = 15, show_progress_bar = True, **kwargs):
    
    optuna.logging.set_verbosity(optuna.logging.CRITICAL)

    # trial_data = []
    # trial_callback = lambda study, trial: trial_data.append(trial.user_attrs["trial_data"])
    
    # study = optuna.create_study(direction="minimize", sampler=optuna.samplers.CmaEsSampler())
    study = optuna.create_study(direction="minimize", sampler=optuna.samplers.NSGAIISampler())
    study.optimize(obj_fn, n_trials=max_iter, show_progress_bar=show_progress_bar, callbacks=[], **kwargs)
    
    return study

study = heuristic_optimizer(obj_fn = objective_fn_lstm_cnn_net)

In [None]:
# Save model
model.save("model.keras")

In [None]:
# Load model from disk
loaded_model = keras.saving.load_model("model.keras", custom_objects=CUSTOM_OBJECTS)

In [None]:
## Pruning and quantization test

from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper

baseline_model = loaded_model
pruning_params = {
    'pruning_schedule': pruning_wrapper.PolynomialDecay(initial_sparsity=0.00,
                                                        final_sparsity=0.80,
                                                        begin_step=0,
                                                        end_step=100)
}
sparse_model = pruning_wrapper.PruneLowMagnitude(baseline_model, **pruning_params)

sparse_model.summary()

### Data exploration, etc.

In [None]:
def create_spectrogram(epoch_data):
    spectrograms = []
    for epoch in epoch_data:
        epoch_spectrograms = []
        for channel in epoch:
            f, t, Sxx = sp.signal.spectrogram(channel, fs=50)
            epoch_spectrograms.append(Sxx)
        spectrograms.append(np.array(epoch_spectrograms))
    return np.array(spectrograms)

spectograms = create_spectrogram(X)
rprint(spectograms.shape) # (180, 21, 129, 2)

num_left_hand_epochs_to_visualize = 2
num_right_hand_epochs_to_visualize = 2

# Select epochs for visualization
left_hand_epochs_to_visualize = np.where((y == 'left_hand'))[0]
right_hand_epochs_to_visualize = np.where((y == 'right_hand'))[0]
left_hand_epochs_to_visualize = np.random.choice(left_hand_epochs_to_visualize, num_left_hand_epochs_to_visualize)
right_hand_epochs_to_visualize = np.random.choice(right_hand_epochs_to_visualize, num_right_hand_epochs_to_visualize)
epochs_to_visualize = np.concatenate((left_hand_epochs_to_visualize, right_hand_epochs_to_visualize))

rprint(epochs_to_visualize)

# Visualize spectrograms
fig, axs = plt.subplots(epochs_to_visualize.size, 1, figsize=(10, 10))
for i, idx in enumerate(epochs_to_visualize):
    axs[i].matshow(spectograms[idx, :, :, 0], cmap='viridis')
    axs[i].set_title(y[idx])
    axs[i].set_xlabel('Frequency (Hz)')
    axs[i].set_ylabel('Time (s)')
    axs[i].set_xlim(0, 60) # Limit x-axis to 60Hz

plt.tight_layout()
plt.show()