In [2]:
%%capture
%load_ext autoreload
%autoreload 2

In [4]:
# region General Imports
import os
import uuid
import shutil
import time
import random
import datetime
import glob
import pickle
import tqdm
import copy
import optuna
import numpy as np
import pandas as pd
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import mne
from rich import print as rprint
from rich.pretty import pprint as rpprint
from tqdm import tqdm
from itertools import chain
from functools import partial
# endregion General Imports

from model_optim.model_optimizer import ModelOptimizer

# Dataset
from custom_datasets.fatigue_mi import FatigueMI
from custom_datasets.norm_cho import NormCho2017
from custom_datasets.opt_game_mi import OptGameMI
from custom_datasets.opt_std_mi import OptStdMI

tf.random.set_seed(42)
np.random.seed(42)

In [6]:
MODELS_LIST = [
    "shallow_conv_net",
    "lstm_net",
    "deep_conv_net",
    "eeg_net",
    "lstm_cnn_net",
    "lstm_cnn_net_v2"
]
MODELS_HYPERPARAMS_DICT = {
    "shallow_conv_net": {
        "max_epochs": 25
    },
    "eeg_net": {
        "max_epochs": 25
    },
    "deep_conv_net": {
        "max_epochs": 25
    },
    "lstm_net": {
        "max_epochs": 10
    },
    "lstm_cnn_net": {
        "max_epochs": 10
    },
    "lstm_cnn_net_v2": {},
}

DATASETS_LIST = [
    FatigueMI,
    # NormCho2017,
    # OptGameMI,
    # OptStdMI,
]

In [4]:
# from model_optim.utils import data_generator

# data_generator(
#     dataset=NormCho2017(),
#     subjects=[1],
#     channel_idx=[],
#     sfreq=128,
# )

In [7]:
subject_files = glob.glob("./temp_v2/*/*/model/shallow_conv_net_study_best_trial.npy")
subject_files_data = {}
for subject_file in subject_files:
    subject_files_data[subject_file] = np.load(subject_file, allow_pickle=True).item()


In [8]:
# Sort subject_files_data by subject_file_data.user_attrs['trial_data']['test_accuracy']
sorted_subject_files_data = dict(sorted(subject_files_data.items(), key=lambda item: item[1].user_attrs['trial_data']['test_accuracy'], reverse=True))
sorted_subject_files_data_test_acc = {k: v.user_attrs['trial_data']['test_accuracy'] for k, v in sorted_subject_files_data.items()}
rpprint(sorted_subject_files_data_test_acc)

In [9]:
model_optimizer = ModelOptimizer(
    dataset=FatigueMI(),
    model_name="shallow_conv_net"
)

Original sfreq: 300.0


In [10]:
for dataset in [FatigueMI]:
    for model in ["deep_conv_net"]:
        model_optimizer = ModelOptimizer(
            dataset=dataset(),
            model_name=model
        )
        for subject in [11]:
            max_epochs = MODELS_HYPERPARAMS_DICT[model]["max_epochs"]
            study = model_optimizer.search_best_model(
                subjects = [subject],
                max_iter = 25,
                max_epochs = max_epochs,
                max_stag_count = 10,
                rounds = 1,
                replace_previous_study_for_subjects = True
            )

Original sfreq: 300.0
Found previous study in ./temp/FatigueMI/[11]/2e54fc7d125447c096b49ecb93fd4f9a/model/deep_conv_net_study_best_trial.npy, removing...
Found previous study in ./temp/FatigueMI/[11]/2e54fc7d125447c096b49ecb93fd4f9a/model/deep_conv_net_study.npy, removing...


  0%|          | 0/25 [00:00<?, ?it/s]

Using sfreq: None; sfreq is None: True; self.original_sfreq: 300.0; sfreq=300.0
Adding metadata with 3 columns
Epoch 1/25


2024-05-02 23:15:51.166238: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:887] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-05-02 23:15:51.166604: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2256] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


Epoch 00001: val_accuracy improved from -inf to 0.55556, storing weights.
Epoch 2/25
Epoch 00002: val_accuracy did not improve
Epoch 3/25
Epoch 00003: val_accuracy did not improve
Epoch 4/25
Epoch 00004: val_accuracy improved from 0.55556 to 0.61111, storing weights.
Epoch 5/25
Epoch 00005: val_accuracy did not improve
Epoch 6/25
Epoch 00006: val_accuracy did not improve
Epoch 7/25
Epoch 00007: val_accuracy did not improve
Epoch 8/25
Epoch 00008: val_accuracy did not improve
Epoch 9/25
Epoch 00009: val_accuracy did not improve
Epoch 10/25
Epoch 00010: val_accuracy did not improve
Epoch 11/25
Epoch 00011: val_accuracy did not improve
Epoch 12/25
Epoch 00012: val_accuracy did not improve
Epoch 13/25
Epoch 00013: val_accuracy did not improve
Epoch 14/25
Epoch 00014: val_accuracy did not improve
Epoch 15/25
Epoch 00015: val_accuracy did not improve
Epoch 16/25
Epoch 00016: val_accuracy did not improve
Epoch 17/25
Epoch 00017: val_accuracy did not improve
Epoch 18/25
Epoch 00018: val_accura

In [10]:
model_optimizer.get_study_metrics(study).sort_values(by="scores", ascending=True)

Unnamed: 0,train_acc,test_acc,val_acc,train_val_acc_diff,train_loss,val_loss,train_val_loss_diff,test_loss,scores,channels_selected,sfreq,batch_size,model_name,subjects
0,0.779412,0.545455,0.611111,0.168301,0.483761,0.680379,0.196618,1.255373,0.151985,"[P3, C3, F3, Fz, F4, C4, Cz, Pz, Fp2, T5, O2, ...",128.0,160,deep_conv_net,[11]
1,0.779412,0.681818,0.777778,0.001634,0.451416,0.575806,0.12439,0.658583,0.299783,"[P3, Fz, P4, Cz, T3, T5, O2, T6]",256.0,32,deep_conv_net,[11]
5,0.779412,0.681818,0.722222,0.05719,0.477276,0.602562,0.125286,2.556679,0.327611,"[P3, C3, Fz, F4, Fp1, O1, F8, T6, T4]",300.0,128,deep_conv_net,[11]
4,0.779412,0.545455,0.666667,0.112745,0.516096,0.679674,0.163578,1.063509,0.361561,"[F3, Fz, F4, C4, T3, T5, F8, A2, T6]",256.0,224,deep_conv_net,[11]
7,0.705882,0.5,0.666667,0.039216,0.588393,0.647678,0.059285,0.793769,0.361611,"[P3, C3, Fz, F4, P4, Cz, Fp1, Fp2, T3, T6]",128.0,256,deep_conv_net,[11]
9,0.720588,0.590909,0.666667,0.053922,0.515996,0.632985,0.11699,2.764858,0.361711,"[P3, Fz, F4, Pz, Fp1, T3, T5, O2, F8, A2, T6, T4]",300.0,256,deep_conv_net,[11]
2,0.661765,0.545455,0.611111,0.050654,0.59055,0.681202,0.090652,1.409137,0.401685,"[F3, F4, P4, Pz, Fp2, T3, O1, F8, A2]",256.0,160,deep_conv_net,[11]
3,0.720588,0.5,0.611111,0.109477,0.543822,0.6287,0.084878,1.162843,0.401785,"[P3, F3, C4, P4, Cz, Fp2, T3, O1, O2, F8, T6]",128.0,192,deep_conv_net,[11]
8,0.735294,0.590909,0.555556,0.179739,0.529332,0.704364,0.175032,0.644767,0.447931,"[P3, Fz, F4, Fp1, Fp2, T3, O2, F7]",128.0,192,deep_conv_net,[11]
6,0.661765,0.5,0.555556,0.106209,0.583539,0.700953,0.117414,0.761109,0.448031,"[P3, C3, Fz, P4, Cz, T5, F7, A2, T6, T4]",128.0,64,deep_conv_net,[11]


In [11]:
rpprint({ k: v for k, v in study.best_trial.params.items() if not k.startswith("channels") })
rprint("test_accuracy =", study.best_trial.user_attrs["trial_data"]["test_accuracy"])
rprint("val_accuracy =", np.max(study.best_trial.user_attrs["trial_data"]["val_accuracy"]))
rprint("channels_selected =", study.best_trial.user_attrs["trial_data"]["channels_selected"])

In [8]:
temp_fatigue_mi_studies = glob.glob("./temp/FatigueMI/**/**/model/*_study.npy") + glob.glob("./temp/FatigueMI/**/**/model/shallow_conv_net_study_best_trial.npy")
temp_fatigue_mi_studies_dict = {}
temp_fatigue_mi_studies_file_names_dict = {}

for study_file in temp_fatigue_mi_studies:
    study = np.load(study_file, allow_pickle=True).item()
    subject_number = int(study_file.split("[")[1].split(']')[0])
    model_name = study_file.split("/")[-1].replace("_study.npy", "").replace("shallow_conv_net_study_best_trial.npy", "shallow_conv_net")
    temp_fatigue_mi_studies_dict[f"{subject_number}_{model_name}"] = study
    temp_fatigue_mi_studies_file_names_dict[f"{subject_number}_{model_name}"] = study_file

In [9]:
filtered_study_trials_concat_df = pd.DataFrame()

for subject_model in temp_fatigue_mi_studies_dict:
    study = temp_fatigue_mi_studies_dict[subject_model]
    study_trials_df = model_optimizer.get_study_metrics(study, **{ 
        "default_model_name": "shallow_conv_net", 
        "subjects": [subject_model.split("_")[0]],
        "file_path": temp_fatigue_mi_studies_file_names_dict[subject_model],
    })
    # Filter: Top 10 best scores -> Max training accuracy -> Minimum difference between training and validation accuracy -> Max test accuracy = best model
    filtered_study_trials_df = study_trials_df.copy()
    filtered_study_trials_df = filtered_study_trials_df.nsmallest(2, 'scores')
    # filtered_study_trials_df = filtered_study_trials_df.nsmallest(5, 'train_val_acc_diff')
    # filtered_study_trials_df = filtered_study_trials_df[filtered_study_trials_df['train_acc'] == max(filtered_study_trials_df['train_acc'])]
    # filtered_study_trials_df = filtered_study_trials_df[filtered_study_trials_df['train_val_acc_diff'] == min(filtered_study_trials_df['train_val_acc_diff'])]
    # filtered_study_trials_df = filtered_study_trials_df[filtered_study_trials_df['test_acc'] == max(filtered_study_trials_df['test_acc'])]
    # Add column trial_number to filtered_study_trials_df
    filtered_study_trials_df.insert(0, 'trial_number', filtered_study_trials_df.index)
    filtered_study_trials_concat_df = pd.concat([filtered_study_trials_concat_df, filtered_study_trials_df])
display(filtered_study_trials_concat_df)

Unnamed: 0,trial_number,train_acc,test_acc,val_acc,train_val_acc_diff,train_loss,val_loss,train_val_loss_diff,test_loss,scores,channels_selected,sfreq,batch_size,model_name,subjects,file_path
19,19,0.838235,0.818182,0.722222,0.116013,1.209507,1.407622,0.198116,1.633958,0.077361,"[C3, Fz, T6, T4]",256.0,192,eeg_net,[9],./temp/FatigueMI/[9]/fd4945e8633f4c0ab473e7520...
22,22,0.794118,0.818182,0.722222,0.071895,1.353778,1.540807,0.187029,1.678117,0.077361,"[C3, Fz, T6, T4]",256.0,160,eeg_net,[9],./temp/FatigueMI/[9]/fd4945e8633f4c0ab473e7520...
7,7,0.573529,0.545455,0.555556,0.017974,0.812133,0.654432,0.157701,0.673368,0.198031,"[P3, C3, C4, Pz, Fp1, T3, O1, O2, F8, T6]",128.0,256,deep_conv_net,[9],./temp/FatigueMI/[9]/0113b240ca004b67b81ca0fd5...
13,13,0.691176,0.863636,0.722222,0.031046,0.595995,0.62438,0.028385,0.578216,0.327611,"[P3, C3, P4, Pz, Fp1, T3, T5, A2, T6]",128.0,192,deep_conv_net,[9],./temp/FatigueMI/[9]/0113b240ca004b67b81ca0fd5...
18,18,0.867647,0.545455,0.666667,0.200980,84.76915,82.498848,2.270302,90.271187,0.111561,"[F3, Fz, F4, Fp2, T3, T5, F7, A2, T4]",256.0,224,lstm_cnn_net,[9],./temp/FatigueMI/[9]/4f2b98dbeb684bfa8f887b396...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,0,1.000000,0.636364,0.833333,0.166667,0.150787,0.580237,0.42945,0.998291,0.791186,"[P3, C3, F4, C4, Pz, Fp1, O1, F8, A2, T6]",,,shallow_conv_net,[13],./temp/FatigueMI/[13]/1b189965ada44ff99e73fa14...
0,0,0.926471,0.636364,0.666667,0.259804,,,,,0.901306,"[P3, C3, F4, C4, P4, Pz, Fp2, T3, O2]",,,shallow_conv_net,[5],./temp/FatigueMI/[5]/d199c9c2ac924b238693f158e...
0,0,1.000000,0.772727,0.888889,0.111111,0.189622,0.604003,0.414381,0.733969,0.665729,"[P3, C3, F3, C4, Cz, T5, O2, F7, T6]",,,shallow_conv_net,[12],./temp/FatigueMI/[12]/96dc576945fb4f2db582d66a...
0,0,1.000000,0.636364,0.722222,0.277778,,,,,0.878924,"[P3, C3, Fz, C4, Fp1, Fp2, T3, O1]",,,shallow_conv_net,[8],./temp/FatigueMI/[8]/9fd82ec44ef3496da6307b57e...


In [12]:
subjects_to_retain = [4, 6, 9, 11, 12]
subjects_to_retain_str = [f'{subject}' for subject in subjects_to_retain]
model_names_to_retain_str = ["shallow_conv_net", "deep_conv_net", "eeg_net"]

rprint(subjects_to_retain_str)

# Convert subjects written as "['1']" to "[1]"
filtered_study_trials_concat_df['subjects'] = filtered_study_trials_concat_df['subjects'].apply(lambda x: x[1:-1].replace("'", "") if isinstance(x, str) else x)

# Convert subjects to catagorical
filtered_study_trials_concat_df['subjects'] = pd.Categorical(filtered_study_trials_concat_df['subjects'].astype(str))
# filtered_study_trials_concat_df.query(f"subjects in {subjects_to_retain_str} and test_acc > 0.67")

# Get the models with the highest test_acc for each type of model_name (eeg_net, deep_conv_net, etc.) and for each subject (4, 6, 9, 10, 11, 12)
best_models_df = filtered_study_trials_concat_df.groupby(['subjects', 'model_name']).apply(lambda x: x.nlargest(1, 'test_acc')).reset_index(drop=True)
best_models_df = best_models_df.query(f"subjects in {subjects_to_retain_str}")
best_models_df = best_models_df.query(f"model_name in {model_names_to_retain_str}")
display(best_models_df)

Unnamed: 0,trial_number,train_acc,test_acc,val_acc,train_val_acc_diff,train_loss,val_loss,train_val_loss_diff,test_loss,scores,channels_selected,sfreq,batch_size,model_name,subjects,file_path
8,1,0.779412,0.681818,0.777778,0.001634,0.451416,0.575806,0.12439,0.658583,0.299783,"[P3, Fz, P4, Cz, T3, T5, O2, T6]",256.0,32.0,deep_conv_net,11,./temp/FatigueMI/[11]/2e54fc7d125447c096b49ecb...
9,3,0.970588,0.681818,0.722222,0.248366,0.755424,1.030562,0.275138,1.222213,0.077861,"[P3, C3, F3, Fz, P4, Cz, T3, T5, O1, F7, F8, A...",256.0,256.0,eeg_net,11,./temp/FatigueMI/[11]/6772e2405e6e436faba83820...
11,0,0.941176,0.727273,0.666667,0.27451,0.501189,0.778463,0.277274,0.757358,0.85302,"[P3, C3, Fz, F4, C4, P4, F7, F8]",,,shallow_conv_net,11,./temp/FatigueMI/[11]/e0643f9a780146a4adc15ddd...
12,23,0.852941,0.727273,0.888889,0.035948,0.388678,0.429204,0.040526,0.646307,0.262646,"[C3, F3, C4, Cz, Fp2, T6]",128.0,128.0,deep_conv_net,12,./temp/FatigueMI/[12]/259f8d3cf7ce4e5283e8e407...
13,15,0.955882,0.818182,0.833333,0.122549,0.432221,0.737542,0.305321,0.739283,0.028128,"[P3, C3, F3, C4, Cz, Pz, Fp2]",128.0,32.0,eeg_net,12,./temp/FatigueMI/[12]/fa0faed8a6ce4a52a2b9ca5d...
15,0,1.0,0.772727,0.888889,0.111111,0.189622,0.604003,0.414381,0.733969,0.665729,"[P3, C3, F3, C4, Cz, T5, O2, F7, T6]",,,shallow_conv_net,12,./temp/FatigueMI/[12]/96dc576945fb4f2db582d66a...
26,19,0.823529,0.818182,0.666667,0.156863,0.426087,0.746527,0.32044,0.510645,0.111411,"[F3, P4, Pz, F7, F8, A2]",300.0,32.0,deep_conv_net,4,./temp/FatigueMI/[4]/52166a0614d541acb9b9ef965...
27,11,0.955882,0.818182,0.722222,0.23366,0.718789,1.110411,0.391622,1.152245,0.077561,"[C3, F3, P4, Cz, Pz, T3, T6, T4]",256.0,160.0,eeg_net,4,./temp/FatigueMI/[4]/72578c4a27b64a8e989fd9072...
28,0,1.0,0.863636,0.888889,0.111111,,,,,0.52052,"[C3, F3, F4, P4, Cz, Fp1, Fp2, T5, O2, A2, T6,...",,,shallow_conv_net,4,./temp/FatigueMI/[4]/3623cb4ba1ad4a908c9098f52...
33,2,0.705882,0.681818,0.666667,0.039216,0.572419,0.690558,0.118139,0.673557,0.111661,"[P3, C3, F3, Fz, F4, C4, Cz, Pz, Fp2, O1, T4]",300.0,64.0,deep_conv_net,6,./temp/FatigueMI/[6]/0614f3b1603b4442a2cc79ade...


: 

In [13]:
best_models_df.to_csv("./final/best_models.csv")

: 