In [35]:
# region General Imports
import os
import re
import uuid
import shutil
import time
import random
import datetime
import glob
import pickle
import tqdm
import copy
import optuna
import numpy as np
import pandas as pd
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import mne
from rich import print as rprint
from rich.pretty import pprint as rpprint
from tqdm import tqdm
from itertools import chain
from functools import partial
# endregion General Imports

import tempfile
import tensorflow as tf
import numpy as np
from tensorflow import keras
%load_ext tensorboard
    
# os.environ["KERAS_BACKEND"] = "tf"
# os.environ["TF_USE_LEGACY_KERAS"] = "0"
# import jax
# import jax.numpy as jnp
# import keras
    
# region Keras
from keras.models import Model
from keras.layers import Dense, Activation, Permute, Dropout
from keras.layers import (
    Conv2D,
    MaxPooling2D,
    AveragePooling2D,
    Conv1D,
    MaxPooling1D,
    AveragePooling1D,
)
from keras.layers import SeparableConv2D, DepthwiseConv2D
from keras.layers import BatchNormalization
from keras.layers import SpatialDropout2D
from keras.regularizers import l1_l2
from keras.layers import Input, Flatten
from keras.constraints import max_norm
from keras import backend as K
# endregion Keras

from custom_datasets.fatigue_mi import FatigueMI

# Sklearn
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout, TimeDistributed, BatchNormalization
from sklearn import preprocessing

from model_optim.utils import channels_to_channels_idx

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [36]:
SKLRNG = 42
# RNG = jax.random.PRNGKey(SKLRNG)

In [37]:
subject_models = glob.glob('./temp/**/model/study_best_trial.npy', recursive=True)
subject_models = sorted(subject_models, key=lambda x: os.path.getmtime(x))

rpprint(subject_models)

In [38]:
# region Helper funcs
def shallow_conv_net_square_layer(x):
    return tf.math.square(x)

def shallow_conv_net_log_layer(x):
    return tf.math.log(tf.clip_by_value(x, 1e-7, 10000))

CUSTOM_OBJECTS = {
    "shallow_conv_net_square_layer": shallow_conv_net_square_layer, 
    "shallow_conv_net_log_layer": shallow_conv_net_log_layer 
}
# endregion Helper funcs

# region Models
def shallow_conv_net(
    nb_classes, channels, samples, **kwargs
):
    """
    From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """

    _POOL_SIZE_D2_ = kwargs.get("pool_size_d2", 35)
    _STRIDES_D2_ = kwargs.get("strides_d2", 7)
    _CONV_FILTERS_D2_ = kwargs.get("conv_filters_d2", 13)

    _POOL_SIZE_ = kwargs.get("pool_size", (1, _POOL_SIZE_D2_))
    _STRIDES_ = kwargs.get("strides", (1, _STRIDES_D2_))
    _CONV_FILTERS_ = kwargs.get("conv_filters", (1, _CONV_FILTERS_D2_))

    _CONV2D_1_UNITS_ = kwargs.get("conv2d_1_units", 40)
    _CONV2D_2_UNITS_ = kwargs.get("conv2d_2_units", 40)
    _L2_REG_1_ = kwargs.get("l2_reg_1", 0.01)
    _L2_REG_2_ = kwargs.get("l2_reg_2", 0.01)
    _L2_REG_3_ = kwargs.get("l2_reg_3", 0.01)
    _DROPOUT_RATE_ = kwargs.get("dropout_rate", 0.5)

    input_main = Input(shape=(channels, samples, 1))
    block1 = Conv2D(
        _CONV2D_1_UNITS_,
        _CONV_FILTERS_,
        input_shape=(channels, samples, 1),
        kernel_constraint=max_norm(2.0, axis=(0, 1, 2)),
        kernel_regularizer=keras.regularizers.L2(_L2_REG_1_),
    )(input_main)
    # block1       = Conv2D(40, (channels, 1), use_bias=False,
    #                       kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block1 = Conv2D(
        _CONV2D_2_UNITS_,
        (channels, 1),
        use_bias=False,
        kernel_constraint=max_norm(2.0, axis=(0, 1, 2)),
        kernel_regularizer=keras.regularizers.L2(_L2_REG_2_),
    )(block1)
    block1 = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
    block1 = Activation(shallow_conv_net_square_layer)(block1)
    block1 = AveragePooling2D(pool_size=_POOL_SIZE_, strides=_STRIDES_)(block1)
    block1 = Activation(shallow_conv_net_log_layer)(block1)
    block1 = Dropout(_DROPOUT_RATE_)(block1)
    flatten = Flatten()(block1)
    # dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
    dense = Dense(
        nb_classes,
        kernel_constraint=max_norm(0.5),
        kernel_regularizer=keras.regularizers.L2(_L2_REG_3_),
    )(flatten)
    softmax = Activation("softmax")(dense)

    return Model(inputs=input_main, outputs=softmax)

# endregion Models

In [39]:
def data_generator(dataset, subjects = [1], channel_idx = [], filters = ([8, 32],), sfreq = 250):

    find_events = lambda raw, event_id: mne.find_events(raw, shortest_event=0, verbose=False) if len(mne.utils._get_stim_channel(None, raw.info, raise_error=False)) > 0 else mne.events_from_annotations(raw, event_id=event_id, verbose=False)[0]
    
    data = dataset.get_data(subjects=subjects)
    
    X = []
    y = []
    metadata = []

    for subject_id in data.keys():
        for session_id in data[subject_id].keys():
            for run_id in data[subject_id][session_id].keys():
                raw = data[subject_id][session_id][run_id]
                
                for fmin, fmax in filters:
                    raw = raw.filter(l_freq = fmin, h_freq = fmax, method = 'iir', picks = 'eeg', verbose = False)
                
                events = find_events(raw, dataset.event_id)

                tmin = dataset.interval[0]
                tmax = dataset.interval[1]

                channels = np.asarray(raw.info['ch_names'])[channel_idx] if len(channel_idx) > 0 else np.asarray(raw.info['ch_names'])

                # rpprint(channels)
                
                stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
                picks = mne.pick_channels(raw.info["ch_names"], include=channels, exclude=stim_channels, ordered=True)

                x = mne.Epochs(
                    raw,
                    events,
                    event_id=dataset.event_id,
                    tmin=tmin,
                    tmax=tmax,
                    proj=False,
                    baseline=None,
                    preload=True,
                    verbose=False,
                    picks=picks,
                    event_repeated="drop",
                    on_missing="ignore",
                )
                x_events = x.events
                inv_events = {k: v for v, k in dataset.event_id.items()}
                labels = [inv_events[e] for e in x_events[:, -1]]

                # rpprint({
                #     "X": np.asarray(x.get_data(copy=False)).shape,
                #     "y": np.asarray(labels).shape,
                #     "channels selected": np.asarray(raw.info['ch_names'])[channel_idx]
                # })

                # x.plot(scalings="auto")
                # display(x.info)
                
                x_resampled = x.resample(sfreq) # Resampler_Epoch
                x_resampled_data = x_resampled.get_data(copy=False) # Convert_Epoch_Array
                x_resampled_data_standard_scaler = np.asarray([
                    StandardScaler().fit_transform(x_resampled_data[i])
                    for i in np.arange(x_resampled_data.shape[0])
                ]) # Standard_Scaler_Epoch

                # x_resampled.plot(scalings="auto")
                # display(x_resampled.info)

                n = x_resampled_data_standard_scaler.shape[0]
                # n = x.get_data(copy=False).shape[0]
                met = pd.DataFrame(index=range(n))
                met["subject"] = subject_id
                met["session"] = session_id
                met["run"] = run_id
                x.metadata = met.copy()
                
                # X.append(x_resampled_data_standard_scaler)
                X.append(x)
                y.append(labels)
                metadata.append(met)

    return np.concatenate(X, axis=0), np.concatenate(y), pd.concat(metadata, ignore_index=True)

fat_dataset = FatigueMI()

In [34]:
model = np.load(subject_models[0], allow_pickle=True).item()
model_info = {
    "subject": model.user_attrs["trial_data"]["subject"] if hasattr(model.user_attrs["trial_data"], "subject") else int(re.compile(r"\[.*\]").search(model.user_attrs["trial_data"]["data_path"]).group(0).strip("[]")),
    "sfreq": model.params["sfreq"] if "sfreq" in model.params else 128,
    "batch_size": model.params["batch_size"] if "batch_size" in model.params else 128,
    "channels_selected": model.user_attrs["trial_data"]["channels_selected"],
    "channels_idx_selected": channels_to_channels_idx(model.user_attrs["trial_data"]["channels_selected"], fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1]),
    "model": tf.keras.models.model_from_json(model.user_attrs["trial_data"]["model"], custom_objects=CUSTOM_OBJECTS),
}
if "weights" in model.user_attrs["trial_data"]:
    model_info["model"].set_weights(model.user_attrs["trial_data"]["weights"])
elif "model_weights" in model.user_attrs["trial_data"]:
    model_info["model"].set_weights(model.user_attrs["trial_data"]["model_weights"])

rpprint(model_info)

TypeError: Error when deserializing class 'InputLayer' using config={'batch_shape': [None, 10, 601, 1], 'dtype': 'float32', 'sparse': False, 'name': 'input_layer_653'}.

Exception encountered: Unrecognized keyword arguments: ['batch_shape']

In [20]:
X, y, _ = data_generator(fat_dataset, subjects=[model_info["subject"]], channel_idx=model_info["channels_idx_selected"], sfreq=model_info["sfreq"])
y_encoded = LabelEncoder().fit_transform(y)

X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=SKLRNG, shuffle=True, stratify=y_encoded)

model_info["model"].compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model_info["model"].evaluate(X, y_encoded, batch_size=model_info["batch_size"])

Adding metadata with 3 columns


2024-03-02 18:54:17.463740: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8907
2024-03-02 18:54:21.284834: I external/local_tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2024-03-02 18:54:22.235338: I external/local_tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory




[0.44489094614982605, 0.8703703880310059]

In [21]:
# temp_model = tf.keras.Model(
#     inputs=model_info["model"].inputs,
#     outputs=model_info["model"].outputs
# )
# temp_model.set_weights(model.user_attrs["trial_data"]["weights"])

isinstance(model_info["model"], tf.keras.Model), model_info["model"]._is_graph_network

(True, True)

In [22]:
import tensorflow_model_optimization as tfmot
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
batch_size = model_info["batch_size"]
epochs = 2
end_step = np.ceil(len(X_train) / batch_size).astype(np.int32) * epochs
# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}
model_for_pruning = prune_low_magnitude(model_info["model"], **pruning_params)
# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

model_for_pruning.summary()

Model: "model_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_5 (InputLayer)        [(None, 14, 513, 1)]      0         
                                                                 
 prune_low_magnitude_conv2d  (None, 14, 503, 120)      2762      
 _8 (PruneLowMagnitude)                                          
                                                                 
 prune_low_magnitude_conv2d  (None, 1, 503, 50)        168002    
 _9 (PruneLowMagnitude)                                          
                                                                 
 prune_low_magnitude_batch_  (None, 1, 503, 50)        201       
 normalization_4 (PruneLowM                                      
 agnitude)                                                       
                                                                 
 prune_low_magnitude_activa  (None, 1, 503, 50)        1   