In [107]:
# General imports
import os
import uuid
import time
import random
import datetime
import glob
import pickle
import tqdm
import copy
import optuna
import numpy as np
import pandas as pd
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import mne
from rich import print as rprint
from rich.pretty import pprint as rpprint
from tqdm import tqdm
from itertools import chain
from functools import partial

# JAX + Keras
os.environ["KERAS_BACKEND"] = "jax"
os.environ["TF_USE_LEGACY_KERAS"] = "0"
import jax
import jax.numpy as jnp
import keras
from keras.models import Model
from keras.layers import Dense, Activation, Permute, Dropout
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv1D, MaxPooling1D, AveragePooling1D
from keras.layers import SeparableConv2D, DepthwiseConv2D
from keras.layers import BatchNormalization
from keras.layers import SpatialDropout2D
from keras.regularizers import l1_l2
from keras.layers import Input, Flatten
from keras.constraints import max_norm
from keras import backend as K

# Sklearn
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout, TimeDistributed, BatchNormalization
import h5py
from sklearn import preprocessing

# Dataset
from custom_datasets.fatigue_mi import FatigueMI

%load_ext autoreload
%autoreload 3

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
SKLRNG = 42
RNG = jax.random.PRNGKey(SKLRNG)

2024-02-22 18:16:57.793317: W external/xla/xla/service/gpu/nvptx_compiler.cc:744] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
def data_generator(dataset, subjects = [1], channel_idx = [], filters = ([8, 32],), sfreq = 250):

    find_events = lambda raw, event_id: mne.find_events(raw, shortest_event=0, verbose=False) if len(mne.utils._get_stim_channel(None, raw.info, raise_error=False)) > 0 else mne.events_from_annotations(raw, event_id=event_id, verbose=False)[0]
    
    data = dataset.get_data(subjects=subjects)
    
    X = []
    y = []
    metadata = []

    for subject_id in data.keys():
        for session_id in data[subject_id].keys():
            for run_id in data[subject_id][session_id].keys():
                raw = data[subject_id][session_id][run_id]
                
                for fmin, fmax in filters:
                    raw = raw.filter(l_freq = fmin, h_freq = fmax, method = 'iir', picks = 'eeg', verbose = False)
                
                events = find_events(raw, dataset.event_id)

                tmin = dataset.interval[0]
                tmax = dataset.interval[1]

                channels = np.asarray(raw.info['ch_names'])[channel_idx] if len(channel_idx) > 0 else np.asarray(raw.info['ch_names'])

                # rpprint(channels)
                
                stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
                picks = mne.pick_channels(raw.info["ch_names"], include=channels, exclude=stim_channels, ordered=True)

                x = mne.Epochs(
                    raw,
                    events,
                    event_id=dataset.event_id,
                    tmin=tmin,
                    tmax=tmax,
                    proj=False,
                    baseline=None,
                    preload=True,
                    verbose=False,
                    picks=picks,
                    event_repeated="drop",
                    on_missing="ignore",
                )
                x_events = x.events
                inv_events = {k: v for v, k in dataset.event_id.items()}
                labels = [inv_events[e] for e in x_events[:, -1]]

                # rpprint({
                #     "X": np.asarray(x.get_data(copy=False)).shape,
                #     "y": np.asarray(labels).shape,
                #     "channels selected": np.asarray(raw.info['ch_names'])[channel_idx]
                # })

                # x.plot(scalings="auto")
                # display(x.info)
                
                x_resampled = x.resample(sfreq) # Resampler_Epoch
                x_resampled_data = x_resampled.get_data(copy=False) # Convert_Epoch_Array
                x_resampled_data_standard_scaler = np.asarray([
                    StandardScaler().fit_transform(x_resampled_data[i])
                    for i in np.arange(x_resampled_data.shape[0])
                ]) # Standard_Scaler_Epoch

                # x_resampled.plot(scalings="auto")
                # display(x_resampled.info)

                n = x_resampled_data_standard_scaler.shape[0]
                # n = x.get_data(copy=False).shape[0]
                met = pd.DataFrame(index=range(n))
                met["subject"] = subject_id
                met["session"] = session_id
                met["run"] = run_id
                x.metadata = met.copy()
                
                # X.append(x_resampled_data_standard_scaler)
                X.append(x)
                y.append(labels)
                metadata.append(met)

    return np.concatenate(X, axis=0), np.concatenate(y), pd.concat(metadata, ignore_index=True)

fat_dataset = FatigueMI()
# X, y, _ = data_generator(fat_dataset, subjects=fat_dataset.subject_list, sfreq=250)
X, y, _ = data_generator(fat_dataset, subjects=[5], channel_idx=[], sfreq=128)
# X, y, _ = data_generator(fat_dataset, subjects=[6], channel_idx=[1, 2, 5, 12, 13, 14, 17, 18], sfreq=128)
# X, y, _ = data_generator(fat_dataset, subjects=[5, 6, 12, 17, 24, 28, 29], channel_idx=[], sfreq=128)

Adding metadata with 3 columns


In [4]:
y_encoded = LabelEncoder().fit_transform(y) # 0 = left, 1 = right, 2 = unlabelled (!NOTE: should be ignored for binary classification)
# rpprint(y_encoded)
rprint({
    "X": X.shape,
    "y": y_encoded.shape
})

NUM_SAMPLES = X.shape[-1]
NUM_CHANNELS = X.shape[-2]
NUM_CLASSES = len(np.unique(y_encoded))

rpprint({
    "NUM_SAMPLES": NUM_SAMPLES,
    "NUM_CHANNELS": NUM_CHANNELS,
    "NUM_CLASSES": NUM_CLASSES
})

# sns.histplot(y_encoded); # Plot the class distribution

In [5]:
TRAIN_SIZE = 0.8
TEST_SIZE = 0.2

# X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, NUM_SAMPLES, NUM_CHANNELS), y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=None, shuffle=False)
# X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, NUM_SAMPLES, NUM_CHANNELS), y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=y_encoded, shuffle=True)
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=y_encoded, shuffle=True)
# X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=None, shuffle=False)

total = X_train.shape[0] + X_test.shape[0]
rpprint({
    "X_train": f"{X_train.shape[0]} ({X_train.shape[0] / total * 100:.2f}%)",
    "X_test": f"{X_test.shape[0]} ({X_test.shape[0] / total * 100:.2f}%)",
}, expand_all=True)

# batch_size = 32
# batches = batch_generator(X_train, batch_size, RNG)

In [6]:
# region Helper funcs
def shallow_conv_net_square_layer(x):
    return jnp.square(x)

def shallow_conv_net_log_layer(x):
    return jnp.log(jnp.clip(x, 1e-7, 10000))

CUSTOM_OBJECTS = {
    "shallow_conv_net_square_layer": shallow_conv_net_square_layer, 
    "shallow_conv_net_log_layer": shallow_conv_net_log_layer 
}
# endregion Helper funcs

# region Models
def eeg_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, dropout_rate = 0.5, 
            kernLength = 64, F1 = 8, D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):
    """
        From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """

    dropoutType = {'Dropout': Dropout, 'SpatialDropout2D': SpatialDropout2D}[dropoutType]

    input1   = Input(shape = (channels, samples, 1))

    ##################################################################
    block1       = Conv2D(F1, (1, kernLength), padding = 'same',
                                   input_shape = (channels, samples, 1),
                                   use_bias = False)(input1)
    block1       = BatchNormalization()(block1)
    block1       = DepthwiseConv2D((channels, 1), use_bias = False, 
                                   depth_multiplier = D,
                                   depthwise_constraint = max_norm(1.))(block1)
    block1       = BatchNormalization()(block1)
    block1       = Activation('elu')(block1)
    block1       = AveragePooling2D((1, 4))(block1)
    block1       = dropoutType(dropout_rate)(block1)
    
    block2       = SeparableConv2D(F2, (1, 16), use_bias = False, padding = 'same')(block1)
    block2       = BatchNormalization()(block2)
    block2       = Activation('elu')(block2)
    block2       = AveragePooling2D((1, 8))(block2)
    block2       = dropoutType(dropout_rate)(block2)
        
    flatten      = Flatten(name = 'flatten')(block2)
    
    dense        = Dense(nb_classes, name = 'dense', 
                         kernel_constraint = max_norm(norm_rate))(flatten)
    softmax      = Activation('softmax', name = 'softmax')(dense)
    
    return Model(inputs=input1, outputs=softmax)

def deep_conv_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, dropout_rate = 0.5):
    """
        From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """
    input_main   = Input((channels, samples, 1))
    block1       = Conv2D(25, (1, 5), 
                                 input_shape=(channels, samples, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(input_main)
    block1       = Conv2D(25, (channels, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block1       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
    block1       = Activation('elu')(block1)
    block1       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block1)
    block1       = Dropout(dropout_rate)(block1)
  
    block2       = Conv2D(50, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block2       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block2)
    block2       = Activation('elu')(block2)
    block2       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block2)
    block2       = Dropout(dropout_rate)(block2)
    
    block3       = Conv2D(100, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block2)
    block3       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block3)
    block3       = Activation('elu')(block3)
    block3       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block3)
    block3       = Dropout(dropout_rate)(block3)
    
    block4       = Conv2D(200, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block3)
    block4       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block4)
    block4       = Activation('elu')(block4)
    block4       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block4)
    block4       = Dropout(dropout_rate)(block4)
    
    flatten      = Flatten()(block4)
    
    dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
    softmax      = Activation('softmax')(dense)
    
    return Model(inputs=input_main, outputs=softmax)

def shallow_conv_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, **kwargs):
    """
        From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """

    _POOL_SIZE_D2_ = kwargs.get("pool_size_d2", 35)
    _STRIDES_D2_ = kwargs.get("strides_d2", 7)
    _CONV_FILTERS_D2_ = kwargs.get("conv_filters_d2", 13)

    _POOL_SIZE_ = kwargs.get("pool_size", (1, _POOL_SIZE_D2_))
    _STRIDES_ = kwargs.get("strides", (1, _STRIDES_D2_))
    _CONV_FILTERS_ = kwargs.get("conv_filters", (1, _CONV_FILTERS_D2_))

    _CONV2D_1_UNITS_ = kwargs.get("conv2d_1_units", 40)
    _CONV2D_2_UNITS_ = kwargs.get("conv2d_2_units", 40)
    _L2_REG_1_ = kwargs.get("l2_reg_1", 0.01)
    _L2_REG_2_ = kwargs.get("l2_reg_2", 0.01)
    _L2_REG_3_ = kwargs.get("l2_reg_3", 0.01)
    _DROPOUT_RATE_ = kwargs.get("dropout_rate", 0.5)

    input_main   = Input(shape=(channels, samples, 1))
    block1       = Conv2D(_CONV2D_1_UNITS_, _CONV_FILTERS_,
                                 input_shape=(channels, samples, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)), kernel_regularizer=keras.regularizers.L2(_L2_REG_1_))(input_main)
    # block1       = Conv2D(40, (channels, 1), use_bias=False, 
    #                       kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block1       = Conv2D(_CONV2D_2_UNITS_, (channels, 1), use_bias=False, 
                          kernel_constraint = max_norm(2., axis=(0,1,2)), kernel_regularizer=keras.regularizers.L2(_L2_REG_2_))(block1)
    block1       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
    block1       = Activation(shallow_conv_net_square_layer)(block1)
    block1       = AveragePooling2D(pool_size=_POOL_SIZE_, strides=_STRIDES_)(block1)
    block1       = Activation(shallow_conv_net_log_layer)(block1)
    block1       = Dropout(_DROPOUT_RATE_)(block1)
    flatten      = Flatten()(block1)
    # dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
    dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5), kernel_regularizer=keras.regularizers.L2(_L2_REG_3_))(flatten)
    softmax      = Activation('softmax')(dense)
    
    return Model(inputs=input_main, outputs=softmax)

def lstm_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES):
    model = Sequential()
    model.add(Input(shape=(samples, channels)))
    model.add(LSTM(40, return_sequences=True, stateful=False))
    # model.add(LSTM(40, return_sequences=True, stateful=False, batch_input_shape=(batch_size, timesteps, data_dim)))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    model.add(LSTM(40, return_sequences=True, stateful=False))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    model.add(LSTM(40, stateful=False))
    model.add(Dropout(0.5))
    # model.add(TimeDistributed(Dense(T_train)))
    model.add(Dense(50,activation='softmax'))
    model.add(Dense(nb_classes, activation='softmax'))
    
    return model

def lstm_cnn_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, dropout_rate_1 = 0.5, dropout_rate_2 = 0.5):
    model = Sequential()
    model.add(Input(shape=(samples, channels)))

    # add 1-layer cnn
    # model.add(Conv1D(40, kernel_size=20, strides=4))
    model.add(Conv1D(40, kernel_size=20, strides=4))
    # model.add(Conv1D(40, kernel_size=20, strides=4, kernel_regularizer=keras.regularizers.L2(0.001), activity_regularizer=keras.regularizers.L2(0.001)))
    model.add(Activation('relu'))
    model.add(Dropout(dropout_rate_1))
    model.add(BatchNormalization())
    model.add(MaxPooling1D(pool_size=4, strides=4))


    # add 1-layer lstm
    # model.add(LSTM(50, return_sequences=True, stateful=False))
    model.add(LSTM(128, return_sequences=True, stateful=False))
    # model.add(LSTM(50, return_sequences=True, stateful=False, kernel_regularizer=keras.regularizers.L2(0.01)))
    # model.add(LSTM(128, return_sequences=True, stateful=False))
    # model.add(LSTM(32, return_sequences=True, stateful=False))
    model.add(Dropout(dropout_rate_2))
    model.add(BatchNormalization())
    model.add(Flatten())
    model.add(Dense(nb_classes, activation='softmax'))
    # model.add(Dense(nb_classes, activation='softmax', kernel_regularizer=keras.regularizers.L2(0.01)))
    
    return model

def lstm_cnn_net_v2(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES):
    model = Sequential()
    model.add(Input(shape=(samples, channels)))

    # Add 1-layer LSTM
    model.add(LSTM(50, return_sequences=True, stateful=False))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    # Add 1-layer CNN
    model.add(Conv1D(40, kernel_size=20, strides=4))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    model.add(MaxPooling1D(pool_size=4, strides=4))
    model.add(Flatten())
    # Fully connected layer for classification
    model.add(Dense(nb_classes, activation='softmax'))
    return model

# endregion Models

In [None]:
NUM_FOLDS = 5
kfold = KFold(n_splits=NUM_FOLDS, shuffle=True)
acc_per_fold = []
loss_per_fold = []

model = shallow_conv_net()
# model = lstm_cnn_net()

# K-fold Cross Validation model evaluation
fold_no = 1
for train, test in kfold.split(X_train, y_train):
        # Fit data to model
        model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
        
        history = model.fit(
                X_train[train],
                y_train[train],
                batch_size=64,
                epochs=100,
                validation_data=(X_train[test], y_encoded[test]),
                # validation_split=0.2,
                shuffle=True,
                callbacks=[
                        keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
                        keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
                        # keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
                        # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
                ],
        )
        
        # Generate generalization metrics
        scores = model.evaluate(X_train[test], y_train[test], verbose=0)
        rprint(f'Score for fold {fold_no}: {model.metrics_names[0]} of {scores[0]}; {model.metrics_names[1]} of {scores[1]*100}%')
        acc_per_fold.append(scores[1] * 100)
        loss_per_fold.append(scores[0])
        
        # Increase fold number
        fold_no = fold_no + 1

In [None]:
mean_cross_val_score = np.mean(acc_per_fold)
rprint("Mean Cross Validation Score (5-fold):", mean_cross_val_score)

results_test = model.evaluate(X_test, y_test)
rprint("Test Accuracy:", results_test[1])

In [None]:
# results_val = model.evaluate(X_val, y_val)
results_test = model.evaluate(X_test, y_test)

# X_test.reshape(-1, NUM_SAMPLES, NUM_CHANNELS).shape

In [None]:
model = shallow_conv_net()
# model = lstm_cnn_net()
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

history = model.fit(
    X_train,
    y_train,
    batch_size=128,
    epochs=25,
    # validation_data=(X_test, y_test),
    validation_split=0.2,
    shuffle=True,
    callbacks=[
        keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
        keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
        # keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
        # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
    ],
)

In [None]:
# model.history.history.keys()

In [270]:
def objective_fn(trial, subjects=[], model_str = "lstm_cnn_net", sfreq = 128, batch_size = 128, **kwargs):
    _SFREQ_ = sfreq if sfreq else trial.suggest_categorical("sfreq", [128, 256, 300])
    _TRAIN_SIZE_ = 0.8
    _TEST_SIZE_ = 1 - _TRAIN_SIZE_
    _BATCH_SIZE = batch_size if batch_size else trial.suggest_int("batch_size", 32, 256, step=32)

    models = {
        "lstm_cnn_net": lstm_cnn_net,
        "shallow_conv_net": shallow_conv_net
    }
    
    subjects = subjects if len(subjects) > 0 else fat_dataset.subject_list
    model_fn = models[model_str]
    
    channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
    channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    while len(channels_idx) == 0:
        channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
        channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    
    X, y, _ = data_generator(fat_dataset, subjects=subjects, channel_idx=channels_idx, sfreq=_SFREQ_)
    
    _NUM_SAMPLES_ = X.shape[-1]
    _NUM_CHANNELS_ = X.shape[-2]
    
    y_encoded = LabelEncoder().fit_transform(y)
    _NUM_CLASSES_ = len(np.unique(y_encoded))
    if "lstm" in model_str:
        X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, _NUM_SAMPLES_, _NUM_CHANNELS_), y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
    else:
        X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, shuffle=True, stratify=y_encoded)
    # X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, _NUM_SAMPLES_, _NUM_CHANNELS_), y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
        
    # y_train = keras.utils.to_categorical(y_train)
    # y_test = keras.utils.to_categorical(y_test)
    
    model_params = {
        "nb_classes": _NUM_CLASSES_,
        "channels": _NUM_CHANNELS_,
        "samples": _NUM_SAMPLES_,
        
        "pool_size_d2": trial.suggest_int("pool_size_d2", 5, 95, step=5),
        "strides_d2": trial.suggest_int("strides_d2", 1, 31, step=2),
        "conv_filters_d2": trial.suggest_int("conv_filters_d2", 5, 55, step=2),
        
        "conv2d_1_units": trial.suggest_int("conv2d_1_units", 10, 200, step=10),
        "conv2d_2_units": trial.suggest_int("conv2d_2_units", 10, 200, step=10),
        "l2_reg_1": trial.suggest_float("l2_reg_1", 0.001, 0.8),
        "l2_reg_2": trial.suggest_float("l2_reg_2", 0.001, 0.8),
        "l2_reg_3": trial.suggest_float("l2_reg_3", 0.001, 0.8),
        "dropout_rate": trial.suggest_float("dropout_rate", 0.1, 0.9, step=0.1),
    }
    
    model = model_fn(**model_params)
    model.compile(
        loss="sparse_categorical_crossentropy", 
        optimizer="adam", 
        metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")]
    )
    
    history = model.fit(
        X_train,
        y_train,
        batch_size=_BATCH_SIZE,
        epochs=5,
        validation_split=0.2,
        shuffle=True,
        callbacks=[
            # keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
            # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
            keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True),
            keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
        ],
    )
    
    # Rename history.history["sparse_categorical_accuracy"] to history.history["accuracy"]
    # history.history["accuracy"] = history.history["sparse_categorical_accuracy"]
    # history.history["val_accuracy"] = history.history["val_sparse_categorical_accuracy"]
    
    trial.set_user_attr("trial_data", {
        "channels_selected": np.asarray(fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'])[channels_idx],
        # "history": history,
        "model": model.to_json(),
        "train_accuracy": history.history["accuracy"],
        "val_accuracy": history.history["val_accuracy"],
        "test_accuracy": model.evaluate(X_test, y_test, batch_size=_BATCH_SIZE)[1]
    })

    cost = (1 - np.max(history.history["val_accuracy"]))**2 + 0.5*(1 - np.mean(history.history["val_accuracy"]))**2 + 1e-5*len(channels_idx) + 1e-3*np.mean(np.asarray(history.history["accuracy"]) - np.asarray(history.history["val_accuracy"]))**2
    # cost = (1 - np.max(history.history["val_accuracy"]))**2 + 1e-5*len(channels_idx) + 1e-3*np.mean(np.asarray(history.history["accuracy"]) - np.asarray(history.history["val_accuracy"]))**2
    # cost = (1 - np.max(history.history["val_accuracy"]))**2 + 8e-5*len(channels_idx) #+ 1e-3*(np.mean(history.history["accuracy"][-2:]) - np.mean(history.history["val_accuracy"][-2:]))**2
    # cost = (1 - np.max(history.history["val_accuracy"]))**2 + 0.25*(1 - np.mean(history.history["val_accuracy"]))**2 + 0.25*(1 - history.history["val_accuracy"][-1])**2 + 8e-5*len(channels_idx) #+ 1e-3*np.mean(np.asarray(history.history["accuracy"]) - np.asarray(history.history["val_accuracy"]))**2
    # cost = (1 - np.max(history.history["val_accuracy"]))**2 + 1e-5*len(channels_idx) + 1e-4*np.mean(np.asarray(history.history["accuracy"]) - np.asarray(history.history["val_accuracy"]))**2 + 1e-5*np.mean(1 - np.max(history.history["accuracy"]))**2
    # if not 0.70 <= np.max(history.history["accuracy"]) <= 1.00 or len(channels_idx) < 5:
    if not 0.75 <= np.max(history.history["accuracy"]) <= 1.00:
        cost += .5

    # cost = (1 - np.mean(history.history["val_accuracy"][-2:])) + 0.005*len(channels_idx)
    # cost = (1 - np.mean(history.history["val_accuracy"]))**2 + 2.5e-1*(np.mean(history.history["accuracy"]) - np.mean(history.history["val_accuracy"]))**2 + 1.5e-1*(1 - np.max(history.history["val_accuracy"]))**2 + 5e-2*len(channels_idx)
    # cost = (1 - np.mean(history.history["val_accuracy"]))**2 + 2.5e-1*(np.mean(history.history["accuracy"]) - np.mean(history.history["val_accuracy"]))**2 + 5e-2*len(channels_idx)
    return cost

def heuristic_optimizer(obj_fn, max_iter = 25, show_progress_bar = True, subjects=[], model_str="lstm_cnn_net", sfreq = 128, batch_size = 128, **kwargs):
    
    optuna.logging.set_verbosity(optuna.logging.CRITICAL)

    objective = partial(obj_fn, subjects = subjects, model_str = model_str, sfreq = sfreq, batch_size = batch_size)
    
    study = optuna.create_study(direction="minimize", sampler=optuna.samplers.NSGAIIISampler())
    try:
        study.optimize(objective, n_trials=max_iter, show_progress_bar=show_progress_bar, callbacks=[], **kwargs)
    except KeyboardInterrupt:
        pass
    
    return study

study = heuristic_optimizer(obj_fn = objective_fn, model_str="shallow_conv_net", subjects=[6], sfreq = None, batch_size = None, max_iter = 30)

  study = optuna.create_study(direction="minimize", sampler=optuna.samplers.NSGAIIISampler())


  0%|          | 0/30 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/5


  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.4559 - loss: 70.3270 - val_accuracy: 0.7778 - val_loss: 63.2632 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.6765 - loss: 63.2858 - val_accuracy: 0.8333 - val_loss: 59.8051 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.9118 - loss: 59.6576 - val_accuracy: 0.7778 - val_loss: 56.4540 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.9853 - loss: 56.1996 - val_accuracy: 0.8889 - val_loss: 53.2192 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 1.0000 - loss: 52.8890 - val_accuracy: 0.8889 - val_loss: 50.1255 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 613ms/step - accuracy: 0.5909 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5441 - loss: 61.4941 - val_accuracy: 0.5556 - val_loss: 54.7392 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 43ms/step - accuracy: 0.7353 - loss: 54.6618 - val_accuracy: 0.5556 - val_loss: 51.7847 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.7206 - loss: 51.6449 - val_accuracy: 0.6111 - val_loss: 48.9407 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.9118 - loss: 48.6820 - val_accuracy: 0.6111 - val_loss: 46.1973 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.9412 - loss: 45.8916 - val_accuracy: 0.5556 - val_loss: 43.5584 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 505ms/step - accuracy: 0.7273 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5147 - loss: 66.9759 - val_accuracy: 0.5556 - val_loss: 60.0416 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8235 - loss: 59.7896 - val_accuracy: 0.5000 - val_loss: 56.7535 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.9412 - loss: 56.3328 - val_accuracy: 0.5000 - val_loss: 53.5037 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 1.0000 - loss: 52.9939 - val_accuracy: 0.5000 - val_loss: 50.3139 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 1.0000 - loss: 49.7714 - val_accuracy: 0.5556 - val_loss: 47.2254 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 427ms/step - accuracy: 0.4091 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.4853 - loss: 145.9614 - val_accuracy: 0.3889 - val_loss: 134.3968 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.5882 - loss: 134.1661 - val_accuracy: 0.4444 - val_loss: 127.4904 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.7794 - loss: 127.1654 - val_accuracy: 0.4444 - val_loss: 120.7672 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.8676 - loss: 120.2690 - val_accuracy: 0.4444 - val_loss: 114.2581 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.9412 - loss: 113.6583 - val_accuracy: 0.3889 - val_loss: 107.9657 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 345ms/step - accurac

  super().__init__(


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 777ms/step - accuracy: 0.4805 - loss: 41.3818 - val_accuracy: 0.7222 - val_loss: 37.1202 - learning_rate: 0.0010
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.7006 - loss: 36.2804 - val_accuracy: 0.6667 - val_loss: 33.2826 - learning_rate: 0.0010
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.7123 - loss: 32.5499 - val_accuracy: 0.7222 - val_loss: 29.7463 - learning_rate: 0.0010
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.8267 - loss: 28.9192 - val_accuracy: 0.5556 - val_loss: 26.5240 - learning_rate: 0.0010
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - accuracy: 0.8405 - loss: 25.6681 - val_accuracy: 0.5000 - val_loss: 23.5634 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 358ms/step - accuracy: 0.4091 -

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5000 - loss: 23.3914 - val_accuracy: 0.3889 - val_loss: 20.1167 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 41ms/step - accuracy: 0.9412 - loss: 19.8196 - val_accuracy: 0.5000 - val_loss: 19.3239 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 1.0000 - loss: 18.9043 - val_accuracy: 0.5000 - val_loss: 18.5487 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 40ms/step - accuracy: 0.9853 - loss: 18.0419 - val_accuracy: 0.5000 - val_loss: 17.7898 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 1.0000 - loss: 17.2122 - val_accuracy: 0.5000 - val_loss: 17.0349 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 490ms/step - accuracy: 0.5000 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4853 - loss: 197.2717 - val_accuracy: 0.5556 - val_loss: 181.4330 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.7647 - loss: 181.3262 - val_accuracy: 0.4444 - val_loss: 171.3394 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.7794 - loss: 171.1821 - val_accuracy: 0.4444 - val_loss: 161.6122 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.9559 - loss: 161.4143 - val_accuracy: 0.6111 - val_loss: 152.2724 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.9412 - loss: 152.0763 - val_accuracy: 0.7222 - val_loss: 143.3315 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 402ms/step - accurac

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 5s/step - accuracy: 0.6029 - loss: 58.7624 - val_accuracy: 0.5000 - val_loss: 54.6063 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 60ms/step - accuracy: 0.8529 - loss: 53.8263 - val_accuracy: 0.6111 - val_loss: 55.1355 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step - accuracy: 0.5882 - loss: 53.5644 - val_accuracy: 0.5556 - val_loss: 49.7071 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 55ms/step - accuracy: 0.9706 - loss: 46.9327 - val_accuracy: 0.4444 - val_loss: 47.8815 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 54ms/step - accuracy: 0.9853 - loss: 43.6950 - val_accuracy: 0.4444 - val_loss: 49.3597 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 705ms/step - accuracy: 0.3636 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4265 - loss: 31.0544 - val_accuracy: 0.7222 - val_loss: 24.8531 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - accuracy: 0.7059 - loss: 24.7952 - val_accuracy: 0.6111 - val_loss: 23.9710 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step - accuracy: 0.8676 - loss: 23.6192 - val_accuracy: 0.6667 - val_loss: 22.9853 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.9559 - loss: 22.5402 - val_accuracy: 0.7222 - val_loss: 22.0219 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.9559 - loss: 21.5599 - val_accuracy: 0.7222 - val_loss: 21.1037 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 351ms/step - accuracy: 0.5909 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5588 - loss: 63.1612 - val_accuracy: 0.4444 - val_loss: 55.9106 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.6029 - loss: 55.5729 - val_accuracy: 0.6111 - val_loss: 53.1735 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.9853 - loss: 52.8378 - val_accuracy: 0.7222 - val_loss: 50.8464 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.8971 - loss: 50.5709 - val_accuracy: 0.6667 - val_loss: 48.6672 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 1.0000 - loss: 48.2510 - val_accuracy: 0.6111 - val_loss: 46.6195 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 595ms/step - accuracy: 0.6818 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.4853 - loss: 98.7750 - val_accuracy: 0.6111 - val_loss: 88.3662 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 41ms/step - accuracy: 0.7206 - loss: 88.2966 - val_accuracy: 0.6111 - val_loss: 83.0823 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 0.8824 - loss: 82.9727 - val_accuracy: 0.7778 - val_loss: 78.0326 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.8971 - loss: 77.8866 - val_accuracy: 0.7778 - val_loss: 73.2082 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.9265 - loss: 73.0321 - val_accuracy: 0.7778 - val_loss: 68.6004 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 470ms/step - accuracy: 0.5909 

  super().__init__(


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2s/step - accuracy: 0.5496 - loss: 56.5908 - val_accuracy: 0.5556 - val_loss: 52.0614 - learning_rate: 0.0010
Epoch 2/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - accuracy: 0.5699 - loss: 51.2446 - val_accuracy: 0.4444 - val_loss: 47.5930 - learning_rate: 0.0010
Epoch 3/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - accuracy: 0.5849 - loss: 46.8041 - val_accuracy: 0.3889 - val_loss: 43.4035 - learning_rate: 0.0010
Epoch 4/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - accuracy: 0.6847 - loss: 42.5392 - val_accuracy: 0.4444 - val_loss: 39.4695 - learning_rate: 0.0010
Epoch 5/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - accuracy: 0.6697 - loss: 38.6924 - val_accuracy: 0.3889 - val_loss: 35.7952 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 453ms/step - accuracy: 0.3182 - lo

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 5s/step - accuracy: 0.4559 - loss: 102.9635 - val_accuracy: 0.5556 - val_loss: 92.3594 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 73ms/step - accuracy: 0.7941 - loss: 92.1631 - val_accuracy: 0.5000 - val_loss: 86.6803 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 58ms/step - accuracy: 0.8529 - loss: 86.4106 - val_accuracy: 0.5000 - val_loss: 81.1569 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 59ms/step - accuracy: 0.8971 - loss: 80.8454 - val_accuracy: 0.5000 - val_loss: 75.8284 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 58ms/step - accuracy: 0.8971 - loss: 75.4637 - val_accuracy: 0.4444 - val_loss: 70.7204 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 647ms/step - accuracy: 0.6364

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.3971 - loss: 15.9670 - val_accuracy: 0.7222 - val_loss: 14.1377 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.9559 - loss: 13.9678 - val_accuracy: 0.7222 - val_loss: 13.4366 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.9853 - loss: 13.1761 - val_accuracy: 0.7222 - val_loss: 12.7196 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.9853 - loss: 12.4182 - val_accuracy: 0.7778 - val_loss: 12.0119 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 1.0000 - loss: 11.6653 - val_accuracy: 0.7222 - val_loss: 11.3237 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 396ms/step - accuracy: 0.5455 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4s/step - accuracy: 0.5000 - loss: 14.8337 - val_accuracy: 0.5556 - val_loss: 13.2971 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 55ms/step - accuracy: 0.9706 - loss: 13.1144 - val_accuracy: 0.6667 - val_loss: 12.6329 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 48ms/step - accuracy: 0.9853 - loss: 12.3570 - val_accuracy: 0.7222 - val_loss: 11.9697 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 50ms/step - accuracy: 0.9853 - loss: 11.6409 - val_accuracy: 0.7222 - val_loss: 11.3120 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step - accuracy: 1.0000 - loss: 10.9246 - val_accuracy: 0.7222 - val_loss: 10.6653 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 693ms/step - accuracy: 0.6364 

  super().__init__(


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 767ms/step - accuracy: 0.4818 - loss: 44.6237 - val_accuracy: 0.6111 - val_loss: 39.7423 - learning_rate: 0.0010
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - accuracy: 0.5866 - loss: 38.9922 - val_accuracy: 0.6111 - val_loss: 35.7496 - learning_rate: 0.0010
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - accuracy: 0.6659 - loss: 34.9077 - val_accuracy: 0.5556 - val_loss: 32.0074 - learning_rate: 0.0010
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - accuracy: 0.8219 - loss: 31.1647 - val_accuracy: 0.4444 - val_loss: 28.5720 - learning_rate: 0.0010
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - accuracy: 0.7661 - loss: 27.8572 - val_accuracy: 0.4444 - val_loss: 25.4542 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 405ms/step - accuracy: 0.4545 -

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4706 - loss: 40.0927 - val_accuracy: 0.5556 - val_loss: 32.8935 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.6618 - loss: 32.8449 - val_accuracy: 0.3889 - val_loss: 30.8897 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.6324 - loss: 30.8049 - val_accuracy: 0.5000 - val_loss: 28.9321 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.8088 - loss: 28.7781 - val_accuracy: 0.5000 - val_loss: 27.0401 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.7647 - loss: 26.8488 - val_accuracy: 0.3889 - val_loss: 25.2216 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 411ms/step - accuracy: 0.5000 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5882 - loss: 49.0193 - val_accuracy: 0.4444 - val_loss: 45.1605 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.8529 - loss: 44.9419 - val_accuracy: 0.6111 - val_loss: 43.5377 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.8971 - loss: 43.1010 - val_accuracy: 0.6667 - val_loss: 41.7864 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.9853 - loss: 41.3355 - val_accuracy: 0.5556 - val_loss: 40.2477 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 1.0000 - loss: 39.6803 - val_accuracy: 0.5556 - val_loss: 38.7393 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 384ms/step - accuracy: 0.4091 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4559 - loss: 28.1180 - val_accuracy: 0.6667 - val_loss: 23.2472 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 0.7059 - loss: 23.2618 - val_accuracy: 0.6667 - val_loss: 22.0879 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9118 - loss: 22.0317 - val_accuracy: 0.6667 - val_loss: 20.9552 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9559 - loss: 20.8547 - val_accuracy: 0.6667 - val_loss: 19.8595 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9706 - loss: 19.7132 - val_accuracy: 0.7222 - val_loss: 18.8076 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 403ms/step - accuracy: 0.5909 

  super().__init__(


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 972ms/step - accuracy: 0.4770 - loss: 150.1596 - val_accuracy: 0.6667 - val_loss: 124.1584 - learning_rate: 0.0010
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8297 - loss: 118.3696 - val_accuracy: 0.5000 - val_loss: 98.1624 - learning_rate: 0.0010
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.7877 - loss: 93.0262 - val_accuracy: 0.4444 - val_loss: 76.3257 - learning_rate: 0.0010
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8709 - loss: 71.8148 - val_accuracy: 0.4444 - val_loss: 58.0622 - learning_rate: 0.0010
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.9623 - loss: 54.4437 - val_accuracy: 0.5556 - val_loss: 43.5902 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 622ms/step - accuracy: 0

  super().__init__(


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 882ms/step - accuracy: 0.5650 - loss: 134.2306 - val_accuracy: 0.7222 - val_loss: 118.4379 - learning_rate: 0.0010
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - accuracy: 0.8146 - loss: 114.6381 - val_accuracy: 0.7778 - val_loss: 100.9885 - learning_rate: 0.0010
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - accuracy: 0.8219 - loss: 97.5711 - val_accuracy: 0.7778 - val_loss: 85.5187 - learning_rate: 0.0010
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - accuracy: 0.8102 - loss: 82.4204 - val_accuracy: 0.7778 - val_loss: 71.9027 - learning_rate: 0.0010
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - accuracy: 0.9511 - loss: 69.1588 - val_accuracy: 0.7778 - val_loss: 60.0453 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 348ms/step - accuracy: 0.63

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5147 - loss: 30.7183 - val_accuracy: 0.3333 - val_loss: 25.1658 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step - accuracy: 0.7794 - loss: 25.0212 - val_accuracy: 0.4444 - val_loss: 23.9591 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.8235 - loss: 23.7855 - val_accuracy: 0.5000 - val_loss: 22.7946 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.8382 - loss: 22.5811 - val_accuracy: 0.6111 - val_loss: 21.6805 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.9706 - loss: 21.4192 - val_accuracy: 0.6111 - val_loss: 20.6192 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 428ms/step - accuracy: 0.7273 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5441 - loss: 140.2227 - val_accuracy: 0.5556 - val_loss: 127.7394 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 43ms/step - accuracy: 0.5441 - loss: 127.7354 - val_accuracy: 0.6111 - val_loss: 120.5368 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.6471 - loss: 120.4692 - val_accuracy: 0.7222 - val_loss: 113.6097 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.7941 - loss: 113.4915 - val_accuracy: 0.7778 - val_loss: 106.9657 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.8676 - loss: 106.8084 - val_accuracy: 0.7778 - val_loss: 100.6118 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 474ms/step - accurac

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.3676 - loss: 75.4411 - val_accuracy: 0.3333 - val_loss: 70.6629 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 43ms/step - accuracy: 0.8529 - loss: 70.3746 - val_accuracy: 0.4444 - val_loss: 66.6239 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.9265 - loss: 66.2381 - val_accuracy: 0.4444 - val_loss: 62.7202 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.9853 - loss: 62.2363 - val_accuracy: 0.3889 - val_loss: 58.9375 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 1.0000 - loss: 58.3984 - val_accuracy: 0.4444 - val_loss: 55.2902 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 499ms/step - accuracy: 0.5000 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4265 - loss: 36.0717 - val_accuracy: 0.5556 - val_loss: 30.4984 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.5000 - loss: 30.2803 - val_accuracy: 0.5556 - val_loss: 29.8106 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.6471 - loss: 29.3245 - val_accuracy: 0.5556 - val_loss: 28.6149 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.7500 - loss: 28.2725 - val_accuracy: 0.6111 - val_loss: 27.6532 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.7647 - loss: 27.4085 - val_accuracy: 0.6111 - val_loss: 26.8243 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 326ms/step - accuracy: 0.5455 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.4559 - loss: 66.8578 - val_accuracy: 0.6667 - val_loss: 59.3876 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.7647 - loss: 59.3064 - val_accuracy: 0.6667 - val_loss: 56.2833 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.9265 - loss: 56.1399 - val_accuracy: 0.7222 - val_loss: 53.2918 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.9706 - loss: 53.0988 - val_accuracy: 0.6667 - val_loss: 50.4127 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.9412 - loss: 50.1959 - val_accuracy: 0.7222 - val_loss: 47.6478 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 457ms/step - accuracy: 0.4545 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5147 - loss: 39.2520 - val_accuracy: 0.5556 - val_loss: 34.0395 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.6324 - loss: 33.9383 - val_accuracy: 0.6111 - val_loss: 32.5594 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8382 - loss: 32.3478 - val_accuracy: 0.6111 - val_loss: 31.0916 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8529 - loss: 30.8544 - val_accuracy: 0.5556 - val_loss: 29.6497 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.9559 - loss: 29.3138 - val_accuracy: 0.5556 - val_loss: 28.2388 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 370ms/step - accuracy: 0.5909 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4s/step - accuracy: 0.6176 - loss: 78.8236 - val_accuracy: 0.6111 - val_loss: 73.6116 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - accuracy: 0.7941 - loss: 73.3126 - val_accuracy: 0.5000 - val_loss: 69.8842 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.8824 - loss: 69.4250 - val_accuracy: 0.5556 - val_loss: 66.0278 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.9853 - loss: 65.4001 - val_accuracy: 0.3889 - val_loss: 62.2831 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.9559 - loss: 61.5162 - val_accuracy: 0.4444 - val_loss: 58.5018 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 512ms/step - accuracy: 0.5000 

  super().__init__(


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 2s/step - accuracy: 0.4301 - loss: 50.8491 - val_accuracy: 0.6111 - val_loss: 46.9933 - learning_rate: 0.0010
Epoch 2/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.7650 - loss: 46.5300 - val_accuracy: 0.6111 - val_loss: 44.1168 - learning_rate: 0.0010
Epoch 3/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.7500 - loss: 43.6612 - val_accuracy: 0.5000 - val_loss: 41.1814 - learning_rate: 0.0010
Epoch 4/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.7102 - loss: 40.6854 - val_accuracy: 0.4444 - val_loss: 38.3147 - learning_rate: 0.0010
Epoch 5/5
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6749 - loss: 37.6972 - val_accuracy: 0.5000 - val_loss: 35.5395 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 539ms/step - accuracy: 0.3182 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5735 - loss: 141.8617 - val_accuracy: 0.5000 - val_loss: 130.4115 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.6324 - loss: 130.3350 - val_accuracy: 0.5556 - val_loss: 122.5059 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.8824 - loss: 122.2602 - val_accuracy: 0.6111 - val_loss: 114.8785 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.9265 - loss: 114.5982 - val_accuracy: 0.6111 - val_loss: 107.5557 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.9412 - loss: 107.2531 - val_accuracy: 0.6111 - val_loss: 100.5560 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 371ms/step - accurac

In [271]:
trial_metrics_dict = {
    "train_acc": [],
    "test_acc": [],
    "val_acc": [],
    "train_val_diff": [],
    "scores": [],
    "channels_selected": [],
}
for i, trial in enumerate(study.trials_dataframe().itertuples()):
    trial_user_attrs = trial.user_attrs_trial_data
    trial_metrics_dict["scores"].append(trial.value)
    trial_metrics_dict["train_acc"].append(np.max(trial_user_attrs['train_accuracy']))
    trial_metrics_dict["test_acc"].append(trial_user_attrs['test_accuracy'])
    trial_metrics_dict["val_acc"].append(np.max(trial_user_attrs['val_accuracy']))
    trial_metrics_dict["channels_selected"].append(trial_user_attrs['channels_selected'])
    trial_metrics_dict["train_val_diff"].append(abs(np.max(trial_user_attrs['train_accuracy']) - np.max(trial_user_attrs['val_accuracy'])))
trial_metrics_df = pd.DataFrame(trial_metrics_dict)
# trial_metrics_df.sort_values("test_acc", ascending=False)
# trial_metrics_df.sort_values("scores", ascending=True)
trial_metrics_df.sort_values("train_val_diff", ascending=True)

Unnamed: 0,train_acc,test_acc,val_acc,train_val_diff,scores,channels_selected
22,0.867647,0.454545,0.777778,0.089869,0.097868,"[P3, Cz, Pz, Fp1, Fp2, T3, A2, T6, T4]"
0,1.0,0.590909,0.888889,0.111111,0.026365,"[P3, F3, Fz, F4, C4, P4, Fp2, T3, O1, O2, F7, ..."
4,0.852941,0.409091,0.722222,0.130719,0.144459,"[F4, P4, T3, T5, O2, A2, T4]"
11,0.691176,0.318182,0.555556,0.135621,0.851982,"[P3, Fz, C4, Cz, Fp1, Fp2, T3, O2, F7, A2]"
10,0.926471,0.590909,0.777778,0.148693,0.091226,"[C3, F4, C4, Cz, Pz, Fp1, O1, O2, F7, A2, T6]"
24,0.764706,0.545455,0.611111,0.153595,0.240452,"[C3, F3, Fp1, T3, T5, A2, T6, T4]"
28,0.764706,0.318182,0.611111,0.153595,0.260291,"[P3, C3, F3, Fz, F4, C4, P4, Cz, Fp1, T5, F7, ..."
20,0.941176,0.636364,0.777778,0.163399,0.076685,"[Fz, F4, P4, Pz, Fp1, T5, F7, T6]"
15,0.823529,0.454545,0.611111,0.212418,0.26023,"[P3, Fz, C4, Pz, Fp2, T5, O2, F7, T4]"
13,1.0,0.545455,0.777778,0.222222,0.085076,"[P3, C3, F3, F4, C4, P4, Pz, Fp1, O1, O2, F8, T6]"


In [273]:
max_val_acc = trial_metrics_df.sort_values("val_acc", ascending=False).query("train_acc > val_acc")[:10].sort_values('train_acc', ascending=False)["val_acc"].max()
trial_metrics_df.sort_values("val_acc", ascending=False).query("train_acc > val_acc")[:10].sort_values('train_acc', ascending=False).query(f"val_acc == {max_val_acc}")

Unnamed: 0,train_acc,test_acc,val_acc,train_val_diff,scores,channels_selected
0,1.0,0.590909,0.888889,0.111111,0.026365,"[P3, F3, Fz, F4, C4, P4, Fp2, T3, O1, O2, F7, ..."


In [276]:
max_val_acc = trial_metrics_df.sort_values("train_val_diff", ascending=True).query("train_acc >= val_acc")[:15]["val_acc"].max()
# max_val_acc = trial_metrics_df.sort_values("train_val_diff", ascending=True)[:15]["val_acc"].max()
best_candidate = trial_metrics_df.sort_values("train_acc", ascending=False).query(f"val_acc == {max_val_acc}")
display(best_candidate)
best_candidate_trial_id = best_candidate.index[0]

rpprint({ k: v for k, v in study.best_trial.params.items() if not k.startswith("channels") })
rprint("test_accuracy =", study.trials[best_candidate_trial_id].user_attrs["trial_data"]["test_accuracy"])
rprint("val_accuracy =", np.max(study.trials[best_candidate_trial_id].user_attrs["trial_data"]["val_accuracy"]))
rprint("channels_selected =", study.trials[best_candidate_trial_id].user_attrs["trial_data"]["channels_selected"])

Unnamed: 0,train_acc,test_acc,val_acc,train_val_diff,scores,channels_selected
0,1.0,0.590909,0.888889,0.111111,0.026365,"[P3, F3, Fz, F4, C4, P4, Fp2, T3, O1, O2, F7, ..."


In [267]:
# Save to disk
np.save(f"./best_subjects_search_results/subject_6_shallow_conv_net_{uuid.uuid4()}.npy", study.trials[best_candidate_trial_id])

In [278]:
rpprint({ k: v for k, v in study.best_trial.params.items() if not k.startswith("channels") })
rprint("test_accuracy =", study.best_trial.user_attrs["trial_data"]["test_accuracy"])
rprint("val_accuracy =", np.max(study.best_trial.user_attrs["trial_data"]["val_accuracy"]))
rprint("channels_selected =", study.best_trial.user_attrs["trial_data"]["channels_selected"])

In [181]:
# channels = fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1]
# channels_of_interest = ['C3', 'F3', 'C4', 'T5', 'O1', 'O2', 'A2', 'T6']
# channels_of_interest_idx = np.where(np.isin(channels, channels_of_interest))[0]

# channels_of_interest_idx

In [280]:
# Model/pipeline hyperparameter optimization test
data_store_dict = {}

def objective_fn_lstm_cnn_net(trial, **kwargs):
    # _SFREQ_ = trial.suggest_categorical("sfreq", [128, 250])
    _SFREQ_ = 300
    # _TRAIN_SIZE_ = trial.suggest_categorical("train_size", [0.8, 0.9])
    _TRAIN_SIZE_ = 0.8
    _TEST_SIZE_ = 1 - _TRAIN_SIZE_
    _BATCH_SIZE = 256
    
    channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
    channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    while len(channels_idx) == 0:
        channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
        channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    
    X, y, _ = data_generator(fat_dataset, subjects=[6], channel_idx=channels_idx, sfreq=_SFREQ_)
    
    _NUM_SAMPLES_ = X.shape[-1]
    _NUM_CHANNELS_ = X.shape[-2]
    # rpprint({
    #     "NUM_SAMPLES": _NUM_SAMPLES_,
    #     "NUM_CHANNELS": _NUM_CHANNELS_,
    #     "X": X.shape,
    #     "y": y.shape,
    #     "channels selected": np.asarray(fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'])[channels_idx]
    # })
    y_encoded = LabelEncoder().fit_transform(y)
    _NUM_CLASSES_ = len(np.unique(y_encoded))
    X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, _NUM_SAMPLES_, _NUM_CHANNELS_), y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
    # X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
    
    model_params = {
        "nb_classes": _NUM_CLASSES_,
        "channels": _NUM_CHANNELS_,
        "samples": _NUM_SAMPLES_,
    }
    
    model = lstm_cnn_net_v2(**model_params)
    # model = lstm_cnn_net(**model_params)
    # model = shallow_conv_net(**model_params)
    model.compile(
        loss="sparse_categorical_crossentropy", 
        optimizer="adam", 
        metrics=["accuracy"]
    )
    
    history = model.fit(
        X_train,
        y_train,
        batch_size=_BATCH_SIZE,
        epochs=25,
        validation_split=0.2,
        callbacks=[
            # keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
            # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
            keras.callbacks.EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True),
            keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
        ],
    )

    trial.set_user_attr("trial_data", {
        "channels_selected": np.asarray(fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'])[channels_idx],
        # "history": history,
        # "model": model,
        "val_accuracy": history.history["val_accuracy"],
        "test_accuracy": model.evaluate(X_test, y_test, batch_size=_BATCH_SIZE)[1]
    })

    # return (1 - np.mean(history.history["val_accuracy"][-2:])) + 0.005*len(channels_idx)
    return (1 - history.history["val_accuracy"][-1])**2 + 0.00005*len(channels_idx)

def heuristic_optimizer(obj_fn = lambda: None, max_iter = 15, show_progress_bar = True, **kwargs):
    
    optuna.logging.set_verbosity(optuna.logging.CRITICAL)

    # trial_data = []
    # trial_callback = lambda study, trial: trial_data.append(trial.user_attrs["trial_data"])
    
    # study = optuna.create_study(direction="minimize", sampler=optuna.samplers.CmaEsSampler())
    study = optuna.create_study(direction="minimize", sampler=optuna.samplers.NSGAIISampler())
    study.optimize(obj_fn, n_trials=max_iter, show_progress_bar=show_progress_bar, callbacks=[], **kwargs)
    
    return study

study = heuristic_optimizer(obj_fn = objective_fn_lstm_cnn_net)

  0%|          | 0/15 [00:00<?, ?it/s]

Sampling frequency of the instance is already 300.0, returning unmodified.
Adding metadata with 3 columns
Epoch 1/25
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5441 - loss: 1.6515 - val_accuracy: 0.5556 - val_loss: 0.7129 - learning_rate: 0.0010
Epoch 2/25
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 65ms/step - accuracy: 0.5735 - loss: 1.1823 - val_accuracy: 0.3889 - val_loss: 0.7137 - learning_rate: 0.0010
Epoch 3/25
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 59ms/step - accuracy: 0.5735 - loss: 1.0984 - val_accuracy: 0.5000 - val_loss: 0.7231 - learning_rate: 0.0010
Epoch 4/25
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 60ms/step - accuracy: 0.4412 - loss: 1.1790 - val_accuracy: 0.4444 - val_loss: 0.7600 - learning_rate: 0.0010
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 295ms/step - accuracy: 0.5909 - loss: 0.6946
Sampling frequency of the instance is already 300.0, retu

In [282]:
rpprint({ k: v for k, v in study.best_trial.params.items() if not k.startswith("channels") })
rprint("test_accuracy =", study.best_trial.user_attrs["trial_data"]["test_accuracy"])
rprint("val_accuracy =", np.max(study.best_trial.user_attrs["trial_data"]["val_accuracy"]))
rprint("channels_selected =", study.best_trial.user_attrs["trial_data"]["channels_selected"])

In [None]:
# Save model
model.save("model.keras")

In [None]:
# Load model from disk
loaded_model = keras.saving.load_model("model.keras", custom_objects=CUSTOM_OBJECTS)

In [None]:
## Pruning and quantization test

from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper

baseline_model = loaded_model
pruning_params = {
    'pruning_schedule': pruning_wrapper.PolynomialDecay(initial_sparsity=0.00,
                                                        final_sparsity=0.80,
                                                        begin_step=0,
                                                        end_step=100)
}
sparse_model = pruning_wrapper.PruneLowMagnitude(baseline_model, **pruning_params)

sparse_model.summary()

### Data exploration, etc.

In [None]:
def create_spectrogram(epoch_data):
    spectrograms = []
    for epoch in epoch_data:
        epoch_spectrograms = []
        for channel in epoch:
            f, t, Sxx = sp.signal.spectrogram(channel, fs=50)
            epoch_spectrograms.append(Sxx)
        spectrograms.append(np.array(epoch_spectrograms))
    return np.array(spectrograms)

spectograms = create_spectrogram(X)
rprint(spectograms.shape) # (180, 21, 129, 2)

num_left_hand_epochs_to_visualize = 2
num_right_hand_epochs_to_visualize = 2

# Select epochs for visualization
left_hand_epochs_to_visualize = np.where((y == 'left_hand'))[0]
right_hand_epochs_to_visualize = np.where((y == 'right_hand'))[0]
left_hand_epochs_to_visualize = np.random.choice(left_hand_epochs_to_visualize, num_left_hand_epochs_to_visualize)
right_hand_epochs_to_visualize = np.random.choice(right_hand_epochs_to_visualize, num_right_hand_epochs_to_visualize)
epochs_to_visualize = np.concatenate((left_hand_epochs_to_visualize, right_hand_epochs_to_visualize))

rprint(epochs_to_visualize)

# Visualize spectrograms
fig, axs = plt.subplots(epochs_to_visualize.size, 1, figsize=(10, 10))
for i, idx in enumerate(epochs_to_visualize):
    axs[i].matshow(spectograms[idx, :, :, 0], cmap='viridis')
    axs[i].set_title(y[idx])
    axs[i].set_xlabel('Frequency (Hz)')
    axs[i].set_ylabel('Time (s)')
    axs[i].set_xlim(0, 60) # Limit x-axis to 60Hz

plt.tight_layout()
plt.show()