In [1]:
# region General Imports
import os
import re
import uuid
import shutil
import time
import random
import datetime
import glob
import pickle
import tqdm
import copy
import optuna
import numpy as np
import pandas as pd
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import mne
from rich import print as rprint
from rich.pretty import pprint as rpprint
from tqdm import tqdm
from itertools import chain
from functools import partial
# endregion General Imports

import tempfile
import tensorflow as tf
import numpy as np
from tensorflow import keras
%load_ext tensorboard
    
# os.environ["KERAS_BACKEND"] = "tf"
# os.environ["TF_USE_LEGACY_KERAS"] = "0"
# import jax
# import jax.numpy as jnp
# import keras
    
# region Keras
from keras.models import Model
from keras.layers import Dense, Activation, Permute, Dropout
from keras.layers import (
    Conv2D,
    MaxPooling2D,
    AveragePooling2D,
    Conv1D,
    MaxPooling1D,
    AveragePooling1D,
)
from keras.layers import SeparableConv2D, DepthwiseConv2D
from keras.layers import BatchNormalization
from keras.layers import SpatialDropout2D
from keras.regularizers import l1_l2
from keras.layers import Input, Flatten
from keras.constraints import max_norm
from keras import backend as K
# endregion Keras

from custom_datasets.fatigue_mi import FatigueMI

# Sklearn
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout, TimeDistributed, BatchNormalization
from sklearn import preprocessing

from model_optim.utils import channels_to_channels_idx







In [2]:
SKLRNG = 42
# RNG = jax.random.PRNGKey(SKLRNG)

In [3]:
# region Helper funcs
def shallow_conv_net_square_layer(x):
    return tf.math.square(x)

def shallow_conv_net_log_layer(x):
    return tf.math.log(tf.clip_by_value(x, 1e-7, 10000))

CUSTOM_OBJECTS = {
    "shallow_conv_net_square_layer": shallow_conv_net_square_layer, 
    "shallow_conv_net_log_layer": shallow_conv_net_log_layer 
}
# endregion Helper funcs

# region Models
def shallow_conv_net(
    nb_classes, channels, samples, **kwargs
):
    """
    From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """

    _POOL_SIZE_D2_ = kwargs.get("pool_size_d2", 35)
    _STRIDES_D2_ = kwargs.get("strides_d2", 7)
    _CONV_FILTERS_D2_ = kwargs.get("conv_filters_d2", 13)

    _POOL_SIZE_ = kwargs.get("pool_size", (1, _POOL_SIZE_D2_))
    _STRIDES_ = kwargs.get("strides", (1, _STRIDES_D2_))
    _CONV_FILTERS_ = kwargs.get("conv_filters", (1, _CONV_FILTERS_D2_))

    _CONV2D_1_UNITS_ = kwargs.get("conv2d_1_units", 40)
    _CONV2D_2_UNITS_ = kwargs.get("conv2d_2_units", 40)
    _L2_REG_1_ = kwargs.get("l2_reg_1", 0.01)
    _L2_REG_2_ = kwargs.get("l2_reg_2", 0.01)
    _L2_REG_3_ = kwargs.get("l2_reg_3", 0.01)
    _DROPOUT_RATE_ = kwargs.get("dropout_rate", 0.5)

    input_main = Input(shape=(channels, samples, 1))
    block1 = Conv2D(
        _CONV2D_1_UNITS_,
        _CONV_FILTERS_,
        input_shape=(channels, samples, 1),
        kernel_constraint=max_norm(2.0, axis=(0, 1, 2)),
        kernel_regularizer=keras.regularizers.L2(_L2_REG_1_),
    )(input_main)
    # block1       = Conv2D(40, (channels, 1), use_bias=False,
    #                       kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block1 = Conv2D(
        _CONV2D_2_UNITS_,
        (channels, 1),
        use_bias=False,
        kernel_constraint=max_norm(2.0, axis=(0, 1, 2)),
        kernel_regularizer=keras.regularizers.L2(_L2_REG_2_),
    )(block1)
    block1 = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
    block1 = Activation(shallow_conv_net_square_layer)(block1)
    block1 = AveragePooling2D(pool_size=_POOL_SIZE_, strides=_STRIDES_)(block1)
    block1 = Activation(shallow_conv_net_log_layer)(block1)
    block1 = Dropout(_DROPOUT_RATE_)(block1)
    flatten = Flatten()(block1)
    # dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
    dense = Dense(
        nb_classes,
        kernel_constraint=max_norm(0.5),
        kernel_regularizer=keras.regularizers.L2(_L2_REG_3_),
    )(flatten)
    softmax = Activation("softmax")(dense)

    return Model(inputs=input_main, outputs=softmax)

# endregion Models

In [4]:
def data_generator(dataset, subjects = [1], channel_idx = [], filters = ([8, 32],), sfreq = 250):

    find_events = lambda raw, event_id: mne.find_events(raw, shortest_event=0, verbose=False) if len(mne.utils._get_stim_channel(None, raw.info, raise_error=False)) > 0 else mne.events_from_annotations(raw, event_id=event_id, verbose=False)[0]
    
    data = dataset.get_data(subjects=subjects)
    
    X = []
    y = []
    metadata = []

    for subject_id in data.keys():
        for session_id in data[subject_id].keys():
            for run_id in data[subject_id][session_id].keys():
                raw = data[subject_id][session_id][run_id]
                
                for fmin, fmax in filters:
                    raw = raw.filter(l_freq = fmin, h_freq = fmax, method = 'iir', picks = 'eeg', verbose = False)
                
                events = find_events(raw, dataset.event_id)

                tmin = dataset.interval[0]
                tmax = dataset.interval[1]

                channels = np.asarray(raw.info['ch_names'])[channel_idx] if len(channel_idx) > 0 else np.asarray(raw.info['ch_names'])

                # rpprint(channels)
                
                stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
                picks = mne.pick_channels(raw.info["ch_names"], include=channels, exclude=stim_channels, ordered=True)

                x = mne.Epochs(
                    raw,
                    events,
                    event_id=dataset.event_id,
                    tmin=tmin,
                    tmax=tmax,
                    proj=False,
                    baseline=None,
                    preload=True,
                    verbose=False,
                    picks=picks,
                    event_repeated="drop",
                    on_missing="ignore",
                )
                x_events = x.events
                inv_events = {k: v for v, k in dataset.event_id.items()}
                labels = [inv_events[e] for e in x_events[:, -1]]

                # rpprint({
                #     "X": np.asarray(x.get_data(copy=False)).shape,
                #     "y": np.asarray(labels).shape,
                #     "channels selected": np.asarray(raw.info['ch_names'])[channel_idx]
                # })

                # x.plot(scalings="auto")
                # display(x.info)
                
                x_resampled = x.resample(sfreq) # Resampler_Epoch
                x_resampled_data = x_resampled.get_data(copy=False) # Convert_Epoch_Array
                x_resampled_data_standard_scaler = np.asarray([
                    StandardScaler().fit_transform(x_resampled_data[i])
                    for i in np.arange(x_resampled_data.shape[0])
                ]) # Standard_Scaler_Epoch

                # x_resampled.plot(scalings="auto")
                # display(x_resampled.info)

                n = x_resampled_data_standard_scaler.shape[0]
                # n = x.get_data(copy=False).shape[0]
                met = pd.DataFrame(index=range(n))
                met["subject"] = subject_id
                met["session"] = session_id
                met["run"] = run_id
                x.metadata = met.copy()
                
                # X.append(x_resampled_data_standard_scaler)
                X.append(x)
                y.append(labels)
                metadata.append(met)

    return np.concatenate(X, axis=0), np.concatenate(y), pd.concat(metadata, ignore_index=True)

fat_dataset = FatigueMI()

In [55]:
def reconstruct_and_save_model(
    study_best_trial_path,
    dataset,
    data_params,
    model_params,
    weights,
    model_fn=lambda: None,
):
    
    study_best_trial = np.load(study_best_trial_path, allow_pickle=True).item()

    _SFREQ_ = data_params["sfreq"]
    _TRAIN_SIZE_ = 0.8
    _TEST_SIZE_ = 1 - _TRAIN_SIZE_
    _BATCH_SIZE = data_params["batch_size"]

    X, y, _ = data_generator(dataset, subjects=data_params["subjects"], channel_idx=data_params["channels_idx"], sfreq=_SFREQ_)

    _NUM_SAMPLES_ = X.shape[-1]
    _NUM_CHANNELS_ = X.shape[-2]

    y_encoded = LabelEncoder().fit_transform(y)
    _NUM_CLASSES_ = len(np.unique(y_encoded))
    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y_encoded,
        test_size=_TEST_SIZE_,
        random_state=SKLRNG,
        shuffle=True,
        stratify=y_encoded,
    )

    final_model_params = {
        "nb_classes": _NUM_CLASSES_,
        "channels": _NUM_CHANNELS_,
        "samples": _NUM_SAMPLES_,
        **model_params
    }
    model = model_fn(**final_model_params)
    model.compile(
        loss="sparse_categorical_crossentropy",
        optimizer="rmsprop",
        metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
    )
    model.set_weights(weights)

    study_best_trial.user_attrs["trial_data"]["model"] = model.to_json()

    # Create same folder as study_best_trial_path under ./temp_v2
    model_dir = "./temp_v2"
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    new_save_path = f"{model_dir}/{'/'.join(study_best_trial_path.split('/')[2:])}".split("study_best_trial.npy")[0]
    # Create new_save_path if not exists
    if not os.path.exists(new_save_path):
        os.makedirs(new_save_path, exist_ok=True)
    
    np.save(new_save_path+'/study_best_trial.npy', study_best_trial, allow_pickle=True)

    # test_eval = model.evaluate(X_test, y_test, batch_size=_BATCH_SIZE)
    # return test_eval



all_old_subject_files = glob.glob("./_temp/**/**/model/study_best_trial.npy")

for old_subject_file in all_old_subject_files:

    rprint(f"Reconstructing model from {old_subject_file}")

    study_best_trial_path = old_subject_file
    model = np.load(study_best_trial_path, allow_pickle=True).item()

    all_channels = fat_dataset.get_data(subjects=[1])[1]["0"]["0"].info["ch_names"][:-1]
    channel_params = [v.replace("channels_", "") for v in model.params.keys() if "channels_" in v and model.params[v] == 1]
    channels_idx = channels_to_channels_idx(channel_params, ch_names = all_channels)

    data_params = {
        "subjects": [int(study_best_trial_path.split("/[")[1][0])],
        "sfreq": model.params["sfreq"],
        "batch_size": model.params["batch_size"],
        "channels_idx": channels_idx
    }

    model_params = {
        "pool_size_d2": model.params["pool_size_d2"],
        "strides_d2": model.params["strides_d2"],
        "conv_filters_d2": model.params["conv_filters_d2"], 
        "conv2d_1_units": model.params["conv2d_1_units"],
        "conv2d_2_units": model.params["conv2d_2_units"],
        "l2_reg_1": model.params["l2_reg_1"],
        "l2_reg_2": model.params["l2_reg_2"],
        "l2_reg_3": model.params["l2_reg_3"],
        "dropout_rate": model.params["dropout_rate"]
    }

    weights = model.user_attrs["trial_data"]["weights"]

    reconstruct_and_save_model(study_best_trial_path, fat_dataset, data_params, model_params, weights, shallow_conv_net)

Adding metadata with 3 columns


Sampling frequency of the instance is already 300.0, returning unmodified.
Adding metadata with 3 columns


Adding metadata with 3 columns


Adding metadata with 3 columns


Adding metadata with 3 columns


Sampling frequency of the instance is already 300.0, returning unmodified.
Adding metadata with 3 columns


Adding metadata with 3 columns


Sampling frequency of the instance is already 300.0, returning unmodified.
Adding metadata with 3 columns


Adding metadata with 3 columns


Adding metadata with 3 columns


Adding metadata with 3 columns


Adding metadata with 3 columns


Sampling frequency of the instance is already 300.0, returning unmodified.
Adding metadata with 3 columns


Adding metadata with 3 columns


In [6]:
# prev_best = prev_best_full.user_attrs['trial_data']

# rpprint({
#     "sfreq": prev_best_full.params["sfreq"],
#     "batch_size": prev_best_full.params["batch_size"],
#     "val_acc": np.max(prev_best["val_accuracy"]),
#     "test_acc": prev_best["test_accuracy"],
#     "channels": prev_best["channels_selected"],
# })

model_info = {
    "subject": model.user_attrs["trial_data"]["subject"] if hasattr(model.user_attrs["trial_data"], "subject") else int(re.compile(r"\[.*\]").search(model.user_attrs["trial_data"]["data_path"]).group(0).strip("[]")),
    "sfreq": model.params["sfreq"] if "sfreq" in model.params else 128,
    "batch_size": model.params["batch_size"] if "batch_size" in model.params else 128,
    "channels_selected": model.user_attrs["trial_data"]["channels_selected"],
    "channels_idx_selected": channels_to_channels_idx(model.user_attrs["trial_data"]["channels_selected"], fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1]),
    "model": tf.keras.models.model_from_json(model.user_attrs["trial_data"]["model"], custom_objects=CUSTOM_OBJECTS),
    "test_acc": model.user_attrs["trial_data"]["test_accuracy"]
}
if "weights" in model.user_attrs["trial_data"]:
    model_info["model"].set_weights(model.user_attrs["trial_data"]["weights"])
elif "model_weights" in model.user_attrs["trial_data"]:
    model_info["model"].set_weights(model.user_attrs["trial_data"]["model_weights"])

rpprint(model_info)

2024-03-24 16:38:06.863189: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-03-24 16:38:06.863610: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [11]:
# np.save('temp_config.npy', model_info["model"].get_config(), allow_pickle=True)