In [2]:
# General imports
import os
import uuid
import time
import random
import datetime
import glob
import pickle
import tqdm
import copy
import optuna
import numpy as np
import pandas as pd
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import mne
from rich import print as rprint
from rich.pretty import pprint as rpprint
from tqdm import tqdm
from itertools import chain
from functools import partial

# JAX + Keras
os.environ["KERAS_BACKEND"] = "jax"
os.environ["TF_USE_LEGACY_KERAS"] = "0"
import jax
import jax.numpy as jnp
import keras
from keras.models import Model
from keras.layers import Dense, Activation, Permute, Dropout
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv1D, MaxPooling1D, AveragePooling1D
from keras.layers import SeparableConv2D, DepthwiseConv2D
from keras.layers import BatchNormalization
from keras.layers import SpatialDropout2D
from keras.regularizers import l1_l2
from keras.layers import Input, Flatten
from keras.constraints import max_norm
from keras import backend as K

# Sklearn
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout, TimeDistributed, BatchNormalization
import h5py
from sklearn import preprocessing

# Dataset
from custom_datasets.fatigue_mi import FatigueMI

%load_ext autoreload
%autoreload 3

/home/arazzz/anaconda3/envs/moabb_model_optimization/lib/python3.11/site-packages/moabb/pipelines/__init__.py:26: ModuleNotFoundError: Tensorflow is not installed. You won't be able to use these MOABB pipelines if you attempt to do so.
  warn(


In [9]:
SKLRNG = 42
RNG = jax.random.PRNGKey(SKLRNG)

2024-03-01 23:10:59.891010: W external/xla/xla/service/gpu/nvptx_compiler.cc:744] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [80]:
def data_generator(dataset, subjects = [1], channel_idx = [], filters = ([8, 32],), sfreq = 250):

    find_events = lambda raw, event_id: mne.find_events(raw, shortest_event=0, verbose=False) if len(mne.utils._get_stim_channel(None, raw.info, raise_error=False)) > 0 else mne.events_from_annotations(raw, event_id=event_id, verbose=False)[0]
    
    data = dataset.get_data(subjects=subjects)
    
    X = []
    y = []
    metadata = []

    for subject_id in data.keys():
        for session_id in data[subject_id].keys():
            for run_id in data[subject_id][session_id].keys():
                raw = data[subject_id][session_id][run_id]
                
                for fmin, fmax in filters:
                    raw = raw.filter(l_freq = fmin, h_freq = fmax, method = 'iir', picks = 'eeg', verbose = False)
                
                events = find_events(raw, dataset.event_id)

                tmin = dataset.interval[0]
                tmax = dataset.interval[1]

                channels = np.asarray(raw.info['ch_names'])[channel_idx] if len(channel_idx) > 0 else np.asarray(raw.info['ch_names'])

                # rpprint(channels)
                
                stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
                picks = mne.pick_channels(raw.info["ch_names"], include=channels, exclude=stim_channels, ordered=True)

                x = mne.Epochs(
                    raw,
                    events,
                    event_id=dataset.event_id,
                    tmin=tmin,
                    tmax=tmax,
                    proj=False,
                    baseline=None,
                    preload=True,
                    verbose=False,
                    picks=picks,
                    event_repeated="drop",
                    on_missing="ignore",
                )
                x_events = x.events
                inv_events = {k: v for v, k in dataset.event_id.items()}
                labels = [inv_events[e] for e in x_events[:, -1]]

                # rpprint({
                #     "X": np.asarray(x.get_data(copy=False)).shape,
                #     "y": np.asarray(labels).shape,
                #     "channels selected": np.asarray(raw.info['ch_names'])[channel_idx]
                # })

                # x.plot(scalings="auto")
                # display(x.info)
                
                x_resampled = x.resample(sfreq) # Resampler_Epoch
                x_resampled_data = x_resampled.get_data(copy=False) # Convert_Epoch_Array
                x_resampled_data_standard_scaler = np.asarray([
                    StandardScaler().fit_transform(x_resampled_data[i])
                    for i in np.arange(x_resampled_data.shape[0])
                ]) # Standard_Scaler_Epoch

                # x_resampled.plot(scalings="auto")
                # display(x_resampled.info)

                n = x_resampled_data_standard_scaler.shape[0]
                # n = x.get_data(copy=False).shape[0]
                met = pd.DataFrame(index=range(n))
                met["subject"] = subject_id
                met["session"] = session_id
                met["run"] = run_id
                x.metadata = met.copy()
                
                # X.append(x_resampled_data_standard_scaler)
                X.append(x)
                y.append(labels)
                metadata.append(met)

    return np.concatenate(X, axis=0), np.concatenate(y), pd.concat(metadata, ignore_index=True)

fat_dataset = FatigueMI()
# X, y, _ = data_generator(fat_dataset, subjects=fat_dataset.subject_list, sfreq=250)
X, y, _ = data_generator(fat_dataset, subjects=[5], channel_idx=[], sfreq=128)
# X, y, _ = data_generator(fat_dataset, subjects=[6], channel_idx=[1, 2, 5, 12, 13, 14, 17, 18], sfreq=128)
# X, y, _ = data_generator(fat_dataset, subjects=[5, 6, 12, 17, 24, 28, 29], channel_idx=[], sfreq=128)

Adding metadata with 3 columns


In [7]:
y_encoded = LabelEncoder().fit_transform(y) # 0 = left, 1 = right, 2 = unlabelled (!NOTE: should be ignored for binary classification)
# rpprint(y_encoded)
rprint({
    "X": X.shape,
    "y": y_encoded.shape
})

NUM_SAMPLES = X.shape[-1]
NUM_CHANNELS = X.shape[-2]
NUM_CLASSES = len(np.unique(y_encoded))

rpprint({
    "NUM_SAMPLES": NUM_SAMPLES,
    "NUM_CHANNELS": NUM_CHANNELS,
    "NUM_CLASSES": NUM_CLASSES
})

# sns.histplot(y_encoded); # Plot the class distribution

In [10]:
TRAIN_SIZE = 0.8
TEST_SIZE = 0.2

# X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, NUM_SAMPLES, NUM_CHANNELS), y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=None, shuffle=False)
# X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, NUM_SAMPLES, NUM_CHANNELS), y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=y_encoded, shuffle=True)
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=y_encoded, shuffle=True)
# X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=None, shuffle=False)

total = X_train.shape[0] + X_test.shape[0]
rpprint({
    "X_train": f"{X_train.shape[0]} ({X_train.shape[0] / total * 100:.2f}%)",
    "X_test": f"{X_test.shape[0]} ({X_test.shape[0] / total * 100:.2f}%)",
}, expand_all=True)

# batch_size = 32
# batches = batch_generator(X_train, batch_size, RNG)

In [14]:
X_train.shape

(86, 20, 256)

In [11]:
# region Helper funcs
def shallow_conv_net_square_layer(x):
    return jnp.square(x)

def shallow_conv_net_log_layer(x):
    return jnp.log(jnp.clip(x, 1e-7, 10000))

CUSTOM_OBJECTS = {
    "shallow_conv_net_square_layer": shallow_conv_net_square_layer, 
    "shallow_conv_net_log_layer": shallow_conv_net_log_layer 
}
# endregion Helper funcs

# region Models
def eeg_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, dropout_rate = 0.5, 
            kernLength = 64, F1 = 8, D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):
    """
        From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """

    dropoutType = {'Dropout': Dropout, 'SpatialDropout2D': SpatialDropout2D}[dropoutType]

    input1   = Input(shape = (channels, samples, 1))

    ##################################################################
    block1       = Conv2D(F1, (1, kernLength), padding = 'same',
                                   input_shape = (channels, samples, 1),
                                   use_bias = False)(input1)
    block1       = BatchNormalization()(block1)
    block1       = DepthwiseConv2D((channels, 1), use_bias = False, 
                                   depth_multiplier = D,
                                   depthwise_constraint = max_norm(1.))(block1)
    block1       = BatchNormalization()(block1)
    block1       = Activation('elu')(block1)
    block1       = AveragePooling2D((1, 4))(block1)
    block1       = dropoutType(dropout_rate)(block1)
    
    block2       = SeparableConv2D(F2, (1, 16), use_bias = False, padding = 'same')(block1)
    block2       = BatchNormalization()(block2)
    block2       = Activation('elu')(block2)
    block2       = AveragePooling2D((1, 8))(block2)
    block2       = dropoutType(dropout_rate)(block2)
        
    flatten      = Flatten(name = 'flatten')(block2)
    
    dense        = Dense(nb_classes, name = 'dense', 
                         kernel_constraint = max_norm(norm_rate))(flatten)
    softmax      = Activation('softmax', name = 'softmax')(dense)
    
    return Model(inputs=input1, outputs=softmax)

def deep_conv_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, dropout_rate = 0.5):
    """
        From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """
    input_main   = Input((channels, samples, 1))
    block1       = Conv2D(25, (1, 5), 
                                 input_shape=(channels, samples, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(input_main)
    block1       = Conv2D(25, (channels, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block1       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
    block1       = Activation('elu')(block1)
    block1       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block1)
    block1       = Dropout(dropout_rate)(block1)
  
    block2       = Conv2D(50, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block2       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block2)
    block2       = Activation('elu')(block2)
    block2       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block2)
    block2       = Dropout(dropout_rate)(block2)
    
    block3       = Conv2D(100, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block2)
    block3       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block3)
    block3       = Activation('elu')(block3)
    block3       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block3)
    block3       = Dropout(dropout_rate)(block3)
    
    block4       = Conv2D(200, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block3)
    block4       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block4)
    block4       = Activation('elu')(block4)
    block4       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block4)
    block4       = Dropout(dropout_rate)(block4)
    
    flatten      = Flatten()(block4)
    
    dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
    softmax      = Activation('softmax')(dense)
    
    return Model(inputs=input_main, outputs=softmax)

def shallow_conv_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, **kwargs):
    """
        From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """
    
    _POOL_SIZE_D2_ = kwargs.get("pool_size_d2", 35)
    _STRIDES_D2_ = kwargs.get("strides_d2", 7)
    _CONV_FILTERS_D2_ = kwargs.get("conv_filters_d2", 13)

    _POOL_SIZE_ = kwargs.get("pool_size", (1, _POOL_SIZE_D2_))
    _STRIDES_ = kwargs.get("strides", (1, _STRIDES_D2_))
    _CONV_FILTERS_ = kwargs.get("conv_filters", (1, _CONV_FILTERS_D2_))

    _CONV2D_1_UNITS_ = kwargs.get("conv2d_1_units", 40)
    _CONV2D_2_UNITS_ = kwargs.get("conv2d_2_units", 40)
    _L2_REG_1_ = kwargs.get("l2_reg_1", 0.01)
    _L2_REG_2_ = kwargs.get("l2_reg_2", 0.01)
    _L2_REG_3_ = kwargs.get("l2_reg_3", 0.01)
    _DROPOUT_RATE_ = kwargs.get("dropout_rate", 0.5)

    input_main   = Input(shape=(channels, samples, 1))
    block1       = Conv2D(_CONV2D_1_UNITS_, _CONV_FILTERS_,
                                 input_shape=(channels, samples, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)), kernel_regularizer=keras.regularizers.L2(_L2_REG_1_))(input_main)
    # block1       = Conv2D(40, (channels, 1), use_bias=False, 
    #                       kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block1       = Conv2D(_CONV2D_2_UNITS_, (channels, 1), use_bias=False, 
                          kernel_constraint = max_norm(2., axis=(0,1,2)), kernel_regularizer=keras.regularizers.L2(_L2_REG_2_))(block1)
    block1       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
    block1       = Activation(shallow_conv_net_square_layer)(block1)
    block1       = AveragePooling2D(pool_size=_POOL_SIZE_, strides=_STRIDES_)(block1)
    block1       = Activation(shallow_conv_net_log_layer)(block1)
    block1       = Dropout(_DROPOUT_RATE_)(block1)
    flatten      = Flatten()(block1)
    # dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
    dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5), kernel_regularizer=keras.regularizers.L2(_L2_REG_3_))(flatten)
    softmax      = Activation('softmax')(dense)
    
    return Model(inputs=input_main, outputs=softmax)

def lstm_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES):
    model = Sequential()
    model.add(Input(shape=(samples, channels)))
    model.add(LSTM(40, return_sequences=True, stateful=False))
    # model.add(LSTM(40, return_sequences=True, stateful=False, batch_input_shape=(batch_size, timesteps, data_dim)))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    model.add(LSTM(40, return_sequences=True, stateful=False))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    model.add(LSTM(40, stateful=False))
    model.add(Dropout(0.5))
    # model.add(TimeDistributed(Dense(T_train)))
    model.add(Dense(50,activation='softmax'))
    model.add(Dense(nb_classes, activation='softmax'))
    
    return model

def lstm_cnn_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, **kwargs):

    _CONV1D_1_UNITS_ = kwargs.get("conv1d_1_units", 40)
    _CONV1D_1_KERNEL_SIZE_ = kwargs.get("conv1d_1_kernel_size", 20)
    _CONV1D_1_STRIDES_ = kwargs.get("conv1d_1_strides", 4)
    _CONV1D_1_MAXPOOL_SIZE_ = kwargs.get("conv1d_1_maxpool_size", 4)
    _CONV1D_1_MAXPOOL_STRIDES_ = kwargs.get("conv1d_1_maxpool_strides", 4)
    _LSTM_1_UNITS_ = kwargs.get("lstm_1_units", 50)
    _L2_REG_1_ = kwargs.get("l2_reg_1", 0.01)
    _L2_REG_2_ = kwargs.get("l2_reg_2", 0.01)
    _L2_REG_3_ = kwargs.get("l2_reg_3", 0.01)
    _DROPOUT_RATE_1_ = kwargs.get("dropout_rate_1", 0.5)
    _DROPOUT_RATE_2_ = kwargs.get("dropout_rate_2", 0.5)

    model = Sequential()
    model.add(Input(shape=(samples, channels)))

    # add 1-layer cnn
    model.add(Conv1D(_CONV1D_1_UNITS_, kernel_size=_CONV1D_1_KERNEL_SIZE_, strides=_CONV1D_1_STRIDES_, kernel_regularizer=keras.regularizers.L2(_L2_REG_1_)))
    model.add(Activation('relu'))
    model.add(Dropout(_DROPOUT_RATE_1_))
    model.add(BatchNormalization())
    model.add(MaxPooling1D(pool_size=_CONV1D_1_MAXPOOL_SIZE_, strides=_CONV1D_1_MAXPOOL_STRIDES_))


    # add 1-layer lstm
    model.add(LSTM(_LSTM_1_UNITS_, return_sequences=True, stateful=False, kernel_regularizer=keras.regularizers.L2(_L2_REG_2_)))
    model.add(Dropout(_DROPOUT_RATE_2_))
    model.add(BatchNormalization())
    model.add(Flatten())
    model.add(Dense(nb_classes, activation='softmax', kernel_regularizer=keras.regularizers.L2(_L2_REG_3_)))
    
    return model

def lstm_cnn_net_v2(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES):
    model = Sequential()
    model.add(Input(shape=(samples, channels)))

    # Add 1-layer LSTM
    model.add(LSTM(50, return_sequences=True, stateful=False))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    # Add 1-layer CNN
    model.add(Conv1D(40, kernel_size=20, strides=4))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    model.add(MaxPooling1D(pool_size=4, strides=4))
    model.add(Flatten())
    # Fully connected layer for classification
    model.add(Dense(nb_classes, activation='softmax'))
    return model

# endregion Models

In [None]:
NUM_FOLDS = 5
kfold = KFold(n_splits=NUM_FOLDS, shuffle=True)
acc_per_fold = []
loss_per_fold = []

model = shallow_conv_net()
# model = lstm_cnn_net()

# K-fold Cross Validation model evaluation
fold_no = 1
for train, test in kfold.split(X_train, y_train):
        # Fit data to model
        model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
        
        history = model.fit(
                X_train[train],
                y_train[train],
                batch_size=64,
                epochs=25,
                validation_data=(X_train[test], y_encoded[test]),
                # validation_split=0.2,
                shuffle=True,
                callbacks=[
                        keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
                        keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
                        # keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
                        # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
                ],
        )
        
        # Generate generalization metrics
        scores = model.evaluate(X_train[test], y_train[test], verbose=0)
        rprint(f'Score for fold {fold_no}: {model.metrics_names[0]} of {scores[0]}; {model.metrics_names[1]} of {scores[1]*100}%')
        acc_per_fold.append(scores[1] * 100)
        loss_per_fold.append(scores[0])
        
        # Increase fold number
        fold_no = fold_no + 1

In [None]:
mean_cross_val_score = np.mean(acc_per_fold)
rprint("Mean Cross Validation Score (5-fold):", mean_cross_val_score)

results_test = model.evaluate(X_test, y_test)
rprint("Test Accuracy:", results_test[1])

In [None]:
# results_val = model.evaluate(X_val, y_val)
results_test = model.evaluate(X_test, y_test)

# X_test.reshape(-1, NUM_SAMPLES, NUM_CHANNELS).shape

In [None]:
model = shallow_conv_net()
# model = lstm_cnn_net()
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

history = model.fit(
    X_train,
    y_train,
    batch_size=128,
    epochs=25,
    # validation_data=(X_test, y_test),
    validation_split=0.2,
    shuffle=True,
    callbacks=[
        keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
        keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
        # keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
        # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
    ],
)

In [34]:
def objective_fn(trial, subjects=[], model_str = "lstm_cnn_net", sfreq = 128, batch_size = 128, **kwargs):
    _SFREQ_ = sfreq if sfreq else trial.suggest_categorical("sfreq", [128, 256, 300])
    _TRAIN_SIZE_ = 0.8
    _TEST_SIZE_ = 1 - _TRAIN_SIZE_
    _BATCH_SIZE = batch_size if batch_size else trial.suggest_int("batch_size", 32, 256, step=32)

    models = {
        "lstm_cnn_net": lstm_cnn_net,
        "shallow_conv_net": shallow_conv_net
    }
    
    subjects = subjects if len(subjects) > 0 else fat_dataset.subject_list
    model_fn = models[model_str]
    
    all_channels = fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1]
    channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in all_channels }
    channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    while len(channels_idx) == 0:
        channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in all_channels }
        channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    
    X, y, _ = data_generator(fat_dataset, subjects=subjects, channel_idx=channels_idx, sfreq=_SFREQ_)
    
    _NUM_SAMPLES_ = X.shape[-1]
    _NUM_CHANNELS_ = X.shape[-2]
    
    y_encoded = LabelEncoder().fit_transform(y)
    _NUM_CLASSES_ = len(np.unique(y_encoded))
    if "lstm" in model_str:
        X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, _NUM_SAMPLES_, _NUM_CHANNELS_), y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
    else:
        X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, shuffle=True, stratify=y_encoded)
    # X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, _NUM_SAMPLES_, _NUM_CHANNELS_), y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
        
    # y_train = keras.utils.to_categorical(y_train)
    # y_test = keras.utils.to_categorical(y_test)
    
    model_params = {
        "nb_classes": _NUM_CLASSES_,
        "channels": _NUM_CHANNELS_,
        "samples": _NUM_SAMPLES_,
    }

    if model_str == "shallow_conv_net":
        model_params = {
            **model_params,

            "pool_size_d2": trial.suggest_int("pool_size_d2", 5, 95, step=5),
            "strides_d2": trial.suggest_int("strides_d2", 1, 31, step=2),
            "conv_filters_d2": trial.suggest_int("conv_filters_d2", 5, 55, step=2),
            
            "conv2d_1_units": trial.suggest_int("conv2d_1_units", 10, 200, step=10),
            "conv2d_2_units": trial.suggest_int("conv2d_2_units", 10, 200, step=10),
            "l2_reg_1": trial.suggest_float("l2_reg_1", 0.001, 0.8),
            "l2_reg_2": trial.suggest_float("l2_reg_2", 0.001, 0.8),
            "l2_reg_3": trial.suggest_float("l2_reg_3", 0.001, 0.8),
            "dropout_rate": trial.suggest_float("dropout_rate", 0.1, 0.9, step=0.1),
        }

    if model_str == "lstm_cnn_net":
        model_params = {
            **model_params,

            "conv1d_1_units": trial.suggest_int("conv1d_1_units", 10, 200, step=10),
            "conv1d_1_kernel_size": trial.suggest_int("conv1d_1_kernel_size", 1, 100, step=1),
            "conv1d_1_strides": trial.suggest_int("conv1d_1_strides", 1, 50, step=1),
            # "conv1d_1_maxpool_size": trial.suggest_int("conv1d_1_maxpool_size", 1, 100, step=2),
            "conv1d_1_maxpool_strides": trial.suggest_int("conv1d_1_maxpool_strides", 1, 4, step=1),
            "lstm_1_units": trial.suggest_int("lstm_1_units", 10, 200, step=10),
            "l2_reg_1": trial.suggest_float("l2_reg_1", 0.001, 0.8),
            "l2_reg_2": trial.suggest_float("l2_reg_2", 0.001, 0.8),
            "l2_reg_3": trial.suggest_float("l2_reg_3", 0.001, 0.8),
            "dropout_rate_1": trial.suggest_float("dropout_rate_1", 0.1, 0.9, step=0.1),
            "dropout_rate_2": trial.suggest_float("dropout_rate_2", 0.1, 0.9, step=0.1),
        }
    
    model = model_fn(**model_params)
    model.compile(
        loss="sparse_categorical_crossentropy", 
        optimizer="rmsprop", 
        metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")]
    )
    
    history = model.fit(
        X_train,
        y_train,
        batch_size=_BATCH_SIZE,
        epochs=5,
        validation_split=0.2,
        shuffle=True,
        callbacks=[
            # keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
            # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
            keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True),
            keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
        ],
    )

    print("\n")
    
    # Rename history.history["sparse_categorical_accuracy"] to history.history["accuracy"]
    # history.history["accuracy"] = history.history["sparse_categorical_accuracy"]
    # history.history["val_accuracy"] = history.history["val_sparse_categorical_accuracy"]
    
    trial.set_user_attr("trial_data", {
        "channels_selected": np.asarray(fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'])[channels_idx],
        # "history": history,
        "model": model.to_json(),
        "model_weights": model.get_weights(),
        "train_accuracy": history.history["accuracy"],
        "val_accuracy": history.history["val_accuracy"],
        "test_accuracy": model.evaluate(X_test, y_test, batch_size=_BATCH_SIZE)[1],
    })
    
    print("\n")

    # cost = np.abs(1 - np.max(history.history["val_accuracy"])) + 1e-5*len(channels_idx) + np.abs(np.mean(np.asarray(history.history["accuracy"]) - np.asarray(history.history["val_accuracy"])))
    # cost = np.abs(1 - np.max(history.history["val_accuracy"])) + 0.25*(len(channels_idx)/len(all_channels))
    cost = (1 - np.max(history.history["val_accuracy"]))**2 + 0.5*(1 - np.mean(history.history["val_accuracy"]))**2 + 1e-2*len(channels_idx) + 1e-3*np.mean(np.asarray(history.history["accuracy"]) - np.asarray(history.history["val_accuracy"]))**2
    # cost = (1 - np.max(history.history["val_accuracy"]))**2 + 1e-5*len(channels_idx) + 1e-3*np.mean(np.asarray(history.history["accuracy"]) - np.asarray(history.history["val_accuracy"]))**2
    # cost = (1 - np.max(history.history["val_accuracy"]))**2 + 8e-5*len(channels_idx) #+ 1e-3*(np.mean(history.history["accuracy"][-2:]) - np.mean(history.history["val_accuracy"][-2:]))**2
    # cost = (1 - np.max(history.history["val_accuracy"]))**2 + 0.25*(1 - np.mean(history.history["val_accuracy"]))**2 + 0.25*(1 - history.history["val_accuracy"][-1])**2 + 8e-5*len(channels_idx) #+ 1e-3*np.mean(np.asarray(history.history["accuracy"]) - np.asarray(history.history["val_accuracy"]))**2
    # cost = (1 - np.max(history.history["val_accuracy"]))**2 + 1e-5*len(channels_idx) + 1e-4*np.mean(np.asarray(history.history["accuracy"]) - np.asarray(history.history["val_accuracy"]))**2 + 1e-5*np.mean(1 - np.max(history.history["accuracy"]))**2
    # if not 0.70 <= np.max(history.history["accuracy"]) <= 1.00 or len(channels_idx) < 5:
    if not 0.75 <= np.max(history.history["accuracy"]) <= 1.00:
        cost += .5

    # cost = (1 - np.mean(history.history["val_accuracy"][-2:])) + 0.005*len(channels_idx)
    # cost = (1 - np.mean(history.history["val_accuracy"]))**2 + 2.5e-1*(np.mean(history.history["accuracy"]) - np.mean(history.history["val_accuracy"]))**2 + 1.5e-1*(1 - np.max(history.history["val_accuracy"]))**2 + 5e-2*len(channels_idx)
    # cost = (1 - np.mean(history.history["val_accuracy"]))**2 + 2.5e-1*(np.mean(history.history["accuracy"]) - np.mean(history.history["val_accuracy"]))**2 + 5e-2*len(channels_idx)
    return cost

def early_stopping_check(study, trial, early_stopping_rounds=10):
    current_trial_number = trial.number
    best_trial_number = study.best_trial.number
    should_stop = (current_trial_number - best_trial_number) >= early_stopping_rounds
    if should_stop:
        optuna.logging._get_library_root_logger().info("Early stopping detected: %s", should_stop)
        study.stop()

def heuristic_optimizer(obj_fn, max_iter = 25, show_progress_bar = True, subjects=[], model_str="lstm_cnn_net", sfreq = 128, batch_size = 128, max_stag_count = 10, **kwargs):
    
    optuna.logging.set_verbosity(optuna.logging.CRITICAL)

    objective = partial(obj_fn, subjects = subjects, model_str = model_str, sfreq = sfreq, batch_size = batch_size)
    
    study = optuna.create_study(direction="minimize", sampler=optuna.samplers.NSGAIIISampler())
    try:
        study.optimize(objective, n_trials=max_iter, show_progress_bar=show_progress_bar, callbacks=[partial(early_stopping_check, early_stopping_rounds=max_stag_count)], **kwargs)
    except KeyboardInterrupt:
        pass
    
    return study

study = heuristic_optimizer(obj_fn = objective_fn, model_str="shallow_conv_net", subjects=[12], sfreq = None, batch_size = None, max_iter = 50, max_stag_count = 10)

  study = optuna.create_study(direction="minimize", sampler=optuna.samplers.NSGAIIISampler())


  0%|          | 0/50 [00:00<?, ?it/s]

Adding metadata with 3 columns


  super().__init__(


Epoch 1/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 887ms/step - accuracy: 0.4805 - loss: 17.3617 - val_accuracy: 0.5000 - val_loss: 12.5354 - learning_rate: 0.0010
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.6668 - loss: 12.1372 - val_accuracy: 0.3333 - val_loss: 11.1348 - learning_rate: 0.0010
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.6399 - loss: 10.8530 - val_accuracy: 0.4444 - val_loss: 10.0629 - learning_rate: 0.0010
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.7500 - loss: 9.7167 - val_accuracy: 0.5000 - val_loss: 9.1637 - learning_rate: 0.0010
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.7137 - loss: 8.9267 - val_accuracy: 0.5556 - val_loss: 8.4057 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 432ms/step - accura

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4s/step - accuracy: 0.5294 - loss: 67.4647 - val_accuracy: 0.5556 - val_loss: 56.4078 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 58ms/step - accuracy: 0.6618 - loss: 56.3769 - val_accuracy: 0.5556 - val_loss: 50.2950 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step - accuracy: 0.6471 - loss: 49.6135 - val_accuracy: 0.5556 - val_loss: 46.5866 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 55ms/step - accuracy: 0.5588 - loss: 45.3197 - val_accuracy: 0.5556 - val_loss: 42.1449 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step - accuracy: 0.7206 - loss: 40.6352 - val_accuracy: 0.6111 - val_loss: 37.4952 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 737ms/step - accur

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4706 - loss: 37.2294 - val_accuracy: 0.5556 - val_loss: 30.3164 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 40ms/step - accuracy: 0.5735 - loss: 30.3119 - val_accuracy: 0.5000 - val_loss: 30.6231 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 40ms/step - accuracy: 0.6029 - loss: 28.7373 - val_accuracy: 0.6667 - val_loss: 24.5432 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.9412 - loss: 24.1468 - val_accuracy: 0.5000 - val_loss: 22.9223 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step - accuracy: 0.9559 - loss: 22.0214 - val_accuracy: 0.5556 - val_loss: 20.8703 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 489ms/step - accur

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5294 - loss: 36.2586 - val_accuracy: 0.5556 - val_loss: 30.5167 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.6324 - loss: 30.5104 - val_accuracy: 0.5000 - val_loss: 27.7761 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.6176 - loss: 27.3057 - val_accuracy: 0.4444 - val_loss: 25.2891 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.7206 - loss: 24.8212 - val_accuracy: 0.5000 - val_loss: 23.2991 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.9412 - loss: 22.6900 - val_accuracy: 0.4444 - val_loss: 21.4162 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 490ms/step - accur

  super().__init__(


Epoch 1/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 854ms/step - accuracy: 0.5455 - loss: 44.6405 - val_accuracy: 0.4444 - val_loss: 35.5819 - learning_rate: 0.0010
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.5905 - loss: 34.4208 - val_accuracy: 0.5556 - val_loss: 30.7721 - learning_rate: 0.0010
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6668 - loss: 29.9487 - val_accuracy: 0.6667 - val_loss: 27.1901 - learning_rate: 0.0010
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - accuracy: 0.6776 - loss: 26.5007 - val_accuracy: 0.6667 - val_loss: 24.2216 - learning_rate: 0.0010
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.8263 - loss: 23.6273 - val_accuracy: 0.5556 - val_loss: 21.6819 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 427ms/step - acc

  super().__init__(


Epoch 1/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 1s/step - accuracy: 0.4770 - loss: 70.8420 - val_accuracy: 0.5556 - val_loss: 52.6310 - learning_rate: 0.0010
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.4926 - loss: 48.6858 - val_accuracy: 0.5000 - val_loss: 41.0236 - learning_rate: 0.0010
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.5949 - loss: 37.5694 - val_accuracy: 0.5556 - val_loss: 31.7768 - learning_rate: 0.0010
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.6507 - loss: 30.2368 - val_accuracy: 0.5556 - val_loss: 26.3650 - learning_rate: 0.0010
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.8610 - loss: 24.4748 - val_accuracy: 0.5556 - val_loss: 21.4303 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 667ms/step - accur

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4s/step - accuracy: 0.5441 - loss: 223.2331 - val_accuracy: 0.6667 - val_loss: 167.5450 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 60ms/step - accuracy: 0.8088 - loss: 167.4390 - val_accuracy: 0.5000 - val_loss: 137.0431 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 51ms/step - accuracy: 0.7794 - loss: 136.7630 - val_accuracy: 0.6111 - val_loss: 115.4525 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 50ms/step - accuracy: 0.7941 - loss: 115.3132 - val_accuracy: 0.4444 - val_loss: 100.4441 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 50ms/step - accuracy: 0.7059 - loss: 98.8379 - val_accuracy: 0.5556 - val_loss: 87.1152 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 615ms/step

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 5s/step - accuracy: 0.6471 - loss: 23.2521 - val_accuracy: 0.3889 - val_loss: 20.1847 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 75ms/step - accuracy: 0.8529 - loss: 19.8326 - val_accuracy: 0.3889 - val_loss: 17.3881 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 67ms/step - accuracy: 0.9853 - loss: 16.8836 - val_accuracy: 0.4444 - val_loss: 15.4585 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 67ms/step - accuracy: 0.9559 - loss: 14.8113 - val_accuracy: 0.5556 - val_loss: 14.6348 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 63ms/step - accuracy: 0.8676 - loss: 13.3040 - val_accuracy: 0.5000 - val_loss: 15.0149 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 769ms/step - accur

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 5s/step - accuracy: 0.4706 - loss: 73.3815 - val_accuracy: 0.4444 - val_loss: 58.5806 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step - accuracy: 0.9118 - loss: 57.8820 - val_accuracy: 0.4444 - val_loss: 55.9401 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 64ms/step - accuracy: 0.6176 - loss: 50.6521 - val_accuracy: 0.5556 - val_loss: 47.9495 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 66ms/step - accuracy: 0.4706 - loss: 43.2604 - val_accuracy: 0.4444 - val_loss: 42.9712 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 63ms/step - accuracy: 0.7941 - loss: 37.8338 - val_accuracy: 0.5556 - val_loss: 34.8390 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 825ms/step - accur

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.4412 - loss: 74.1248 - val_accuracy: 0.5000 - val_loss: 61.0009 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 43ms/step - accuracy: 0.6912 - loss: 60.9210 - val_accuracy: 0.6667 - val_loss: 54.8365 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.8235 - loss: 54.7095 - val_accuracy: 0.7222 - val_loss: 50.1908 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.8676 - loss: 50.0235 - val_accuracy: 0.7222 - val_loss: 46.3987 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 55ms/step - accuracy: 0.9118 - loss: 46.1510 - val_accuracy: 0.5556 - val_loss: 43.1316 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 502ms/step - accuracy: 0.681

  super().__init__(


Epoch 1/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 1s/step - accuracy: 0.3826 - loss: 58.6704 - val_accuracy: 0.4444 - val_loss: 46.8738 - learning_rate: 0.0010
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.5568 - loss: 45.1231 - val_accuracy: 0.5000 - val_loss: 39.0955 - learning_rate: 0.0010
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.7461 - loss: 37.7478 - val_accuracy: 0.5556 - val_loss: 33.3310 - learning_rate: 0.0010
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.8635 - loss: 32.0277 - val_accuracy: 0.6667 - val_loss: 28.6299 - learning_rate: 0.0010
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.8752 - loss: 27.4983 - val_accuracy: 0.6667 - val_loss: 24.9044 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 581ms/step - accur

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5441 - loss: 69.0628 - val_accuracy: 0.6111 - val_loss: 56.5829 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 46ms/step - accuracy: 0.7794 - loss: 56.3754 - val_accuracy: 0.5000 - val_loss: 49.5106 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - accuracy: 0.7353 - loss: 48.9271 - val_accuracy: 0.4444 - val_loss: 44.0562 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 40ms/step - accuracy: 0.7941 - loss: 43.3533 - val_accuracy: 0.5000 - val_loss: 39.5907 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 40ms/step - accuracy: 0.9118 - loss: 38.8609 - val_accuracy: 0.6667 - val_loss: 35.5435 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 508ms/step - accur

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5147 - loss: 17.5704 - val_accuracy: 0.6111 - val_loss: 11.5562 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.5735 - loss: 12.0113 - val_accuracy: 0.5556 - val_loss: 11.2002 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.6471 - loss: 11.1829 - val_accuracy: 0.5000 - val_loss: 10.7544 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.6324 - loss: 10.7666 - val_accuracy: 0.5000 - val_loss: 10.3456 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.6471 - loss: 10.2171 - val_accuracy: 0.5000 - val_loss: 9.9621 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 648ms/step - accura

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.5441 - loss: 144.6169 - val_accuracy: 0.4444 - val_loss: 122.0854 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 46ms/step - accuracy: 0.5294 - loss: 122.0523 - val_accuracy: 0.4444 - val_loss: 108.3968 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 41ms/step - accuracy: 0.7500 - loss: 108.1008 - val_accuracy: 0.3889 - val_loss: 97.9126 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - accuracy: 0.7206 - loss: 97.7652 - val_accuracy: 0.4444 - val_loss: 90.2073 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 41ms/step - accuracy: 0.6618 - loss: 89.4070 - val_accuracy: 0.5000 - val_loss: 82.2784 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 559ms/step - accuracy: 

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4853 - loss: 38.8051 - val_accuracy: 0.6111 - val_loss: 30.6796 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 0.6618 - loss: 30.5880 - val_accuracy: 0.5556 - val_loss: 28.3745 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.6176 - loss: 28.2593 - val_accuracy: 0.5000 - val_loss: 26.6228 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.6324 - loss: 26.4686 - val_accuracy: 0.4444 - val_loss: 25.1333 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.7647 - loss: 24.9347 - val_accuracy: 0.5000 - val_loss: 23.8282 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 483ms/step - accur

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.4706 - loss: 7.6063 - val_accuracy: 0.5556 - val_loss: 6.2018 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.7794 - loss: 6.0535 - val_accuracy: 0.5556 - val_loss: 5.6664 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.8971 - loss: 5.4858 - val_accuracy: 0.5556 - val_loss: 5.2749 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.9706 - loss: 5.0639 - val_accuracy: 0.6111 - val_loss: 4.9481 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.9559 - loss: 4.7260 - val_accuracy: 0.6111 - val_loss: 4.6766 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 624ms/step - accuracy: 0.8182 - loss: 

  super().__init__(


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4118 - loss: 72.0755 - val_accuracy: 0.4444 - val_loss: 58.4119 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - accuracy: 0.7353 - loss: 58.2708 - val_accuracy: 0.5556 - val_loss: 50.6149 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.6618 - loss: 50.2401 - val_accuracy: 0.4444 - val_loss: 46.0878 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step - accuracy: 0.6618 - loss: 44.4053 - val_accuracy: 0.5556 - val_loss: 40.2018 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step - accuracy: 0.7794 - loss: 39.5902 - val_accuracy: 0.4444 - val_loss: 36.6692 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 329ms/step - accuracy: 0.500

  super().__init__(


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 892ms/step - accuracy: 0.5338 - loss: 95.6195 - val_accuracy: 0.4444 - val_loss: 68.4448 - learning_rate: 0.0010
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.6820 - loss: 64.4864 - val_accuracy: 0.3333 - val_loss: 51.9365 - learning_rate: 0.0010
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.5875 - loss: 49.3110 - val_accuracy: 0.3889 - val_loss: 40.7892 - learning_rate: 0.0010
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.6772 - loss: 38.6940 - val_accuracy: 0.3889 - val_loss: 32.5082 - learning_rate: 0.0010
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 0.7808 - loss: 30.6739 - val_accuracy: 0.5556 - val_loss: 25.9664 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 424ms/step - accuracy: 0.

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4853 - loss: 138.9495 - val_accuracy: 0.5000 - val_loss: 113.6793 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.6912 - loss: 113.5782 - val_accuracy: 0.5000 - val_loss: 100.2464 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.6324 - loss: 100.1989 - val_accuracy: 0.5556 - val_loss: 90.1998 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.6618 - loss: 90.1323 - val_accuracy: 0.5556 - val_loss: 82.0555 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.6618 - loss: 81.9803 - val_accuracy: 0.4444 - val_loss: 75.1681 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 490ms/step - 

  super().__init__(


Epoch 1/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4s/step - accuracy: 0.5000 - loss: 126.3484 - val_accuracy: 0.6111 - val_loss: 104.8660 - learning_rate: 0.0010
Epoch 2/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 112ms/step - accuracy: 0.5588 - loss: 103.7674 - val_accuracy: 0.4444 - val_loss: 96.6603 - learning_rate: 0.0010
Epoch 3/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 106ms/step - accuracy: 0.5441 - loss: 92.5707 - val_accuracy: 0.4444 - val_loss: 86.1050 - learning_rate: 0.0010
Epoch 4/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 107ms/step - accuracy: 0.7206 - loss: 79.1065 - val_accuracy: 0.5556 - val_loss: 76.3749 - learning_rate: 0.0010
Epoch 5/5
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 108ms/step - accuracy: 0.6176 - loss: 71.2330 - val_accuracy: 0.4444 - val_loss: 71.5466 - learning_rate: 0.0010


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 718ms/step 

In [35]:
trial_metrics_dict = {
    "train_acc": [],
    "test_acc": [],
    "val_acc": [],
    "train_val_diff": [],
    "scores": [],
    "channels_selected": [],
}
for i, trial in enumerate(study.trials_dataframe().itertuples()):
    trial_user_attrs = trial.user_attrs_trial_data
    trial_metrics_dict["scores"].append(trial.value)
    trial_metrics_dict["train_acc"].append(np.max(trial_user_attrs['train_accuracy']))
    trial_metrics_dict["test_acc"].append(trial_user_attrs['test_accuracy'])
    trial_metrics_dict["val_acc"].append(np.max(trial_user_attrs['val_accuracy']))
    trial_metrics_dict["channels_selected"].append(trial_user_attrs['channels_selected'])
    trial_metrics_dict["train_val_diff"].append(abs(np.max(trial_user_attrs['train_accuracy']) - np.max(trial_user_attrs['val_accuracy'])))
trial_metrics_df = pd.DataFrame(trial_metrics_dict)
# trial_metrics_df.sort_values("test_acc", ascending=False)
trial_metrics_df.sort_values("scores", ascending=True)
# trial_metrics_df.sort_values("train_val_diff", ascending=True)

Unnamed: 0,train_acc,test_acc,val_acc,train_val_diff,scores,channels_selected
9,0.911765,0.681818,0.722222,0.189542,0.214396,"[F3, Fz, C4, P4, O1, F8, T4]"
10,0.867647,0.727273,0.666667,0.20098,0.285014,"[P3, C3, Fz, C4, Cz, Fp2, T5, A2]"
4,0.808824,0.590909,0.666667,0.142157,0.300254,"[F3, C4, P4, Cz, Pz, Fp1, Fp2, F7, A2, T4]"
2,0.955882,0.5,0.666667,0.289216,0.3199,"[F3, Fz, F4, C4, P4, Cz, Fp2, O2, F7, F8, T6]"
15,0.970588,0.818182,0.611111,0.359477,0.320427,"[F3, Fz, F4, Pz, Fp1, T3, O2, A2]"
6,0.808824,0.5,0.666667,0.142157,0.329906,"[P3, C3, F3, Fz, P4, Cz, Fp1, T5, F7, A2, T6, T4]"
11,0.911765,0.681818,0.666667,0.245098,0.33492,"[P3, C3, F3, Fz, C4, Fp1, Fp2, T5, O2, F7, F8,..."
14,0.764706,0.363636,0.611111,0.153595,0.345383,"[P4, Pz, T3, T5, O1, O2, A2, T6]"
5,0.823529,0.454545,0.555556,0.267974,0.391301,"[P3, F3, F4, Pz, Fp1, T5, O1, T6, T4]"
16,0.779412,0.5,0.555556,0.223856,0.418174,"[C3, Fz, P4, Fp2, O1, O2, F7, T6, T4]"


In [1]:
max_val_acc = trial_metrics_df.sort_values("val_acc", ascending=False).query("train_acc > val_acc")[:10].sort_values('train_acc', ascending=False)["val_acc"].max()
trial_metrics_df.sort_values("val_acc", ascending=False).query("train_acc > val_acc")[:10].sort_values('train_acc', ascending=False).query(f"val_acc == {max_val_acc}")

NameError: name 'trial_metrics_df' is not defined

In [41]:
max_val_acc = trial_metrics_df.sort_values("val_acc", ascending=False).query("train_acc > val_acc")[:10].sort_values('train_acc', ascending=False)["val_acc"].max()
# best_candidate = trial_metrics_df.sort_values("scores", ascending=False).query(f"val_acc == {max_val_acc}").sort_values('train_acc', ascending=False)
# best_candidate = trial_metrics_df.sort_values("val_acc", ascending=False).query("train_acc > val_acc")[:10].sort_values('train_acc', ascending=False).query(f"val_acc == {max_val_acc}").sort_values("train_val_diff", ascending=True)
best_candidate = trial_metrics_df.sort_values("val_acc", ascending=False).query("train_acc > val_acc")[:10].sort_values('train_acc', ascending=False).query(f"val_acc == {max_val_acc}").query("train_acc >= 0.7").sort_values("train_val_diff", ascending=True)
best_candidate = best_candidate.sort_values("train_acc", ascending=False)
display(best_candidate)
best_candidate_trial_id = best_candidate.index[0]

rpprint({ k: v for k, v in study.best_trial.params.items() if not k.startswith("channels") })
rprint("test_accuracy =", study.trials[best_candidate_trial_id].user_attrs["trial_data"]["test_accuracy"])
rprint("val_accuracy =", np.max(study.trials[best_candidate_trial_id].user_attrs["trial_data"]["val_accuracy"]))
rprint("channels_selected =", study.trials[best_candidate_trial_id].user_attrs["trial_data"]["channels_selected"])

Unnamed: 0,train_acc,test_acc,val_acc,train_val_diff,scores,channels_selected
9,0.911765,0.681818,0.722222,0.189542,0.214396,"[F3, Fz, C4, P4, O1, F8, T4]"


In [139]:
# max_val_acc = trial_metrics_df.sort_values("train_val_diff", ascending=True).query("train_acc >= val_acc")[:15]["val_acc"].max()
# # max_val_acc = trial_metrics_df.sort_values("train_val_diff", ascending=True)[:15]["val_acc"].max()
# best_candidate = trial_metrics_df.sort_values("train_acc", ascending=False).query(f"val_acc == {max_val_acc}")
# display(best_candidate)
# best_candidate_trial_id = best_candidate.index[0]

# rpprint({ k: v for k, v in study.best_trial.params.items() if not k.startswith("channels") })
# rprint("test_accuracy =", study.trials[best_candidate_trial_id].user_attrs["trial_data"]["test_accuracy"])
# rprint("val_accuracy =", np.max(study.trials[best_candidate_trial_id].user_attrs["trial_data"]["val_accuracy"]))
# rprint("channels_selected =", study.trials[best_candidate_trial_id].user_attrs["trial_data"]["channels_selected"])

In [17]:
# Save to disk
subject = 12
temp_id = uuid.uuid4()
np.save(f"./best_subjects_search_results/subject_{subject}_12_shallow_conv_net_{temp_id}.npy", study.trials[best_candidate_trial_id])
np.save(f"./best_subjects_search_results/weights/subject_{subject}_12_shallow_conv_net_{temp_id}__weights.npy", study.trials[best_candidate_trial_id].user_attrs["trial_data"]["model_weights"])

  arr = np.asanyarray(arr)


In [42]:
rpprint({ k: v for k, v in study.best_trial.params.items() if not k.startswith("channels") })
rprint("test_accuracy =", study.best_trial.user_attrs["trial_data"]["test_accuracy"])
rprint("val_accuracy =", np.max(study.best_trial.user_attrs["trial_data"]["val_accuracy"]))
rprint("channels_selected =", study.best_trial.user_attrs["trial_data"]["channels_selected"])

In [44]:
subject_data = study.trials[best_candidate_trial_id]
subject_trial_data = subject_data.user_attrs['trial_data']
subject_model_json = subject_trial_data['model']
subject_model = keras.models.model_from_json(subject_model_json, custom_objects=CUSTOM_OBJECTS)
subject_model_weights = subject_trial_data['model_weights']
subject_model.set_weights(subject_model_weights)

def channels_to_channels_idx(channels, dataset):
    all_channels = dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1]
    channels_dict = { k: (True if k in channels else False) for k in all_channels }
    channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    return channels_idx

subject_channels_idx = channels_to_channels_idx(subject_trial_data['channels_selected'], fat_dataset)
subject_sfreq = subject_trial_data['sfreq'] if 'sfreq' in subject_trial_data else 128
subject_data_hyper_params = { k: v for k, v in subject_data.params.items() if not k.startswith('channels_') }
subject_batch_size = subject_data.params['batch_size']

rpprint({ k: v for k, v in subject_data.params.items() if not k.startswith("channels") })
rprint("test_accuracy =", subject_data.user_attrs["trial_data"]["test_accuracy"])
rprint("val_accuracy =", np.max(subject_data.user_attrs["trial_data"]["val_accuracy"]))
rprint("channels_selected =", subject_data.user_attrs["trial_data"]["channels_selected"])

X, y, _ = data_generator(fat_dataset, subjects=[12], channel_idx=subject_channels_idx, sfreq=subject_sfreq)
y_encoded = LabelEncoder().fit_transform(y)

subject_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
subject_model.evaluate(X, y_encoded)

Adding metadata with 3 columns


RuntimeError: Unable to automatically build the model. Please build it yourself before calling fit/evaluate/predict. A model is 'built' when its variables have been created and its `self.built` attribute is True. Usually, calling the model on a batch of data is the right way to build it.
Exception encountered:
'Input 0 of layer "functional_191" is incompatible with the layer: expected shape=(None, 7, 601, 1), found shape=(32, 7, 256)'

In [None]:
# channels = fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1]
# channels_of_interest = ['C3', 'F3', 'C4', 'T5', 'O1', 'O2', 'A2', 'T6']
# channels_of_interest_idx = np.where(np.isin(channels, channels_of_interest))[0]

# channels_of_interest_idx

In [None]:
# Model/pipeline hyperparameter optimization test
data_store_dict = {}

def objective_fn_lstm_cnn_net(trial, **kwargs):
    # _SFREQ_ = trial.suggest_categorical("sfreq", [128, 250])
    _SFREQ_ = 300
    # _TRAIN_SIZE_ = trial.suggest_categorical("train_size", [0.8, 0.9])
    _TRAIN_SIZE_ = 0.8
    _TEST_SIZE_ = 1 - _TRAIN_SIZE_
    _BATCH_SIZE = 256
    
    channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
    channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    while len(channels_idx) == 0:
        channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
        channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    
    X, y, _ = data_generator(fat_dataset, subjects=[6], channel_idx=channels_idx, sfreq=_SFREQ_)
    
    _NUM_SAMPLES_ = X.shape[-1]
    _NUM_CHANNELS_ = X.shape[-2]
    # rpprint({
    #     "NUM_SAMPLES": _NUM_SAMPLES_,
    #     "NUM_CHANNELS": _NUM_CHANNELS_,
    #     "X": X.shape,
    #     "y": y.shape,
    #     "channels selected": np.asarray(fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'])[channels_idx]
    # })
    y_encoded = LabelEncoder().fit_transform(y)
    _NUM_CLASSES_ = len(np.unique(y_encoded))
    X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, _NUM_SAMPLES_, _NUM_CHANNELS_), y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
    # X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
    
    model_params = {
        "nb_classes": _NUM_CLASSES_,
        "channels": _NUM_CHANNELS_,
        "samples": _NUM_SAMPLES_,
    }
    
    model = lstm_cnn_net_v2(**model_params)
    # model = lstm_cnn_net(**model_params)
    # model = shallow_conv_net(**model_params)
    model.compile(
        loss="sparse_categorical_crossentropy", 
        optimizer="adam", 
        metrics=["accuracy"]
    )
    
    history = model.fit(
        X_train,
        y_train,
        batch_size=_BATCH_SIZE,
        epochs=25,
        validation_split=0.2,
        callbacks=[
            # keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
            # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
            keras.callbacks.EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True),
            keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
        ],
    )

    trial.set_user_attr("trial_data", {
        "channels_selected": np.asarray(fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'])[channels_idx],
        # "history": history,
        # "model": model,
        "val_accuracy": history.history["val_accuracy"],
        "test_accuracy": model.evaluate(X_test, y_test, batch_size=_BATCH_SIZE)[1]
    })

    # return (1 - np.mean(history.history["val_accuracy"][-2:])) + 0.005*len(channels_idx)
    return (1 - history.history["val_accuracy"][-1])**2 + 0.00005*len(channels_idx)

def heuristic_optimizer(obj_fn = lambda: None, max_iter = 15, show_progress_bar = True, **kwargs):
    
    optuna.logging.set_verbosity(optuna.logging.CRITICAL)

    # trial_data = []
    # trial_callback = lambda study, trial: trial_data.append(trial.user_attrs["trial_data"])
    
    # study = optuna.create_study(direction="minimize", sampler=optuna.samplers.CmaEsSampler())
    study = optuna.create_study(direction="minimize", sampler=optuna.samplers.NSGAIISampler())
    study.optimize(obj_fn, n_trials=max_iter, show_progress_bar=show_progress_bar, callbacks=[], **kwargs)
    
    return study

study = heuristic_optimizer(obj_fn = objective_fn_lstm_cnn_net)

In [312]:
rpprint({ k: v for k, v in study.best_trial.params.items() if not k.startswith("channels") })
rprint("test_accuracy =", study.best_trial.user_attrs["trial_data"]["test_accuracy"])
rprint("val_accuracy =", np.max(study.best_trial.user_attrs["trial_data"]["val_accuracy"]))
rprint("channels_selected =", study.best_trial.user_attrs["trial_data"]["channels_selected"])

In [None]:
# Save model
model.save("model.keras")

In [None]:
# Load model from disk
loaded_model = keras.saving.load_model("model.keras", custom_objects=CUSTOM_OBJECTS)

In [None]:
## Pruning and quantization test

from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper

baseline_model = loaded_model
pruning_params = {
    'pruning_schedule': pruning_wrapper.PolynomialDecay(initial_sparsity=0.00,
                                                        final_sparsity=0.80,
                                                        begin_step=0,
                                                        end_step=100)
}
sparse_model = pruning_wrapper.PruneLowMagnitude(baseline_model, **pruning_params)

sparse_model.summary()

### Data exploration, etc.

In [None]:
def create_spectrogram(epoch_data):
    spectrograms = []
    for epoch in epoch_data:
        epoch_spectrograms = []
        for channel in epoch:
            f, t, Sxx = sp.signal.spectrogram(channel, fs=50)
            epoch_spectrograms.append(Sxx)
        spectrograms.append(np.array(epoch_spectrograms))
    return np.array(spectrograms)

spectograms = create_spectrogram(X)
rprint(spectograms.shape) # (180, 21, 129, 2)

num_left_hand_epochs_to_visualize = 2
num_right_hand_epochs_to_visualize = 2

# Select epochs for visualization
left_hand_epochs_to_visualize = np.where((y == 'left_hand'))[0]
right_hand_epochs_to_visualize = np.where((y == 'right_hand'))[0]
left_hand_epochs_to_visualize = np.random.choice(left_hand_epochs_to_visualize, num_left_hand_epochs_to_visualize)
right_hand_epochs_to_visualize = np.random.choice(right_hand_epochs_to_visualize, num_right_hand_epochs_to_visualize)
epochs_to_visualize = np.concatenate((left_hand_epochs_to_visualize, right_hand_epochs_to_visualize))

rprint(epochs_to_visualize)

# Visualize spectrograms
fig, axs = plt.subplots(epochs_to_visualize.size, 1, figsize=(10, 10))
for i, idx in enumerate(epochs_to_visualize):
    axs[i].matshow(spectograms[idx, :, :, 0], cmap='viridis')
    axs[i].set_title(y[idx])
    axs[i].set_xlabel('Frequency (Hz)')
    axs[i].set_ylabel('Time (s)')
    axs[i].set_xlim(0, 60) # Limit x-axis to 60Hz

plt.tight_layout()
plt.show()