In [None]:
# General imports
import os
import time
import random
import datetime
import glob
import pickle
import tqdm
import copy
import optuna
import numpy as np
import pandas as pd
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import mne
from rich import print as rprint
from rich.pretty import pprint as rpprint
from tqdm import tqdm
from itertools import chain
from functools import partial

# JAX + Keras
os.environ["KERAS_BACKEND"] = "jax"
os.environ["TF_USE_LEGACY_KERAS"] = "0"
import jax
import jax.numpy as jnp
import keras
from keras.models import Model
from keras.layers import Dense, Activation, Permute, Dropout
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv1D, MaxPooling1D, AveragePooling1D
from keras.layers import SeparableConv2D, DepthwiseConv2D
from keras.layers import BatchNormalization
from keras.layers import SpatialDropout2D
from keras.regularizers import l1_l2
from keras.layers import Input, Flatten
from keras.constraints import max_norm
from keras import backend as K

# Sklearn
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout, TimeDistributed, BatchNormalization
import h5py
from sklearn import preprocessing

# Dataset
from custom_datasets.fatigue_mi import FatigueMI

%load_ext autoreload
%autoreload 3

In [None]:
SKLRNG = 42
RNG = jax.random.PRNGKey(SKLRNG)

In [None]:
def data_generator(dataset, subjects = [1], channel_idx = [], filters = ([8, 32],), sfreq = 250):

    find_events = lambda raw, event_id: mne.find_events(raw, shortest_event=0, verbose=False) if len(mne.utils._get_stim_channel(None, raw.info, raise_error=False)) > 0 else mne.events_from_annotations(raw, event_id=event_id, verbose=False)[0]
    
    data = dataset.get_data(subjects=subjects)
    
    X = []
    y = []
    metadata = []

    for subject_id in data.keys():
        for session_id in data[subject_id].keys():
            for run_id in data[subject_id][session_id].keys():
                raw = data[subject_id][session_id][run_id]
                
                for fmin, fmax in filters:
                    raw = raw.filter(l_freq = fmin, h_freq = fmax, method = 'iir', picks = 'eeg', verbose = False)
                
                events = find_events(raw, dataset.event_id)

                tmin = dataset.interval[0]
                tmax = dataset.interval[1]

                channels = np.asarray(raw.info['ch_names'])[channel_idx] if len(channel_idx) > 0 else np.asarray(raw.info['ch_names'])

                # rpprint(channels)
                
                stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
                picks = mne.pick_channels(raw.info["ch_names"], include=channels, exclude=stim_channels, ordered=True)

                x = mne.Epochs(
                    raw,
                    events,
                    event_id=dataset.event_id,
                    tmin=tmin,
                    tmax=tmax,
                    proj=False,
                    baseline=None,
                    preload=True,
                    verbose=False,
                    picks=picks,
                    event_repeated="drop",
                    on_missing="ignore",
                )
                x_events = x.events
                inv_events = {k: v for v, k in dataset.event_id.items()}
                labels = [inv_events[e] for e in x_events[:, -1]]

                # rpprint({
                #     "X": np.asarray(x.get_data(copy=False)).shape,
                #     "y": np.asarray(labels).shape,
                #     "channels selected": np.asarray(raw.info['ch_names'])[channel_idx]
                # })

                # x.plot(scalings="auto")
                # display(x.info)
                
                x_resampled = x.resample(sfreq) # Resampler_Epoch
                x_resampled_data = x_resampled.get_data(copy=False) # Convert_Epoch_Array
                x_resampled_data_standard_scaler = np.asarray([
                    StandardScaler().fit_transform(x_resampled_data[i])
                    for i in np.arange(x_resampled_data.shape[0])
                ]) # Standard_Scaler_Epoch

                # x_resampled.plot(scalings="auto")
                # display(x_resampled.info)

                n = x_resampled_data_standard_scaler.shape[0]
                # n = x.get_data(copy=False).shape[0]
                met = pd.DataFrame(index=range(n))
                met["subject"] = subject_id
                met["session"] = session_id
                met["run"] = run_id
                x.metadata = met.copy()
                
                # X.append(x_resampled_data_standard_scaler)
                X.append(x)
                y.append(labels)
                metadata.append(met)

    return np.concatenate(X, axis=0), np.concatenate(y), pd.concat(metadata, ignore_index=True)

fat_dataset = FatigueMI()
# X, y, _ = data_generator(fat_dataset, subjects=fat_dataset.subject_list, sfreq=250)
X, y, _ = data_generator(fat_dataset, subjects=[5], channel_idx=[], sfreq=128)
# X, y, _ = data_generator(fat_dataset, subjects=[6], channel_idx=[1, 2, 5, 12, 13, 14, 17, 18], sfreq=128)
# X, y, _ = data_generator(fat_dataset, subjects=[5, 6, 12, 17, 24, 28, 29], channel_idx=[], sfreq=128)

In [None]:
y_encoded = LabelEncoder().fit_transform(y) # 0 = left, 1 = right, 2 = unlabelled (!NOTE: should be ignored for binary classification)
# rpprint(y_encoded)
rprint({
    "X": X.shape,
    "y": y_encoded.shape
})

NUM_SAMPLES = X.shape[-1]
NUM_CHANNELS = X.shape[-2]
NUM_CLASSES = len(np.unique(y_encoded))

rpprint({
    "NUM_SAMPLES": NUM_SAMPLES,
    "NUM_CHANNELS": NUM_CHANNELS,
    "NUM_CLASSES": NUM_CLASSES
})

# sns.histplot(y_encoded); # Plot the class distribution

In [None]:
TRAIN_SIZE = 0.8
TEST_SIZE = 0.2

# X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, NUM_SAMPLES, NUM_CHANNELS), y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=None, shuffle=False)
# X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, NUM_SAMPLES, NUM_CHANNELS), y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=y_encoded, shuffle=True)
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=y_encoded, shuffle=True)
# X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=None, shuffle=False)

total = X_train.shape[0] + X_test.shape[0]
rpprint({
    "X_train": f"{X_train.shape[0]} ({X_train.shape[0] / total * 100:.2f}%)",
    "X_test": f"{X_test.shape[0]} ({X_test.shape[0] / total * 100:.2f}%)",
}, expand_all=True)

# batch_size = 32
# batches = batch_generator(X_train, batch_size, RNG)

In [None]:
# region Helper funcs
def shallow_conv_net_square_layer(x):
    return jnp.square(x)

def shallow_conv_net_log_layer(x):
    return jnp.log(jnp.clip(x, 1e-7, 10000))

CUSTOM_OBJECTS = {
    "shallow_conv_net_square_layer": shallow_conv_net_square_layer, 
    "shallow_conv_net_log_layer": shallow_conv_net_log_layer 
}
# endregion Helper funcs

# region Models
def eeg_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, dropout_rate = 0.5, 
            kernLength = 64, F1 = 8, D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):
    """
        From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """

    dropoutType = {'Dropout': Dropout, 'SpatialDropout2D': SpatialDropout2D}[dropoutType]

    input1   = Input(shape = (channels, samples, 1))

    ##################################################################
    block1       = Conv2D(F1, (1, kernLength), padding = 'same',
                                   input_shape = (channels, samples, 1),
                                   use_bias = False)(input1)
    block1       = BatchNormalization()(block1)
    block1       = DepthwiseConv2D((channels, 1), use_bias = False, 
                                   depth_multiplier = D,
                                   depthwise_constraint = max_norm(1.))(block1)
    block1       = BatchNormalization()(block1)
    block1       = Activation('elu')(block1)
    block1       = AveragePooling2D((1, 4))(block1)
    block1       = dropoutType(dropout_rate)(block1)
    
    block2       = SeparableConv2D(F2, (1, 16), use_bias = False, padding = 'same')(block1)
    block2       = BatchNormalization()(block2)
    block2       = Activation('elu')(block2)
    block2       = AveragePooling2D((1, 8))(block2)
    block2       = dropoutType(dropout_rate)(block2)
        
    flatten      = Flatten(name = 'flatten')(block2)
    
    dense        = Dense(nb_classes, name = 'dense', 
                         kernel_constraint = max_norm(norm_rate))(flatten)
    softmax      = Activation('softmax', name = 'softmax')(dense)
    
    return Model(inputs=input1, outputs=softmax)

def deep_conv_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, dropout_rate = 0.5):
    """
        From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """
    input_main   = Input((channels, samples, 1))
    block1       = Conv2D(25, (1, 5), 
                                 input_shape=(channels, samples, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(input_main)
    block1       = Conv2D(25, (channels, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block1       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
    block1       = Activation('elu')(block1)
    block1       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block1)
    block1       = Dropout(dropout_rate)(block1)
  
    block2       = Conv2D(50, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block2       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block2)
    block2       = Activation('elu')(block2)
    block2       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block2)
    block2       = Dropout(dropout_rate)(block2)
    
    block3       = Conv2D(100, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block2)
    block3       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block3)
    block3       = Activation('elu')(block3)
    block3       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block3)
    block3       = Dropout(dropout_rate)(block3)
    
    block4       = Conv2D(200, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block3)
    block4       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block4)
    block4       = Activation('elu')(block4)
    block4       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block4)
    block4       = Dropout(dropout_rate)(block4)
    
    flatten      = Flatten()(block4)
    
    dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
    softmax      = Activation('softmax')(dense)
    
    return Model(inputs=input_main, outputs=softmax)

def shallow_conv_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, **kwargs):
    """
        From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """

    _POOL_SIZE_D2_ = kwargs.get("pool_size_d2", 35)
    _STRIDES_D2_ = kwargs.get("strides_d2", 7)
    _CONV_FILTERS_D2_ = kwargs.get("conv_filters_d2", 13)

    _POOL_SIZE_ = kwargs.get("pool_size", (1, _POOL_SIZE_D2_))
    _STRIDES_ = kwargs.get("strides", (1, _STRIDES_D2_))
    _CONV_FILTERS_ = kwargs.get("conv_filters", (1, _CONV_FILTERS_D2_))

    _CONV2D_1_UNITS_ = kwargs.get("conv2d_1_units", 40)
    _CONV2D_2_UNITS_ = kwargs.get("conv2d_2_units", 40)
    _L2_REG_1_ = kwargs.get("l2_reg_1", 0.01)
    _L2_REG_2_ = kwargs.get("l2_reg_2", 0.01)
    _L2_REG_3_ = kwargs.get("l2_reg_3", 0.01)
    _DROPOUT_RATE_ = kwargs.get("dropout_rate", 0.5)

    input_main   = Input(shape=(channels, samples, 1))
    block1       = Conv2D(_CONV2D_1_UNITS_, _CONV_FILTERS_,
                                 input_shape=(channels, samples, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)), kernel_regularizer=keras.regularizers.L2(_L2_REG_1_))(input_main)
    # block1       = Conv2D(40, (channels, 1), use_bias=False, 
    #                       kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block1       = Conv2D(_CONV2D_2_UNITS_, (channels, 1), use_bias=False, 
                          kernel_constraint = max_norm(2., axis=(0,1,2)), kernel_regularizer=keras.regularizers.L2(_L2_REG_2_))(block1)
    block1       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
    block1       = Activation(shallow_conv_net_square_layer)(block1)
    block1       = AveragePooling2D(pool_size=_POOL_SIZE_, strides=_STRIDES_)(block1)
    block1       = Activation(shallow_conv_net_log_layer)(block1)
    block1       = Dropout(_DROPOUT_RATE_)(block1)
    flatten      = Flatten()(block1)
    # dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
    dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5), kernel_regularizer=keras.regularizers.L2(_L2_REG_3_))(flatten)
    softmax      = Activation('softmax')(dense)
    
    return Model(inputs=input_main, outputs=softmax)

def lstm_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES):
    model = Sequential()
    model.add(Input(shape=(samples, channels)))
    model.add(LSTM(40, return_sequences=True, stateful=False))
    # model.add(LSTM(40, return_sequences=True, stateful=False, batch_input_shape=(batch_size, timesteps, data_dim)))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    model.add(LSTM(40, return_sequences=True, stateful=False))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    model.add(LSTM(40, stateful=False))
    model.add(Dropout(0.5))
    # model.add(TimeDistributed(Dense(T_train)))
    model.add(Dense(50,activation='softmax'))
    model.add(Dense(nb_classes, activation='softmax'))
    
    return model

def lstm_cnn_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, dropout_rate_1 = 0.5, dropout_rate_2 = 0.5):
    model = Sequential()
    model.add(Input(shape=(samples, channels)))

    # add 1-layer cnn
    # model.add(Conv1D(40, kernel_size=20, strides=4))
    model.add(Conv1D(40, kernel_size=20, strides=4))
    # model.add(Conv1D(40, kernel_size=20, strides=4, kernel_regularizer=keras.regularizers.L2(0.001), activity_regularizer=keras.regularizers.L2(0.001)))
    model.add(Activation('relu'))
    model.add(Dropout(dropout_rate_1))
    model.add(BatchNormalization())
    model.add(MaxPooling1D(pool_size=4, strides=4))


    # add 1-layer lstm
    # model.add(LSTM(50, return_sequences=True, stateful=False))
    model.add(LSTM(128, return_sequences=True, stateful=False))
    # model.add(LSTM(50, return_sequences=True, stateful=False, kernel_regularizer=keras.regularizers.L2(0.01)))
    # model.add(LSTM(128, return_sequences=True, stateful=False))
    # model.add(LSTM(32, return_sequences=True, stateful=False))
    model.add(Dropout(dropout_rate_2))
    model.add(BatchNormalization())
    model.add(Flatten())
    model.add(Dense(nb_classes, activation='softmax'))
    # model.add(Dense(nb_classes, activation='softmax', kernel_regularizer=keras.regularizers.L2(0.01)))
    
    return model

def lstm_cnn_net_v2(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES):
    model = Sequential()
    model.add(Input(shape=(samples, channels)))

    # Add 1-layer LSTM
    model.add(LSTM(50, return_sequences=True, stateful=False))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    # Add 1-layer CNN
    model.add(Conv1D(40, kernel_size=20, strides=4))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    model.add(MaxPooling1D(pool_size=4, strides=4))
    model.add(Flatten())
    # Fully connected layer for classification
    model.add(Dense(nb_classes, activation='softmax'))
    return model

# endregion Models

In [109]:
NUM_FOLDS = 5
kfold = KFold(n_splits=NUM_FOLDS, shuffle=True)
acc_per_fold = []
loss_per_fold = []

model = shallow_conv_net()
# model = lstm_cnn_net()

# K-fold Cross Validation model evaluation
fold_num = 1
for train, test in kfold.split(X_train, y_train):
        # Fit data to model
        model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
        
        history = model.fit(
                X_train[train],
                y_train[train],
                batch_size=64,
                epochs=100,
                validation_data=(X_train[test], y_encoded[test]),
                # validation_split=0.2,
                shuffle=True,
                callbacks=[
                        keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
                        keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
                        # keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
                        # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
                ],
        )
        
        # Generate generalization metrics
        scores = model.evaluate(X_train[test], y_train[test], verbose=0)
        rprint(f'Score for fold {fold_num}: {model.metrics_names[0]} of {scores[0]}; {model.metrics_names[1]} of {scores[1]*100}%')
        acc_per_fold.append(scores[1] * 100)
        loss_per_fold.append(scores[0])
        
        # Increase fold number
        fold_num = fold_num + 1

Epoch 1/100


  super().__init__(


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 2s/step - accuracy: 0.5052 - loss: 1.8567 - val_accuracy: 0.7222 - val_loss: 1.4612 - learning_rate: 0.0010
Epoch 2/100
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.6103 - loss: 1.5053 - val_accuracy: 0.6111 - val_loss: 1.4915 - learning_rate: 0.0010
Epoch 3/100
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.6599 - loss: 1.4071 - val_accuracy: 0.5556 - val_loss: 1.5090 - learning_rate: 0.0010
Epoch 4/100
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.6749 - loss: 1.3466 - val_accuracy: 0.5000 - val_loss: 1.5168 - learning_rate: 0.0010
Epoch 5/100
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.6801 - loss: 1.4520 - val_accuracy: 0.4444 - val_loss: 1.5222 - learning_rate: 0.0010
Epoch 6/100
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy

Epoch 1/100
[1m1/2[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m0s[0m 865ms/step - accuracy: 0.8906 - loss: 0.7738

KeyboardInterrupt: 

In [None]:
mean_cross_val_score = np.mean(acc_per_fold)
rprint("Mean Cross Validation Score (5-fold):", mean_cross_val_score)

results_test = model.evaluate(X_test, y_test)
rprint("Test Accuracy:", results_test[1])

In [None]:
# results_val = model.evaluate(X_val, y_val)
results_test = model.evaluate(X_test, y_test)

# X_test.reshape(-1, NUM_SAMPLES, NUM_CHANNELS).shape

In [110]:
model = shallow_conv_net()
# model = lstm_cnn_net()
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

history = model.fit(
    X_train,
    y_train,
    batch_size=128,
    epochs=25,
    # validation_data=(X_test, y_test),
    validation_split=0.2,
    shuffle=True,
    callbacks=[
        keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
        keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
        # keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
        # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
    ],
)

Epoch 1/25


  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5735 - loss: 1.7503 - val_accuracy: 0.5556 - val_loss: 1.5632 - learning_rate: 0.0010
Epoch 2/25
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.6324 - loss: 1.4596 - val_accuracy: 0.5000 - val_loss: 1.5906 - learning_rate: 0.0010
Epoch 3/25
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.7353 - loss: 1.4258 - val_accuracy: 0.3889 - val_loss: 1.6076 - learning_rate: 0.0010
Epoch 4/25
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.8529 - loss: 1.3003 - val_accuracy: 0.3889 - val_loss: 1.6232 - learning_rate: 0.0010
Epoch 5/25
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.8971 - loss: 1.2525 - val_accuracy: 0.3333 - val_loss: 1.6357 - learning_rate: 0.0010
Epoch 6/25
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.8

In [114]:
# model.history.history.keys()

dict_keys(['accuracy', 'loss', 'val_accuracy', 'val_loss', 'learning_rate'])

In [207]:
def objective_fn(trial, subjects=[], model_str = "lstm_cnn_net", sfreq = 128, batch_size = 128, **kwargs):
    _SFREQ_ = sfreq if sfreq else trial.suggest_categorical("sfreq", [128, 256, 300])
    _TRAIN_SIZE_ = 0.8
    _TEST_SIZE_ = 1 - _TRAIN_SIZE_
    _BATCH_SIZE = batch_size if batch_size else trial.suggest_int("batch_size", 32, 256, step=32)

    models = {
        "lstm_cnn_net": lstm_cnn_net,
        "shallow_conv_net": shallow_conv_net
    }
    
    subjects = subjects if len(subjects) > 0 else fat_dataset.subject_list
    model_fn = models[model_str]
    
    channels_idx = []
    # channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
    # channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    # while len(channels_idx) == 0:
    #     channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
    #     channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    
    X, y, _ = data_generator(fat_dataset, subjects=subjects, channel_idx=channels_idx, sfreq=_SFREQ_)
    
    _NUM_SAMPLES_ = X.shape[-1]
    _NUM_CHANNELS_ = X.shape[-2]
    
    y_encoded = LabelEncoder().fit_transform(y)
    _NUM_CLASSES_ = len(np.unique(y_encoded))
    if "lstm" in model_str:
        X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, _NUM_SAMPLES_, _NUM_CHANNELS_), y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
    else:
        X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, shuffle=True, stratify=y_encoded)
    
    model_params = {
        "nb_classes": _NUM_CLASSES_,
        "channels": _NUM_CHANNELS_,
        "samples": _NUM_SAMPLES_,
        
        "pool_size_d2": trial.suggest_int("pool_size_d2", 15, 75, step=5),
        "strides_d2": trial.suggest_int("strides_d2", 3, 31, step=2),
        "conv_filters_d2": trial.suggest_int("conv_filters_d2", 5, 35, step=2),
        
        "conv2d_1_units": trial.suggest_int("conv2d_1_units", 20, 60, step=10),
        "conv2d_2_units": trial.suggest_int("conv2d_2_units", 20, 60, step=10),
        "l2_reg_1": trial.suggest_float("l2_reg_1", 0.01, 0.1, step=0.01),
        "l2_reg_2": trial.suggest_float("l2_reg_2", 0.01, 0.1, step=0.01),
        "l2_reg_3": trial.suggest_float("l2_reg_3", 0.01, 0.1, step=0.01),
        "dropout_rate": trial.suggest_float("dropout_rate", 0.1, 0.9, step=0.1),
    }

    NUM_FOLDS = 2
    kfold = KFold(n_splits=NUM_FOLDS, shuffle=True)
    acc_per_fold = []
    loss_per_fold = []

    # K-fold Cross Validation model evaluation
    fold_num = 1
    for train, test in kfold.split(X_train, y_train):
        model = model_fn(**model_params)
        model.compile(
            loss="sparse_categorical_crossentropy", 
            optimizer="adam", 
            metrics=["accuracy"]
        )
        
        history = model.fit(
            X_train[train],
            y_train[train],
            batch_size=_BATCH_SIZE,
            epochs=10,
            # validation_split=0.2,
            validation_data=(X_train[test], y_encoded[test]),
            shuffle=True,
            callbacks=[
                # keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
                # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
                keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True),
                keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
            ],
        )

        trial.set_user_attr("trial_data", {
            "fold_no": fold_num,
            "channels_selected": np.asarray(fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'])[channels_idx],
            # "history": history,
            "model": model.to_json(),
            "train_accuracy": history.history["accuracy"],
            "val_accuracy": history.history["val_accuracy"],
            "test_accuracy": model.evaluate(X_test, y_test, batch_size=_BATCH_SIZE)[1],
        })
        
        # Generate generalization metrics
        scores = model.evaluate(X_train[test], y_train[test], verbose=0)
        acc_per_fold.append(scores[1])
        loss_per_fold.append(scores[0])
        
        # Increase fold number
        fold_num = fold_num + 1

    # return (1 - np.mean(history.history["val_accuracy"][-2:])) + 0.005*len(channels_idx)
    # return (1 - np.mean(history.history["val_accuracy"]))**2 + 2.5e-1*(np.mean(history.history["accuracy"]) - np.mean(history.history["val_accuracy"]))**2 + 1.5e-1*(1 - np.max(history.history["val_accuracy"]))**2 + 5e-2*len(channels_idx)
    return (1 - np.mean(acc_per_fold))**2 #+ 5e-4*len(channels_idx)

def heuristic_optimizer(obj_fn, max_iter = 25, show_progress_bar = True, subjects=[], model_str="lstm_cnn_net", sfreq = 128, batch_size = 128, **kwargs):
    
    optuna.logging.set_verbosity(optuna.logging.CRITICAL)

    objective = partial(obj_fn, subjects = subjects, model_str = model_str, sfreq = sfreq, batch_size = batch_size)
    
    study = optuna.create_study(direction="minimize", sampler=optuna.samplers.NSGAIIISampler())
    study.optimize(objective, n_trials=max_iter, show_progress_bar=show_progress_bar, callbacks=[], **kwargs)
    
    return study

study = heuristic_optimizer(obj_fn = objective_fn, model_str="shallow_conv_net", subjects=[1], sfreq = None, batch_size = None, max_iter = 25)

  study = optuna.create_study(direction="minimize", sampler=optuna.samplers.NSGAIIISampler())


  0%|          | 0/25 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/10


  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4186 - loss: 3.6707 - val_accuracy: 0.5814 - val_loss: 2.9876 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 115ms/step - accuracy: 0.7442 - loss: 2.7774 - val_accuracy: 0.5349 - val_loss: 3.0522 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 62ms/step - accuracy: 0.8605 - loss: 2.6585 - val_accuracy: 0.5349 - val_loss: 3.0764 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 66ms/step - accuracy: 0.8837 - loss: 2.5547 - val_accuracy: 0.5349 - val_loss: 3.0535 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 70ms/step - accuracy: 0.9302 - loss: 2.4828 - val_accuracy: 0.4884 - val_loss: 3.0059 - learning_rate: 1.0000e-04
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step - accuracy

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.5116 - loss: 3.7284 - val_accuracy: 0.4419 - val_loss: 2.9446 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - accuracy: 0.7442 - loss: 2.7925 - val_accuracy: 0.4419 - val_loss: 2.9010 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.8140 - loss: 2.6848 - val_accuracy: 0.4186 - val_loss: 2.8647 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.9070 - loss: 2.5616 - val_accuracy: 0.4419 - val_loss: 2.8274 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.9535 - loss: 2.4685 - val_accuracy: 0.4651 - val_loss: 2.7899 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.9

  super().__init__(


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2s/step - accuracy: 0.4663 - loss: 1.9575 - val_accuracy: 0.5116 - val_loss: 1.9212 - learning_rate: 0.0010
Epoch 2/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.5543 - loss: 1.3958 - val_accuracy: 0.5116 - val_loss: 1.8887 - learning_rate: 0.0010
Epoch 3/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.6839 - loss: 1.3571 - val_accuracy: 0.4884 - val_loss: 1.8514 - learning_rate: 0.0010
Epoch 4/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.6371 - loss: 1.2492 - val_accuracy: 0.4186 - val_loss: 1.8019 - learning_rate: 0.0010
Epoch 5/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.8602 - loss: 1.1569 - val_accuracy: 0.3953 - val_loss: 1.7721 - learning_rate: 0.0010
Epoch 6/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.9

  super().__init__(


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1s/step - accuracy: 0.4506 - loss: 2.0981 - val_accuracy: 0.4419 - val_loss: 1.5183 - learning_rate: 0.0010
Epoch 2/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.6424 - loss: 1.4706 - val_accuracy: 0.4186 - val_loss: 1.5256 - learning_rate: 0.0010
Epoch 3/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.6892 - loss: 1.3772 - val_accuracy: 0.4186 - val_loss: 1.5251 - learning_rate: 0.0010
Epoch 4/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.7720 - loss: 1.2430 - val_accuracy: 0.4651 - val_loss: 1.5285 - learning_rate: 0.0010
Epoch 5/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.8394 - loss: 1.1352 - val_accuracy: 0.4186 - val_loss: 1.5189 - learning_rate: 1.0000e-04
Epoch 6/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy:

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4884 - loss: 2.8297 - val_accuracy: 0.5116 - val_loss: 2.0367 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.6279 - loss: 1.7375 - val_accuracy: 0.5116 - val_loss: 1.9819 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.6047 - loss: 1.6695 - val_accuracy: 0.5116 - val_loss: 1.9456 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.7209 - loss: 1.6060 - val_accuracy: 0.5116 - val_loss: 1.9319 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.6279 - loss: 1.6840 - val_accuracy: 0.5116 - val_loss: 1.9080 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.6

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.5349 - loss: 2.5265 - val_accuracy: 0.5581 - val_loss: 1.7044 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.4884 - loss: 1.7393 - val_accuracy: 0.5581 - val_loss: 1.7001 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.5814 - loss: 1.7890 - val_accuracy: 0.5581 - val_loss: 1.6960 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.6744 - loss: 1.7049 - val_accuracy: 0.6047 - val_loss: 1.6941 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.6512 - loss: 1.7056 - val_accuracy: 0.5814 - val_loss: 1.6914 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.7

  super().__init__(


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 2s/step - accuracy: 0.5027 - loss: 3.0234 - val_accuracy: 0.4186 - val_loss: 2.3904 - learning_rate: 0.0010
Epoch 2/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.8188 - loss: 2.1559 - val_accuracy: 0.5116 - val_loss: 2.3595 - learning_rate: 0.0010
Epoch 3/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.8653 - loss: 2.0167 - val_accuracy: 0.5349 - val_loss: 2.3284 - learning_rate: 0.0010
Epoch 4/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.9586 - loss: 1.8652 - val_accuracy: 0.5581 - val_loss: 2.3098 - learning_rate: 0.0010
Epoch 5/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 1.0000 - loss: 1.7725 - val_accuracy: 0.4884 - val_loss: 2.2990 - learning_rate: 0.0010
Epoch 6/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 1.0

  super().__init__(


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1s/step - accuracy: 0.3421 - loss: 3.1944 - val_accuracy: 0.4419 - val_loss: 2.4431 - learning_rate: 0.0010
Epoch 2/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.7098 - loss: 2.1296 - val_accuracy: 0.4419 - val_loss: 2.3646 - learning_rate: 0.0010
Epoch 3/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.9067 - loss: 1.9519 - val_accuracy: 0.4651 - val_loss: 2.3330 - learning_rate: 0.0010
Epoch 4/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.9482 - loss: 1.8195 - val_accuracy: 0.4884 - val_loss: 2.2940 - learning_rate: 0.0010
Epoch 5/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.9741 - loss: 1.7331 - val_accuracy: 0.4884 - val_loss: 2.2628 - learning_rate: 0.0010
Epoch 6/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 1.0

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5349 - loss: 9.0087 - val_accuracy: 0.5814 - val_loss: 8.1162 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 0.6279 - loss: 8.0474 - val_accuracy: 0.4884 - val_loss: 7.9525 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8605 - loss: 7.6234 - val_accuracy: 0.5116 - val_loss: 7.7130 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9535 - loss: 7.3345 - val_accuracy: 0.5349 - val_loss: 7.4794 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 1.0000 - loss: 6.9972 - val_accuracy: 0.4651 - val_loss: 7.2547 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 1.0

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.5349 - loss: 8.8240 - val_accuracy: 0.5116 - val_loss: 8.3056 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 0.5814 - loss: 8.0174 - val_accuracy: 0.5349 - val_loss: 8.7045 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.8605 - loss: 7.6255 - val_accuracy: 0.5349 - val_loss: 8.6253 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.8837 - loss: 7.3260 - val_accuracy: 0.5349 - val_loss: 8.2470 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.9302 - loss: 6.9949 - val_accuracy: 0.5116 - val_loss: 7.7908 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.6047 - loss: 6.3850 - val_accuracy: 0.4884 - val_loss: 5.8882 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.9535 - loss: 5.6999 - val_accuracy: 0.4419 - val_loss: 5.9344 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.9302 - loss: 5.3827 - val_accuracy: 0.4419 - val_loss: 5.6536 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 1.0000 - loss: 5.0943 - val_accuracy: 0.5814 - val_loss: 5.3674 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 1.0000 - loss: 4.8520 - val_accuracy: 0.5581 - val_loss: 5.1581 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 1.0

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.4419 - loss: 6.7166 - val_accuracy: 0.3953 - val_loss: 5.9604 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.6977 - loss: 5.7746 - val_accuracy: 0.5349 - val_loss: 6.0535 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.7907 - loss: 5.4652 - val_accuracy: 0.5116 - val_loss: 5.8010 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9302 - loss: 5.1908 - val_accuracy: 0.4419 - val_loss: 5.5152 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 1.0000 - loss: 4.9201 - val_accuracy: 0.4651 - val_loss: 5.3257 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 1.0

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.5581 - loss: 4.9418 - val_accuracy: 0.5814 - val_loss: 4.5729 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.4186 - loss: 4.6346 - val_accuracy: 0.5814 - val_loss: 4.5178 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.6047 - loss: 4.4252 - val_accuracy: 0.5814 - val_loss: 4.4635 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.7674 - loss: 4.2699 - val_accuracy: 0.5814 - val_loss: 4.3814 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step - accuracy: 0.8372 - loss: 4.0222 - val_accuracy: 0.5814 - val_loss: 4.2814 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 45ms/step - accuracy: 0.9

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.3953 - loss: 5.1933 - val_accuracy: 0.5349 - val_loss: 4.5853 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.5814 - loss: 4.5349 - val_accuracy: 0.5581 - val_loss: 4.4533 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.5814 - loss: 4.4900 - val_accuracy: 0.5814 - val_loss: 4.3669 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.9070 - loss: 4.2189 - val_accuracy: 0.5814 - val_loss: 4.2861 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.8140 - loss: 4.0589 - val_accuracy: 0.5814 - val_loss: 4.2042 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.8

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4186 - loss: 5.5190 - val_accuracy: 0.5814 - val_loss: 4.3939 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.6512 - loss: 4.3717 - val_accuracy: 0.4651 - val_loss: 4.4228 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.8140 - loss: 4.1860 - val_accuracy: 0.3953 - val_loss: 4.3922 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.8837 - loss: 3.9952 - val_accuracy: 0.3953 - val_loss: 4.3094 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.9767 - loss: 3.8422 - val_accuracy: 0.4186 - val_loss: 4.2059 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.9

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.5814 - loss: 5.3125 - val_accuracy: 0.5581 - val_loss: 4.5615 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.7209 - loss: 4.3399 - val_accuracy: 0.5349 - val_loss: 4.4736 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.8605 - loss: 4.1561 - val_accuracy: 0.5581 - val_loss: 4.3992 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.9302 - loss: 4.0075 - val_accuracy: 0.5349 - val_loss: 4.3069 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.9535 - loss: 3.8278 - val_accuracy: 0.5116 - val_loss: 4.2170 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 1.0

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.7209 - loss: 6.3861 - val_accuracy: 0.4651 - val_loss: 5.8913 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.7209 - loss: 5.4493 - val_accuracy: 0.4651 - val_loss: 5.7796 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.7442 - loss: 5.2788 - val_accuracy: 0.4651 - val_loss: 5.7486 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.7674 - loss: 5.2323 - val_accuracy: 0.4419 - val_loss: 5.6603 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.8372 - loss: 4.9638 - val_accuracy: 0.4186 - val_loss: 5.5230 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.9

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.5116 - loss: 7.0715 - val_accuracy: 0.5116 - val_loss: 5.7166 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 0.5349 - loss: 5.7029 - val_accuracy: 0.5116 - val_loss: 5.6304 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.5814 - loss: 5.5789 - val_accuracy: 0.5581 - val_loss: 5.5400 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.7442 - loss: 5.3546 - val_accuracy: 0.5581 - val_loss: 5.4477 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.6977 - loss: 5.1606 - val_accuracy: 0.5349 - val_loss: 5.3535 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.9

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.5349 - loss: 6.0597 - val_accuracy: 0.4884 - val_loss: 5.6468 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.4651 - loss: 5.4378 - val_accuracy: 0.5116 - val_loss: 5.3537 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.6047 - loss: 5.2449 - val_accuracy: 0.4651 - val_loss: 5.1745 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.6744 - loss: 5.0564 - val_accuracy: 0.4884 - val_loss: 5.0329 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.7209 - loss: 4.9348 - val_accuracy: 0.6047 - val_loss: 4.9035 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.7

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.4419 - loss: 6.3446 - val_accuracy: 0.5581 - val_loss: 5.5558 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.4419 - loss: 5.6158 - val_accuracy: 0.5581 - val_loss: 5.4330 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.4651 - loss: 5.4415 - val_accuracy: 0.4419 - val_loss: 5.3049 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.5349 - loss: 5.2881 - val_accuracy: 0.3953 - val_loss: 5.1762 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.7209 - loss: 5.0910 - val_accuracy: 0.3953 - val_loss: 5.0479 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.6

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.4186 - loss: 5.0643 - val_accuracy: 0.5814 - val_loss: 3.9872 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.6977 - loss: 3.9360 - val_accuracy: 0.5814 - val_loss: 3.9403 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.7907 - loss: 3.8424 - val_accuracy: 0.5581 - val_loss: 3.8943 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.9070 - loss: 3.6479 - val_accuracy: 0.5116 - val_loss: 3.8387 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.7442 - loss: 3.6477 - val_accuracy: 0.5116 - val_loss: 3.7766 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.9

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 942ms/step - accuracy: 0.6047 - loss: 4.7236 - val_accuracy: 0.5581 - val_loss: 4.1045 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8372 - loss: 3.8592 - val_accuracy: 0.5581 - val_loss: 4.0088 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.8605 - loss: 3.7664 - val_accuracy: 0.5581 - val_loss: 3.9545 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.9070 - loss: 3.6349 - val_accuracy: 0.5581 - val_loss: 3.9043 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.8837 - loss: 3.5255 - val_accuracy: 0.5581 - val_loss: 3.8511 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 

  super().__init__(


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 2s/step - accuracy: 0.4869 - loss: 5.4771 - val_accuracy: 0.5581 - val_loss: 5.0651 - learning_rate: 0.0010
Epoch 2/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.7306 - loss: 4.7530 - val_accuracy: 0.5581 - val_loss: 4.9479 - learning_rate: 0.0010
Epoch 3/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.7771 - loss: 4.5980 - val_accuracy: 0.5349 - val_loss: 4.7844 - learning_rate: 0.0010
Epoch 4/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.6994 - loss: 4.4234 - val_accuracy: 0.5814 - val_loss: 4.5948 - learning_rate: 0.0010
Epoch 5/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.8653 - loss: 4.1690 - val_accuracy: 0.6047 - val_loss: 4.4091 - learning_rate: 0.0010
Epoch 6/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.9

  super().__init__(


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1s/step - accuracy: 0.5078 - loss: 5.6311 - val_accuracy: 0.4884 - val_loss: 5.1025 - learning_rate: 0.0010
Epoch 2/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.6684 - loss: 4.8469 - val_accuracy: 0.4884 - val_loss: 5.1149 - learning_rate: 0.0010
Epoch 3/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.7253 - loss: 4.5907 - val_accuracy: 0.4419 - val_loss: 4.9925 - learning_rate: 0.0010
Epoch 4/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.7565 - loss: 4.4444 - val_accuracy: 0.3721 - val_loss: 4.7935 - learning_rate: 0.0010
Epoch 5/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.7670 - loss: 4.2423 - val_accuracy: 0.2791 - val_loss: 4.5965 - learning_rate: 0.0010
Epoch 6/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.8

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.6047 - loss: 1.9891 - val_accuracy: 0.5581 - val_loss: 1.6457 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 0.8372 - loss: 1.4399 - val_accuracy: 0.4186 - val_loss: 1.6233 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - accuracy: 1.0000 - loss: 1.2932 - val_accuracy: 0.4186 - val_loss: 1.6377 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - accuracy: 1.0000 - loss: 1.1859 - val_accuracy: 0.3953 - val_loss: 1.6679 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 1.0000 - loss: 1.1033 - val_accuracy: 0.3953 - val_loss: 1.7168 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step - accuracy: 1.0

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 988ms/step - accuracy: 0.4651 - loss: 2.0650 - val_accuracy: 0.5349 - val_loss: 1.6247 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 40ms/step - accuracy: 0.8140 - loss: 1.4700 - val_accuracy: 0.5581 - val_loss: 1.5774 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9302 - loss: 1.2998 - val_accuracy: 0.5116 - val_loss: 1.5917 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 1.0000 - loss: 1.2011 - val_accuracy: 0.5116 - val_loss: 1.6349 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 1.0000 - loss: 1.1141 - val_accuracy: 0.4419 - val_loss: 1.6709 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.5116 - loss: 4.7900 - val_accuracy: 0.4651 - val_loss: 4.2773 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.5814 - loss: 4.0933 - val_accuracy: 0.4884 - val_loss: 4.0869 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.7209 - loss: 3.9303 - val_accuracy: 0.5814 - val_loss: 3.9827 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.7209 - loss: 3.8235 - val_accuracy: 0.5116 - val_loss: 3.9122 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.8372 - loss: 3.6835 - val_accuracy: 0.4651 - val_loss: 3.8494 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.8

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 917ms/step - accuracy: 0.4884 - loss: 4.8735 - val_accuracy: 0.5581 - val_loss: 4.1371 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.5116 - loss: 4.1019 - val_accuracy: 0.4884 - val_loss: 4.0203 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.7674 - loss: 3.9337 - val_accuracy: 0.4651 - val_loss: 3.9746 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.7907 - loss: 3.7813 - val_accuracy: 0.3953 - val_loss: 3.9392 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.6977 - loss: 3.7670 - val_accuracy: 0.4419 - val_loss: 3.8931 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4186 - loss: 4.8338 - val_accuracy: 0.4186 - val_loss: 3.1564 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.6047 - loss: 3.0882 - val_accuracy: 0.4186 - val_loss: 3.4249 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.6512 - loss: 2.9973 - val_accuracy: 0.4186 - val_loss: 3.5850 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.7907 - loss: 2.7675 - val_accuracy: 0.4186 - val_loss: 3.6173 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.8605 - loss: 2.7151 - val_accuracy: 0.4186 - val_loss: 3.4704 - learning_rate: 1.0000e-04
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy:

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 932ms/step - accuracy: 0.5116 - loss: 3.8895 - val_accuracy: 0.4884 - val_loss: 3.0772 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 0.6279 - loss: 3.0489 - val_accuracy: 0.4884 - val_loss: 3.0573 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.7442 - loss: 2.8631 - val_accuracy: 0.4651 - val_loss: 3.0353 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8140 - loss: 2.7478 - val_accuracy: 0.4884 - val_loss: 3.0133 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8140 - loss: 2.7514 - val_accuracy: 0.5116 - val_loss: 2.9927 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5349 - loss: 3.8037 - val_accuracy: 0.4884 - val_loss: 3.3742 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.9535 - loss: 3.0615 - val_accuracy: 0.4884 - val_loss: 3.5327 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 1.0000 - loss: 2.8682 - val_accuracy: 0.4884 - val_loss: 3.5740 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 1.0000 - loss: 2.7429 - val_accuracy: 0.4884 - val_loss: 3.5072 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 1.0000 - loss: 2.6142 - val_accuracy: 0.4651 - val_loss: 3.4325 - learning_rate: 1.0000e-04
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy:

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 917ms/step - accuracy: 0.5814 - loss: 3.7957 - val_accuracy: 0.5116 - val_loss: 3.3158 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 0.8372 - loss: 3.1225 - val_accuracy: 0.4419 - val_loss: 3.2804 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.9767 - loss: 2.9251 - val_accuracy: 0.4419 - val_loss: 3.2254 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.9767 - loss: 2.7800 - val_accuracy: 0.4419 - val_loss: 3.1660 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 1.0000 - loss: 2.6630 - val_accuracy: 0.4884 - val_loss: 3.1060 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.5116 - loss: 9.9319 - val_accuracy: 0.5349 - val_loss: 8.9581 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.6279 - loss: 8.8012 - val_accuracy: 0.5349 - val_loss: 8.6510 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.7674 - loss: 8.4300 - val_accuracy: 0.5349 - val_loss: 8.3483 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.7907 - loss: 8.0753 - val_accuracy: 0.5349 - val_loss: 8.0437 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.9302 - loss: 7.7036 - val_accuracy: 0.5581 - val_loss: 7.7414 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.8

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 908ms/step - accuracy: 0.5116 - loss: 10.0361 - val_accuracy: 0.4419 - val_loss: 9.1019 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 46ms/step - accuracy: 0.6279 - loss: 8.8036 - val_accuracy: 0.4419 - val_loss: 8.7593 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.7674 - loss: 8.4327 - val_accuracy: 0.4419 - val_loss: 8.4369 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.8837 - loss: 8.0904 - val_accuracy: 0.4651 - val_loss: 8.1340 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.8605 - loss: 7.7459 - val_accuracy: 0.4651 - val_loss: 7.8329 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step - accuracy:

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4s/step - accuracy: 0.3488 - loss: 2.2284 - val_accuracy: 0.4884 - val_loss: 1.3634 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 63ms/step - accuracy: 0.5814 - loss: 1.2459 - val_accuracy: 0.4186 - val_loss: 1.7362 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 61ms/step - accuracy: 0.7907 - loss: 1.1919 - val_accuracy: 0.4186 - val_loss: 1.8941 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 56ms/step - accuracy: 0.8372 - loss: 1.0698 - val_accuracy: 0.4186 - val_loss: 1.8318 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 43ms/step - accuracy: 0.9302 - loss: 0.9715 - val_accuracy: 0.4186 - val_loss: 1.6975 - learning_rate: 1.0000e-04
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy:

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.3721 - loss: 2.2292 - val_accuracy: 0.3721 - val_loss: 1.6334 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 45ms/step - accuracy: 0.6744 - loss: 1.2470 - val_accuracy: 0.3953 - val_loss: 2.3511 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step - accuracy: 0.8372 - loss: 1.0663 - val_accuracy: 0.3953 - val_loss: 2.6603 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 40ms/step - accuracy: 0.8605 - loss: 0.9666 - val_accuracy: 0.3953 - val_loss: 2.6658 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 44ms/step - accuracy: 0.8837 - loss: 0.9486 - val_accuracy: 0.3953 - val_loss: 2.3723 - learning_rate: 1.0000e-04
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - accuracy:

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.6047 - loss: 5.4425 - val_accuracy: 0.4651 - val_loss: 4.8457 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 75ms/step - accuracy: 0.6744 - loss: 4.6732 - val_accuracy: 0.3721 - val_loss: 4.7347 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 70ms/step - accuracy: 0.8372 - loss: 4.4493 - val_accuracy: 0.4651 - val_loss: 4.6412 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 69ms/step - accuracy: 0.9535 - loss: 4.2548 - val_accuracy: 0.4651 - val_loss: 4.5495 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 69ms/step - accuracy: 0.9767 - loss: 4.1253 - val_accuracy: 0.4884 - val_loss: 4.4542 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 70ms/step - accuracy: 0.9

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.4884 - loss: 5.4835 - val_accuracy: 0.4884 - val_loss: 4.7418 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 40ms/step - accuracy: 0.6279 - loss: 4.6009 - val_accuracy: 0.4884 - val_loss: 4.6868 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 0.8837 - loss: 4.3587 - val_accuracy: 0.4419 - val_loss: 4.6527 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 0.9767 - loss: 4.1528 - val_accuracy: 0.4419 - val_loss: 4.5876 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.9767 - loss: 4.0129 - val_accuracy: 0.4186 - val_loss: 4.5123 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 1.0

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.4884 - loss: 10.2818 - val_accuracy: 0.5581 - val_loss: 9.5637 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - accuracy: 0.7907 - loss: 9.3721 - val_accuracy: 0.5814 - val_loss: 9.2847 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.9302 - loss: 8.8758 - val_accuracy: 0.5116 - val_loss: 8.9400 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 1.0000 - loss: 8.3974 - val_accuracy: 0.4651 - val_loss: 8.6039 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 1.0000 - loss: 7.9716 - val_accuracy: 0.3721 - val_loss: 8.2848 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 1.

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.4884 - loss: 10.4638 - val_accuracy: 0.3953 - val_loss: 10.0726 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step - accuracy: 0.8605 - loss: 9.5115 - val_accuracy: 0.3953 - val_loss: 10.2126 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 0.8837 - loss: 9.0270 - val_accuracy: 0.4186 - val_loss: 9.7396 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.9767 - loss: 8.5883 - val_accuracy: 0.3488 - val_loss: 9.1623 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 1.0000 - loss: 8.1679 - val_accuracy: 0.4651 - val_loss: 8.6957 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5814 - loss: 7.9555 - val_accuracy: 0.4651 - val_loss: 7.1423 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.7907 - loss: 6.8826 - val_accuracy: 0.4419 - val_loss: 6.9570 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 65ms/step - accuracy: 0.9767 - loss: 6.5004 - val_accuracy: 0.4186 - val_loss: 6.7430 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 1.0000 - loss: 6.1850 - val_accuracy: 0.4651 - val_loss: 6.5319 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 1.0000 - loss: 5.9060 - val_accuracy: 0.4884 - val_loss: 6.2936 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 1.0

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.4419 - loss: 8.1498 - val_accuracy: 0.4419 - val_loss: 7.1139 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.8372 - loss: 6.8579 - val_accuracy: 0.4186 - val_loss: 6.9720 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.9767 - loss: 6.5324 - val_accuracy: 0.4651 - val_loss: 6.8585 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 1.0000 - loss: 6.2195 - val_accuracy: 0.5116 - val_loss: 6.7849 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 1.0000 - loss: 5.9401 - val_accuracy: 0.4419 - val_loss: 6.6464 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.9

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.4419 - loss: 7.5784 - val_accuracy: 0.4419 - val_loss: 6.3865 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.5116 - loss: 6.2765 - val_accuracy: 0.4419 - val_loss: 6.2446 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.5349 - loss: 6.1567 - val_accuracy: 0.4651 - val_loss: 6.1264 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.6512 - loss: 6.0201 - val_accuracy: 0.4651 - val_loss: 6.0243 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.6512 - loss: 5.9554 - val_accuracy: 0.4419 - val_loss: 5.9334 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.5

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.4651 - loss: 7.7088 - val_accuracy: 0.3953 - val_loss: 6.3706 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.4419 - loss: 6.4642 - val_accuracy: 0.3953 - val_loss: 6.3353 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.6047 - loss: 6.1212 - val_accuracy: 0.3721 - val_loss: 6.2589 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.7209 - loss: 5.9367 - val_accuracy: 0.3953 - val_loss: 6.1556 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.6047 - loss: 5.8853 - val_accuracy: 0.3953 - val_loss: 6.0435 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.7

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.3953 - loss: 9.5043 - val_accuracy: 0.5116 - val_loss: 8.5118 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.7907 - loss: 8.2305 - val_accuracy: 0.4884 - val_loss: 8.9677 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.8605 - loss: 7.8310 - val_accuracy: 0.4884 - val_loss: 8.6915 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9302 - loss: 7.4751 - val_accuracy: 0.5116 - val_loss: 8.1390 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 1.0000 - loss: 7.1377 - val_accuracy: 0.4884 - val_loss: 7.6712 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 1.0

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.6744 - loss: 9.4151 - val_accuracy: 0.4651 - val_loss: 8.4218 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.7907 - loss: 8.2694 - val_accuracy: 0.4884 - val_loss: 8.2842 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.8605 - loss: 7.8396 - val_accuracy: 0.4884 - val_loss: 7.9575 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.9767 - loss: 7.4586 - val_accuracy: 0.5814 - val_loss: 7.7068 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 1.0000 - loss: 7.0971 - val_accuracy: 0.5814 - val_loss: 7.4808 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 1.0

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.6047 - loss: 4.1040 - val_accuracy: 0.4651 - val_loss: 3.3447 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.6512 - loss: 3.0126 - val_accuracy: 0.4651 - val_loss: 3.1347 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.6047 - loss: 3.0893 - val_accuracy: 0.5116 - val_loss: 3.0457 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.5814 - loss: 3.0025 - val_accuracy: 0.5814 - val_loss: 3.0097 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.6047 - loss: 2.8902 - val_accuracy: 0.5349 - val_loss: 2.9937 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.8

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.4419 - loss: 4.1973 - val_accuracy: 0.4419 - val_loss: 3.1591 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.4884 - loss: 3.2037 - val_accuracy: 0.4419 - val_loss: 3.1380 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.7209 - loss: 2.9784 - val_accuracy: 0.4419 - val_loss: 3.1167 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.5581 - loss: 3.0294 - val_accuracy: 0.4419 - val_loss: 3.1010 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.6512 - loss: 2.9412 - val_accuracy: 0.4419 - val_loss: 3.0798 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.6

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.5116 - loss: 8.2263 - val_accuracy: 0.6279 - val_loss: 7.5093 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 0.7674 - loss: 7.4674 - val_accuracy: 0.4884 - val_loss: 7.4275 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.8837 - loss: 7.1063 - val_accuracy: 0.4884 - val_loss: 7.2479 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.9767 - loss: 6.7985 - val_accuracy: 0.5116 - val_loss: 6.9926 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.9767 - loss: 6.5169 - val_accuracy: 0.5349 - val_loss: 6.7309 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 1.0

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.5349 - loss: 8.1787 - val_accuracy: 0.6512 - val_loss: 7.5813 - learning_rate: 0.0010
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.6977 - loss: 7.4856 - val_accuracy: 0.5349 - val_loss: 7.3825 - learning_rate: 0.0010
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 0.8837 - loss: 7.1514 - val_accuracy: 0.5116 - val_loss: 7.2068 - learning_rate: 0.0010
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.9767 - loss: 6.8390 - val_accuracy: 0.5116 - val_loss: 6.9931 - learning_rate: 0.0010
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.9535 - loss: 6.5506 - val_accuracy: 0.5581 - val_loss: 6.7591 - learning_rate: 0.0010
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.9

In [208]:
trial_metrics_dict = {
    "train_acc": [],
    "test_acc": [],
    "val_acc": [],
    "scores": [],
    "channels_selected": [],
}
for i, trial in enumerate(study.trials_dataframe().itertuples()):
    trial_user_attrs = trial.user_attrs_trial_data
    trial_metrics_dict["scores"].append(trial.value)
    trial_metrics_dict["train_acc"].append(np.max(trial_user_attrs['train_accuracy']))
    trial_metrics_dict["test_acc"].append(trial_user_attrs['test_accuracy'])
    trial_metrics_dict["val_acc"].append(np.max(trial_user_attrs['val_accuracy']))
    trial_metrics_dict["channels_selected"].append(trial_user_attrs['channels_selected'])
trial_metrics_df = pd.DataFrame(trial_metrics_dict)
# trial_metrics_df.sort_values("test_acc", ascending=False)
trial_metrics_df.sort_values("scores", ascending=True)

Unnamed: 0,train_acc,test_acc,val_acc,scores,channels_selected
12,1.0,0.590909,0.55814,0.16563,[]
24,1.0,0.545455,0.651163,0.17523,[]
2,0.790698,0.5,0.604651,0.1851,[]
17,0.906977,0.681818,0.395349,0.195241,[]
20,1.0,0.409091,0.511628,0.205652,[]
4,1.0,0.772727,0.534884,0.205652,[]
0,1.0,0.545455,0.55814,0.216333,[]
9,0.860465,0.5,0.55814,0.227285,[]
23,0.813953,0.636364,0.488372,0.227285,[]
13,0.860465,0.5,0.55814,0.238507,[]


In [206]:
rpprint({ k: v for k, v in study.best_trial.params.items() if not k.startswith("channels") })
rprint("test_accuracy =", study.best_trial.user_attrs["trial_data"]["test_accuracy"])
rprint("val_accuracy =", study.best_trial.user_attrs["trial_data"]["val_accuracy"][-1])
rprint("channels_selected =", study.best_trial.user_attrs["trial_data"]["channels_selected"])

In [None]:
channels = fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1]
channels_of_interest = ['C3', 'F3', 'C4', 'T5', 'O1', 'O2', 'A2', 'T6']
channels_of_interest_idx = np.where(np.isin(channels, channels_of_interest))[0]

channels_of_interest_idx

In [None]:
# Model/pipeline hyperparameter optimization test
data_store_dict = {}

def objective_fn_lstm_cnn_net(trial, **kwargs):
    # _SFREQ_ = trial.suggest_categorical("sfreq", [128, 250])
    _SFREQ_ = 250
    # _TRAIN_SIZE_ = trial.suggest_categorical("train_size", [0.8, 0.9])
    _TRAIN_SIZE_ = 0.8
    _TEST_SIZE_ = 1 - _TRAIN_SIZE_
    _BATCH_SIZE = 256
    
    channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
    channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    while len(channels_idx) == 0:
        channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
        channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    
    X, y, _ = data_generator(fat_dataset, subjects=[3], channel_idx=channels_idx, sfreq=_SFREQ_)
    
    _NUM_SAMPLES_ = X.shape[-1]
    _NUM_CHANNELS_ = X.shape[-2]
    # rpprint({
    #     "NUM_SAMPLES": _NUM_SAMPLES_,
    #     "NUM_CHANNELS": _NUM_CHANNELS_,
    #     "X": X.shape,
    #     "y": y.shape,
    #     "channels selected": np.asarray(fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'])[channels_idx]
    # })
    y_encoded = LabelEncoder().fit_transform(y)
    _NUM_CLASSES_ = len(np.unique(y_encoded))
    X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, _NUM_SAMPLES_, _NUM_CHANNELS_), y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
    # X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
    
    model_params = {
        "nb_classes": _NUM_CLASSES_,
        "channels": _NUM_CHANNELS_,
        "samples": _NUM_SAMPLES_,
    }
    
    model = lstm_cnn_net_v2(**model_params)
    # model = lstm_cnn_net(**model_params)
    # model = shallow_conv_net(**model_params)
    model.compile(
        loss="sparse_categorical_crossentropy", 
        optimizer="adam", 
        metrics=["accuracy"]
    )
    
    history = model.fit(
        X_train,
        y_train,
        batch_size=_BATCH_SIZE,
        epochs=50,
        validation_split=0.2,
        callbacks=[
            # keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
            # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
            keras.callbacks.EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True),
            keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
        ],
    )

    trial.set_user_attr("trial_data", {
        "channels_selected": np.asarray(fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'])[channels_idx],
        # "history": history,
        # "model": model,
        "val_accuracy": history.history["val_accuracy"],
        "test_accuracy": model.evaluate(X_test, y_test, batch_size=_BATCH_SIZE)[1]
    })

    # return (1 - np.mean(history.history["val_accuracy"][-2:])) + 0.005*len(channels_idx)
    return (1 - history.history["val_accuracy"][-1])**2 + 0.00005*len(channels_idx)

def heuristic_optimizer(obj_fn = lambda: None, max_iter = 15, show_progress_bar = True, **kwargs):
    
    optuna.logging.set_verbosity(optuna.logging.CRITICAL)

    # trial_data = []
    # trial_callback = lambda study, trial: trial_data.append(trial.user_attrs["trial_data"])
    
    # study = optuna.create_study(direction="minimize", sampler=optuna.samplers.CmaEsSampler())
    study = optuna.create_study(direction="minimize", sampler=optuna.samplers.NSGAIISampler())
    study.optimize(obj_fn, n_trials=max_iter, show_progress_bar=show_progress_bar, callbacks=[], **kwargs)
    
    return study

study = heuristic_optimizer(obj_fn = objective_fn_lstm_cnn_net)

In [None]:
# Save model
model.save("model.keras")

In [None]:
# Load model from disk
loaded_model = keras.saving.load_model("model.keras", custom_objects=CUSTOM_OBJECTS)

In [None]:
## Pruning and quantization test

from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper

baseline_model = loaded_model
pruning_params = {
    'pruning_schedule': pruning_wrapper.PolynomialDecay(initial_sparsity=0.00,
                                                        final_sparsity=0.80,
                                                        begin_step=0,
                                                        end_step=100)
}
sparse_model = pruning_wrapper.PruneLowMagnitude(baseline_model, **pruning_params)

sparse_model.summary()

### Data exploration, etc.

In [None]:
def create_spectrogram(epoch_data):
    spectrograms = []
    for epoch in epoch_data:
        epoch_spectrograms = []
        for channel in epoch:
            f, t, Sxx = sp.signal.spectrogram(channel, fs=50)
            epoch_spectrograms.append(Sxx)
        spectrograms.append(np.array(epoch_spectrograms))
    return np.array(spectrograms)

spectograms = create_spectrogram(X)
rprint(spectograms.shape) # (180, 21, 129, 2)

num_left_hand_epochs_to_visualize = 2
num_right_hand_epochs_to_visualize = 2

# Select epochs for visualization
left_hand_epochs_to_visualize = np.where((y == 'left_hand'))[0]
right_hand_epochs_to_visualize = np.where((y == 'right_hand'))[0]
left_hand_epochs_to_visualize = np.random.choice(left_hand_epochs_to_visualize, num_left_hand_epochs_to_visualize)
right_hand_epochs_to_visualize = np.random.choice(right_hand_epochs_to_visualize, num_right_hand_epochs_to_visualize)
epochs_to_visualize = np.concatenate((left_hand_epochs_to_visualize, right_hand_epochs_to_visualize))

rprint(epochs_to_visualize)

# Visualize spectrograms
fig, axs = plt.subplots(epochs_to_visualize.size, 1, figsize=(10, 10))
for i, idx in enumerate(epochs_to_visualize):
    axs[i].matshow(spectograms[idx, :, :, 0], cmap='viridis')
    axs[i].set_title(y[idx])
    axs[i].set_xlabel('Frequency (Hz)')
    axs[i].set_ylabel('Time (s)')
    axs[i].set_xlim(0, 60) # Limit x-axis to 60Hz

plt.tight_layout()
plt.show()