In [1]:
# region General Imports
import os
import re
import uuid
import shutil
import time
import random
import datetime
import glob
import pickle
import tqdm
import copy
import optuna
import numpy as np
import pandas as pd
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import mne
from rich import print as rprint
from rich.pretty import pprint as rpprint
from tqdm import tqdm
from itertools import chain
from functools import partial
# endregion General Imports

import tempfile
import tensorflow as tf
import numpy as np
from tensorflow import keras

import tensorflow_model_optimization as tfmot
%load_ext tensorboard
    
# os.environ["KERAS_BACKEND"] = "tf"
# os.environ["TF_USE_LEGACY_KERAS"] = "0"
# import jax
# import jax.numpy as jnp
# import keras
    
# region Keras
from keras.models import Model
from keras.layers import Dense, Activation, Permute, Dropout
from keras.layers import (
    Conv2D,
    MaxPooling2D,
    AveragePooling2D,
    Conv1D,
    MaxPooling1D,
    AveragePooling1D,
)
from keras.layers import SeparableConv2D, DepthwiseConv2D
from keras.layers import BatchNormalization
from keras.layers import SpatialDropout2D
from keras.regularizers import l1_l2
from keras.layers import Input, Flatten
from keras.constraints import max_norm
from keras import backend as K
# endregion Keras

from custom_datasets.fatigue_mi import FatigueMI

# Sklearn
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout, TimeDistributed, BatchNormalization
from sklearn import preprocessing

from model_optim.utils import channels_to_channels_idx



To use the get_shape_from_baseconcar, InputShapeSetterEEG, BraindecodeDatasetLoaderyou need to install `braindecode`.`pip install braindecode` or Please refer to `https://braindecode.org`.




In [2]:
SKLRNG = 42
# RNG = jax.random.PRNGKey(SKLRNG)

In [3]:
# region Helper funcs
def shallow_conv_net_square_layer(x):
    return tf.math.square(x)

def shallow_conv_net_log_layer(x):
    return tf.math.log(tf.clip_by_value(x, 1e-7, 10000))

CUSTOM_OBJECTS = {
    "shallow_conv_net_square_layer": shallow_conv_net_square_layer, 
    "shallow_conv_net_log_layer": shallow_conv_net_log_layer 
}
# endregion Helper funcs

# region Models
def shallow_conv_net(
    nb_classes, channels, samples, **kwargs
):
    """
    From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """

    _POOL_SIZE_D2_ = kwargs.get("pool_size_d2", 35)
    _STRIDES_D2_ = kwargs.get("strides_d2", 7)
    _CONV_FILTERS_D2_ = kwargs.get("conv_filters_d2", 13)

    _POOL_SIZE_ = kwargs.get("pool_size", (1, _POOL_SIZE_D2_))
    _STRIDES_ = kwargs.get("strides", (1, _STRIDES_D2_))
    _CONV_FILTERS_ = kwargs.get("conv_filters", (1, _CONV_FILTERS_D2_))

    _CONV2D_1_UNITS_ = kwargs.get("conv2d_1_units", 40)
    _CONV2D_2_UNITS_ = kwargs.get("conv2d_2_units", 40)
    _L2_REG_1_ = kwargs.get("l2_reg_1", 0.01)
    _L2_REG_2_ = kwargs.get("l2_reg_2", 0.01)
    _L2_REG_3_ = kwargs.get("l2_reg_3", 0.01)
    _DROPOUT_RATE_ = kwargs.get("dropout_rate", 0.5)

    input_main = Input(shape=(channels, samples, 1))
    block1 = Conv2D(
        _CONV2D_1_UNITS_,
        _CONV_FILTERS_,
        input_shape=(channels, samples, 1),
        kernel_constraint=max_norm(2.0, axis=(0, 1, 2)),
        kernel_regularizer=keras.regularizers.L2(_L2_REG_1_),
    )(input_main)
    # block1       = Conv2D(40, (channels, 1), use_bias=False,
    #                       kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block1 = Conv2D(
        _CONV2D_2_UNITS_,
        (channels, 1),
        use_bias=False,
        kernel_constraint=max_norm(2.0, axis=(0, 1, 2)),
        kernel_regularizer=keras.regularizers.L2(_L2_REG_2_),
    )(block1)
    block1 = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
    block1 = Activation(shallow_conv_net_square_layer)(block1)
    block1 = AveragePooling2D(pool_size=_POOL_SIZE_, strides=_STRIDES_)(block1)
    block1 = Activation(shallow_conv_net_log_layer)(block1)
    block1 = Dropout(_DROPOUT_RATE_)(block1)
    flatten = Flatten()(block1)
    # dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
    dense = Dense(
        nb_classes,
        kernel_constraint=max_norm(0.5),
        kernel_regularizer=keras.regularizers.L2(_L2_REG_3_),
    )(flatten)
    softmax = Activation("softmax")(dense)

    return Model(inputs=input_main, outputs=softmax)

# endregion Models

In [230]:
subject_best_trials = glob.glob('./temp_v2/**/model/study_best_trial.npy', recursive=True)
subject_best_trials = sorted(subject_best_trials, key=lambda x: os.path.getmtime(x))
subject_best_trials = sorted(subject_best_trials, key=lambda x: int(re.compile(r"\[.*\]").search(x).group(0).strip("[]")))

# subject_best_trials = glob.glob('./temp/**/model/study_best_trial.npy', recursive=True)

In [231]:
def data_generator(dataset, subjects = [1], channel_idx = [], filters = ([8, 32],), sfreq = 250):

    find_events = lambda raw, event_id: mne.find_events(raw, shortest_event=0, verbose=False) if len(mne.utils._get_stim_channel(None, raw.info, raise_error=False)) > 0 else mne.events_from_annotations(raw, event_id=event_id, verbose=False)[0]
    
    data = dataset.get_data(subjects=subjects)
    
    X = []
    y = []
    metadata = []

    for subject_id in data.keys():
        for session_id in data[subject_id].keys():
            for run_id in data[subject_id][session_id].keys():
                raw = data[subject_id][session_id][run_id]
                
                for fmin, fmax in filters:
                    raw = raw.filter(l_freq = fmin, h_freq = fmax, method = 'iir', picks = 'eeg', verbose = False)
                
                events = find_events(raw, dataset.event_id)

                tmin = dataset.interval[0]
                tmax = dataset.interval[1]

                channels = np.asarray(raw.info['ch_names'])[channel_idx] if len(channel_idx) > 0 else np.asarray(raw.info['ch_names'])

                # rpprint(channels)
                
                stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
                picks = mne.pick_channels(raw.info["ch_names"], include=channels, exclude=stim_channels, ordered=True)

                x = mne.Epochs(
                    raw,
                    events,
                    event_id=dataset.event_id,
                    tmin=tmin,
                    tmax=tmax,
                    proj=False,
                    baseline=None,
                    preload=True,
                    verbose=False,
                    picks=picks,
                    event_repeated="drop",
                    on_missing="ignore",
                )
                x_events = x.events
                inv_events = {k: v for v, k in dataset.event_id.items()}
                labels = [inv_events[e] for e in x_events[:, -1]]

                # rpprint({
                #     "X": np.asarray(x.get_data(copy=False)).shape,
                #     "y": np.asarray(labels).shape,
                #     "channels selected": np.asarray(raw.info['ch_names'])[channel_idx]
                # })

                # x.plot(scalings="auto")
                # display(x.info)
                
                x_resampled = x.resample(sfreq) # Resampler_Epoch
                x_resampled_data = x_resampled.get_data(copy=False) # Convert_Epoch_Array
                x_resampled_data_standard_scaler = np.asarray([
                    StandardScaler().fit_transform(x_resampled_data[i])
                    for i in np.arange(x_resampled_data.shape[0])
                ]) # Standard_Scaler_Epoch

                # x_resampled.plot(scalings="auto")
                # display(x_resampled.info)

                n = x_resampled_data_standard_scaler.shape[0]
                # n = x.get_data(copy=False).shape[0]
                met = pd.DataFrame(index=range(n))
                met["subject"] = subject_id
                met["session"] = session_id
                met["run"] = run_id
                x.metadata = met.copy()
                
                # X.append(x_resampled_data_standard_scaler)
                X.append(x)
                y.append(labels)
                metadata.append(met)

    return np.concatenate(X, axis=0), np.concatenate(y), pd.concat(metadata, ignore_index=True)

fat_dataset = FatigueMI()

In [232]:
rpprint(subject_best_trials)

In [233]:
def unit_prune_dense_layer(k_weights, b_weights, k_sparsity):
    """
    Takes in matrices of kernel and bias weights (for a dense
      layer) and returns the unit-pruned versions of each
    Args:
      k_weights: 2D matrix of the 
      b_weights: 1D matrix of the biases of a dense layer
      k_sparsity: percentage of weights to set to 0
    Returns:
      kernel_weights: sparse matrix with same shape as the original
        kernel weight matrix
      bias_weights: sparse array with same shape as the original
        bias array
    """

    # Copy the kernel weights and get ranked indeces of the
    # column-wise L2 Norms
    kernel_weights = np.copy(k_weights)
    ind = np.argsort(np.linalg.norm(kernel_weights, axis=0))
        
    # Number of indexes to set to 0
    cutoff = int(len(ind)*k_sparsity)
    # The indexes in the 2D kernel weight matrix to set to 0
    sparse_cutoff_inds = ind[0:cutoff]
    kernel_weights[:,sparse_cutoff_inds] = 0.
        
    # Copy the bias weights and get ranked indeces of the abs
    bias_weights = np.copy(b_weights)
    # The indexes in the 1D bias weight matrix to set to 0
    # Equal to the indexes of the columns that were removed in this case
    #sparse_cutoff_inds
    bias_weights[sparse_cutoff_inds] = 0.
    
    return kernel_weights, bias_weights

In [234]:
def load_best_trial(subject_best_trial):
    model = np.load(subject_best_trials[0], allow_pickle=True).item()
    model_info = {
        "subject": model.user_attrs["trial_data"]["subject"] if hasattr(model.user_attrs["trial_data"], "subject") else int(re.compile(r"\[.*\]").search(model.user_attrs["trial_data"]["data_path"]).group(0).strip("[]")),
        "sfreq": model.params["sfreq"] if "sfreq" in model.params else 128,
        "batch_size": model.params["batch_size"] if "batch_size" in model.params else 128,
        "channels_selected": model.user_attrs["trial_data"]["channels_selected"],
        "channels_idx_selected": channels_to_channels_idx(model.user_attrs["trial_data"]["channels_selected"], fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1]),
        "model": tf.keras.models.model_from_json(model.user_attrs["trial_data"]["model"], custom_objects=CUSTOM_OBJECTS),
        "test_acc": model.user_attrs["trial_data"]["test_accuracy"],
        "model_name": model.user_attrs["trial_data"]["model_name"] if hasattr(model.user_attrs["trial_data"], "model_name") else "shallow_conv_net"
    }
    if "weights" in model.user_attrs["trial_data"]:
        model_info["model"].set_weights(model.user_attrs["trial_data"]["weights"])
    elif "model_weights" in model.user_attrs["trial_data"]:
        model_info["model"].set_weights(model.user_attrs["trial_data"]["model_weights"])
    
    return model_info

In [235]:
for subject_best_trial in subject_best_trials[0:1]:

    rprint("\n\n\nLoading best trial model...", subject_best_trial)

    model_info = load_best_trial(subject_best_trial)
    subject = model_info["subject"]
    model_name = model_info["model_name"]

    X, y, _ = data_generator(fat_dataset, subjects=[model_info["subject"]], channel_idx=model_info["channels_idx_selected"], sfreq=model_info["sfreq"])
    y_encoded = LabelEncoder().fit_transform(y)
    X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=SKLRNG, shuffle=True, stratify=y_encoded)
    train_test_data = { "X_train": X_train, "X_test": X_test, "y_train": y_train, "y_test": y_test }

Sampling frequency of the instance is already 300.0, returning unmodified.
Adding metadata with 3 columns


In [258]:
def weight_prune_dense_layer(k_weights, b_weights, k_sparsity):
    # Copy the kernel weights and get ranked indeces of the abs
    kernel_weights = np.copy(k_weights)
    kernel_weight_idx_by_magnitude = np.argsort(np.abs(kernel_weights), axis=None) # rank the individual weights in weight matrix according to their magnitude (absolute value)

    kernel_weights_sparse_idx = kernel_weight_idx_by_magnitude[0:int(len(kernel_weight_idx_by_magnitude)*k_sparsity)]
    kernel_weights[np.unravel_index(kernel_weights_sparse_idx, kernel_weights.shape) if len(kernel_weights_sparse_idx) > 0 else kernel_weights_sparse_idx] = 0

    if b_weights is None:
        return kernel_weights, None
    
    bias_weights = np.copy(b_weights)
    bias_weights_idx_by_magnitude = np.argsort(np.abs(bias_weights), axis=None) # rank the individual weights in weight matrix according to their magnitude (absolute value)

    bias_weights_sparse_idx = bias_weights_idx_by_magnitude[0:int(len(bias_weights_idx_by_magnitude)*k_sparsity)]
    bias_weights[np.unravel_index(bias_weights_sparse_idx, bias_weights.shape)] = 0

    return kernel_weights, bias_weights


def sparsify_model(model, x_test, y_test, k_sparsity, pruning='weight', ignore_batch_norm = True):
    keras.utils.get_custom_objects().update(CUSTOM_OBJECTS)
    sparse_model = tf.keras.models.clone_model(model)
    sparse_model.set_weights(model.get_weights())

    assert ignore_batch_norm == True, print("Batch Normalization is not supported yet")

    for layer in model.layers:
        if "input" in layer.name or len(layer.trainable_weights) == 0:
            continue
        if "batch_normalization" in layer.name and ignore_batch_norm:
            continue
        W = layer.get_weights()[0]
        b = None
        if "batch_normalization" not in layer.name and layer.use_bias:
            b = layer.get_weights()[1]

        if pruning=='weight':
            # rprint(layer.name, W.shape)
            kernel_weights, bias_weights = weight_prune_dense_layer(W, b, k_sparsity)
        # elif pruning=='unit':
        #     kernel_weights, bias_weights = unit_prune_dense_layer(W, b, k_sparsity)

        sparse_model.get_layer(layer.name).set_weights([kernel_weights, bias_weights] if b is not None else [kernel_weights])

    sparse_model.compile(
        loss=tf.keras.losses.sparse_categorical_crossentropy,
        optimizer='adam',
        metrics=['accuracy'])
    
    # Printing the the associated loss & Accuracy for the k% sparsity
    score = sparse_model.evaluate(x_test, y_test, verbose=0)
    print('k% weight sparsity: ', k_sparsity,
          '\tTest loss: {:07.5f}'.format(score[0]),
          '\tTest accuracy: {:05.2f} %%'.format(score[1]*100.))
    
    return sparse_model, score


sparsify_model(
    model = model_info["model"],
    x_test = train_test_data["X_test"],
    y_test = train_test_data["y_test"],
    k_sparsity = 0.40,
    pruning = 'weight'
)

k% weight sparsity:  0.4 	Test loss: 24.94406 	Test accuracy: 68.18 %%


(<tf_keras.src.engine.functional.Functional at 0x7f60d8b26d90>,
 [24.944063186645508, 0.6818181872367859])

In [259]:
rpprint(model_info)

In [241]:
model_info["model"].summary()

Model: "model_11"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_10 (InputLayer)       [(None, 10, 601, 1)]      0         
                                                                 
 conv2d_18 (Conv2D)          (None, 10, 562, 180)      7380      
                                                                 
 conv2d_19 (Conv2D)          (None, 1, 562, 30)        54000     
                                                                 
 batch_normalization_9 (Bat  (None, 1, 562, 30)        120       
 chNormalization)                                                
                                                                 
 activation_27 (Activation)  (None, 1, 562, 30)        0         
                                                                 
 average_pooling2d_9 (Avera  (None, 1, 33, 30)         0         
 gePooling2D)                                             