In [None]:
%%capture
%load_ext autoreload
%autoreload 2

In [None]:
# region General Imports
import os
import uuid
import shutil
import time
import random
import datetime
import glob
import pickle
import tqdm
import copy
import optuna
import numpy as np
import pandas as pd
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import mne
from rich import print as rprint
from rich.pretty import pprint as rpprint
from tqdm import tqdm
from itertools import chain
from functools import partial
# endregion General Imports

from model_optim.model_optimizer import ModelOptimizer

# Dataset
from custom_datasets.fatigue_mi import FatigueMI
from custom_datasets.norm_cho import NormCho2017
from custom_datasets.opt_game_mi import OptGameMI
from custom_datasets.opt_std_mi import OptStdMI

tf.random.set_seed(42)
np.random.seed(42)

In [None]:
MODELS_LIST = [
    "shallow_conv_net",
    "lstm_net",
    "deep_conv_net",
    "eeg_net",
    "lstm_cnn_net",
    "lstm_cnn_net_v2"
]
MODELS_HYPERPARAMS_DICT = {
    "shallow_conv_net": {
        "max_epochs": 10
    },
    "eeg_net": {
        "max_epochs": 10
    },
    "deep_conv_net": {
        "max_epochs": 50
    },
    "lstm_net": {
        "max_epochs": 10
    },
    "lstm_cnn_net": {
        "max_epochs": 10
    },
    "lstm_cnn_net_v2": {},
}

DATASETS_LIST = [
    FatigueMI,
    # NormCho2017,
    # OptGameMI,
    # OptStdMI,
]

In [None]:
# from model_optim.utils import data_generator

# data_generator(
#     dataset=NormCho2017(),
#     subjects=[1],
#     channel_idx=[],
#     sfreq=128,
# )

In [None]:
subject_files = glob.glob("./temp_v2/*/*/model/study_best_trial.npy")
subject_files_data = {}
for subject_file in subject_files:
    subject_files_data[subject_file] = np.load(subject_file, allow_pickle=True).item()


In [None]:
# Sort subject_files_data by subject_file_data.user_attrs['trial_data']['test_accuracy']
sorted_subject_files_data = dict(sorted(subject_files_data.items(), key=lambda item: item[1].user_attrs['trial_data']['test_accuracy'], reverse=True))
sorted_subject_files_data_test_acc = {k: v.user_attrs['trial_data']['test_accuracy'] for k, v in sorted_subject_files_data.items()}
rpprint(sorted_subject_files_data_test_acc)

In [None]:
for dataset in [FatigueMI]:
    for model in ["deep_conv_net"]:
        model_optimizer = ModelOptimizer(
            dataset=dataset(),
            model_name=model
        )
        for subject in [12]:
            max_epochs = MODELS_HYPERPARAMS_DICT[model]["max_epochs"]
            study = model_optimizer.search_best_model(
                subjects = [subject],
                max_iter = 25,
                max_epochs = max_epochs,
                max_stag_count = 10,
                rounds = 1,
                replace_previous_study_for_subjects = False
            )

In [None]:
model_optimizer.get_study_metrics(study).sort_values(by="scores", ascending=True)

In [None]:
rpprint({ k: v for k, v in study.best_trial.params.items() if not k.startswith("channels") })
rprint("test_accuracy =", study.best_trial.user_attrs["trial_data"]["test_accuracy"])
rprint("val_accuracy =", np.max(study.best_trial.user_attrs["trial_data"]["val_accuracy"]))
rprint("channels_selected =", study.best_trial.user_attrs["trial_data"]["channels_selected"])

In [35]:
temp_fatigue_mi_studies = glob.glob("./temp/FatigueMI/**/**/model/*_study.npy")
temp_fatigue_mi_studies_dict = {}

for study_file in temp_fatigue_mi_studies:
    study = np.load(study_file, allow_pickle=True).item()
    subject_number = int(study_file.split("[")[1].split(']')[0])
    model_name = study_file.split("/")[-1].replace("_study.npy", "")
    temp_fatigue_mi_studies_dict[f"{subject_number}_{model_name}"] = study

In [67]:
filtered_study_trials_concat_df = pd.DataFrame()

for subject_model in temp_fatigue_mi_studies_dict:
    study = temp_fatigue_mi_studies_dict[subject_model]
    study_trials_df = model_optimizer.get_study_metrics(study)
    # Filter: Top 10 best scores -> Max training accuracy -> Minimum difference between training and validation accuracy -> Max test accuracy = best model
    filtered_study_trials_df = study_trials_df.copy()
    filtered_study_trials_df = filtered_study_trials_df.nsmallest(1, 'scores')
    # filtered_study_trials_df = filtered_study_trials_df.nsmallest(5, 'train_val_acc_diff')
    # filtered_study_trials_df = filtered_study_trials_df[filtered_study_trials_df['train_acc'] == max(filtered_study_trials_df['train_acc'])]
    # filtered_study_trials_df = filtered_study_trials_df[filtered_study_trials_df['train_val_acc_diff'] == min(filtered_study_trials_df['train_val_acc_diff'])]
    # filtered_study_trials_df = filtered_study_trials_df[filtered_study_trials_df['test_acc'] == max(filtered_study_trials_df['test_acc'])]
    filtered_study_trials_concat_df = pd.concat([filtered_study_trials_concat_df, filtered_study_trials_df])
display(filtered_study_trials_concat_df)

Unnamed: 0,train_acc,test_acc,val_acc,train_val_acc_diff,train_loss,val_loss,train_val_loss_diff,test_loss,scores,channels_selected,sfreq,batch_size,model_name,subjects
19,0.838235,0.818182,0.722222,0.116013,1.209507,1.407622,0.198116,1.633958,0.077361,"[C3, Fz, T6, T4]",256.0,192,eeg_net,[9]
7,0.573529,0.545455,0.555556,0.017974,0.812133,0.654432,0.157701,0.673368,0.198031,"[P3, C3, C4, Pz, Fp1, T3, O1, O2, F8, T6]",128.0,256,deep_conv_net,[9]
18,0.867647,0.545455,0.666667,0.20098,84.76915,82.498848,2.270302,90.271187,0.111561,"[F3, Fz, F4, Fp2, T3, T5, F7, A2, T4]",256.0,224,lstm_cnn_net,[9]
15,0.970588,0.545455,0.777778,0.19281,14.35515,14.367508,0.012358,18.714355,0.049883,"[F3, F4, Cz, Pz, Fp1, Fp2, T5, O1, A2, T6]",256.0,128,lstm_cnn_net,[10]
15,0.823529,0.5,0.722222,0.101307,0.924573,0.995178,0.070605,1.006173,0.077561,"[F3, Fz, F4, P4, Fp2, O2, A2, T6]",300.0,224,eeg_net,[10]
14,0.661765,0.590909,0.611111,0.050654,0.636553,0.60518,0.031373,0.682946,0.151685,"[P3, C3, F4, C4, P4, Fp2, F7, A2, T4]",300.0,96,deep_conv_net,[10]
0,0.970588,0.454545,0.777778,0.19281,55.781464,51.300117,4.481346,57.484512,0.049683,"[Fz, F4, T5, F7, F8, T4]",256.0,32,lstm_cnn_net,[11]
3,0.970588,0.681818,0.722222,0.248366,0.755424,1.030562,0.275138,1.222213,0.077861,"[P3, C3, F3, Fz, P4, Cz, T3, T5, O1, F7, F8, A...",256.0,256,eeg_net,[11]
1,0.720588,0.363636,0.666667,0.053922,0.614674,0.683233,0.068558,0.750804,0.111461,"[F3, P4, Cz, Pz, T3, T5, A2]",128.0,160,deep_conv_net,[11]
13,0.941176,0.545455,0.611111,0.330065,1.290258,1.403062,0.112804,1.581581,0.151635,"[Fz, F4, P4, Cz, Fp2, O1, O2, T4]",256.0,192,eeg_net,[2]


In [None]:
filtered_study_trials_concat_df.query("model_name == 'eeg_net'")

Unnamed: 0,train_acc,test_acc,val_acc,train_val_acc_diff,train_loss,val_loss,train_val_loss_diff,test_loss,scores,channels_selected,sfreq,batch_size,model_name,subjects
19,0.838235,0.818182,0.722222,0.116013,1.209507,1.407622,0.198116,1.633958,0.077361,"[C3, Fz, T6, T4]",256.0,192,eeg_net,[9]
14,0.661765,0.590909,0.611111,0.050654,0.636553,0.60518,0.031373,0.682946,0.151685,"[P3, C3, F4, C4, P4, Fp2, F7, A2, T4]",300.0,96,deep_conv_net,[10]
3,0.970588,0.681818,0.722222,0.248366,0.755424,1.030562,0.275138,1.222213,0.077861,"[P3, C3, F3, Fz, P4, Cz, T3, T5, O1, F7, F8, A...",256.0,256,eeg_net,[11]
0,0.882353,0.636364,0.722222,0.160131,13.190247,12.870646,0.3196,15.704945,0.077511,"[P3, C4, Pz, O1, O2, F7, T6]",128.0,224,lstm_cnn_net,[2]
1,0.735294,0.727273,0.666667,0.068627,0.553756,0.69879,0.145035,0.629561,0.111761,"[P3, F3, Fz, F4, P4, Pz, Fp1, Fp2, O1, F7, A2,...",300.0,224,deep_conv_net,[3]
10,0.955882,0.636364,0.722222,0.23366,0.831503,0.940101,0.108597,0.985933,0.077661,"[P3, F3, Fz, Cz, Pz, Fp1, F7, A2, T6, T4]",256.0,32,eeg_net,[6]
2,0.705882,0.681818,0.666667,0.039216,0.572419,0.690558,0.118139,0.673557,0.111661,"[P3, C3, F3, Fz, F4, C4, Cz, Pz, Fp2, O1, T4]",300.0,64,deep_conv_net,[6]
4,0.852941,0.727273,0.666667,0.186274,0.879401,0.99146,0.112059,1.045302,0.111711,"[C3, Fz, C4, P4, Pz, Fp1, T3, T5, O2, F7, A2, T4]",256.0,64,eeg_net,[5]
23,0.852941,0.727273,0.888889,0.035948,0.388678,0.429204,0.040526,0.646307,0.262646,"[C3, F3, C4, Cz, Fp2, T6]",128.0,128,deep_conv_net,[12]
14,0.779412,0.727273,0.666667,0.112745,0.741722,0.851026,0.109304,0.850154,0.111411,"[C3, F3, Fz, F4, Cz, Pz]",300.0,32,eeg_net,[4]
