In [1]:
# General imports
import os
import uuid
import time
import random
import datetime
import glob
import pickle
import tqdm
import copy
import optuna
import numpy as np
import pandas as pd
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import mne
from rich import print as rprint
from rich.pretty import pprint as rpprint
from tqdm import tqdm
from itertools import chain
from functools import partial

# JAX + Keras
os.environ["KERAS_BACKEND"] = "jax"
os.environ["TF_USE_LEGACY_KERAS"] = "0"
import jax
import jax.numpy as jnp
import keras
from keras.models import Model
from keras.layers import Dense, Activation, Permute, Dropout
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv1D, MaxPooling1D, AveragePooling1D
from keras.layers import SeparableConv2D, DepthwiseConv2D
from keras.layers import BatchNormalization
from keras.layers import SpatialDropout2D
from keras.regularizers import l1_l2
from keras.layers import Input, Flatten
from keras.constraints import max_norm
from keras import backend as K

# Sklearn
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout, TimeDistributed, BatchNormalization
import h5py
from sklearn import preprocessing

# Dataset
from custom_datasets.fatigue_mi import FatigueMI

%load_ext autoreload
%autoreload 3

/home/arazzz/anaconda3/envs/moabb_model_optimization/lib/python3.11/site-packages/moabb/pipelines/__init__.py:26: ModuleNotFoundError: Tensorflow is not installed. You won't be able to use these MOABB pipelines if you attempt to do so.
  warn(


In [2]:
SKLRNG = 42
RNG = jax.random.PRNGKey(SKLRNG)

2024-03-02 17:34:53.637499: W external/xla/xla/service/gpu/nvptx_compiler.cc:744] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
def data_generator(dataset, subjects = [1], channel_idx = [], filters = ([8, 32],), sfreq = 250):

    find_events = lambda raw, event_id: mne.find_events(raw, shortest_event=0, verbose=False) if len(mne.utils._get_stim_channel(None, raw.info, raise_error=False)) > 0 else mne.events_from_annotations(raw, event_id=event_id, verbose=False)[0]
    
    data = dataset.get_data(subjects=subjects)
    
    X = []
    y = []
    metadata = []

    for subject_id in data.keys():
        for session_id in data[subject_id].keys():
            for run_id in data[subject_id][session_id].keys():
                raw = data[subject_id][session_id][run_id]
                
                for fmin, fmax in filters:
                    raw = raw.filter(l_freq = fmin, h_freq = fmax, method = 'iir', picks = 'eeg', verbose = False)
                
                events = find_events(raw, dataset.event_id)

                tmin = dataset.interval[0]
                tmax = dataset.interval[1]

                channels = np.asarray(raw.info['ch_names'])[channel_idx] if len(channel_idx) > 0 else np.asarray(raw.info['ch_names'])

                # rpprint(channels)
                
                stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
                picks = mne.pick_channels(raw.info["ch_names"], include=channels, exclude=stim_channels, ordered=True)

                x = mne.Epochs(
                    raw,
                    events,
                    event_id=dataset.event_id,
                    tmin=tmin,
                    tmax=tmax,
                    proj=False,
                    baseline=None,
                    preload=True,
                    verbose=False,
                    picks=picks,
                    event_repeated="drop",
                    on_missing="ignore",
                )
                x_events = x.events
                inv_events = {k: v for v, k in dataset.event_id.items()}
                labels = [inv_events[e] for e in x_events[:, -1]]

                # rpprint({
                #     "X": np.asarray(x.get_data(copy=False)).shape,
                #     "y": np.asarray(labels).shape,
                #     "channels selected": np.asarray(raw.info['ch_names'])[channel_idx]
                # })

                # x.plot(scalings="auto")
                # display(x.info)
                
                x_resampled = x.resample(sfreq) # Resampler_Epoch
                x_resampled_data = x_resampled.get_data(copy=False) # Convert_Epoch_Array
                x_resampled_data_standard_scaler = np.asarray([
                    StandardScaler().fit_transform(x_resampled_data[i])
                    for i in np.arange(x_resampled_data.shape[0])
                ]) # Standard_Scaler_Epoch

                # x_resampled.plot(scalings="auto")
                # display(x_resampled.info)

                n = x_resampled_data_standard_scaler.shape[0]
                # n = x.get_data(copy=False).shape[0]
                met = pd.DataFrame(index=range(n))
                met["subject"] = subject_id
                met["session"] = session_id
                met["run"] = run_id
                x.metadata = met.copy()
                
                # X.append(x_resampled_data_standard_scaler)
                X.append(x)
                y.append(labels)
                metadata.append(met)

    return np.concatenate(X, axis=0), np.concatenate(y), pd.concat(metadata, ignore_index=True)

fat_dataset = FatigueMI()
X, y, _ = data_generator(fat_dataset, subjects=[1], channel_idx=[], sfreq=128)

Adding metadata with 3 columns


In [4]:
y_encoded = LabelEncoder().fit_transform(y) # 0 = left, 1 = right, 2 = unlabelled (!NOTE: should be ignored for binary classification)
# rpprint(y_encoded)
rprint({
    "X": X.shape,
    "y": y_encoded.shape
})

NUM_SAMPLES = X.shape[-1]
NUM_CHANNELS = X.shape[-2]
NUM_CLASSES = len(np.unique(y_encoded))

rpprint({
    "NUM_SAMPLES": NUM_SAMPLES,
    "NUM_CHANNELS": NUM_CHANNELS,
    "NUM_CLASSES": NUM_CLASSES
})

# sns.histplot(y_encoded); # Plot the class distribution

In [5]:
TRAIN_SIZE = 0.8
TEST_SIZE = 0.2

# X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, NUM_SAMPLES, NUM_CHANNELS), y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=None, shuffle=False)
# X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, NUM_SAMPLES, NUM_CHANNELS), y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=y_encoded, shuffle=True)
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=y_encoded, shuffle=True)
# X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=TEST_SIZE, random_state=SKLRNG, stratify=None, shuffle=False)

total = X_train.shape[0] + X_test.shape[0]
rpprint({
    "X_train": f"{X_train.shape[0]} ({X_train.shape[0] / total * 100:.2f}%)",
    "X_test": f"{X_test.shape[0]} ({X_test.shape[0] / total * 100:.2f}%)",
}, expand_all=True)

# batch_size = 32
# batches = batch_generator(X_train, batch_size, RNG)

In [6]:
# region Helper funcs
def shallow_conv_net_square_layer(x):
    return jnp.square(x)

def shallow_conv_net_log_layer(x):
    return jnp.log(jnp.clip(x, 1e-7, 10000))

CUSTOM_OBJECTS = {
    "shallow_conv_net_square_layer": shallow_conv_net_square_layer, 
    "shallow_conv_net_log_layer": shallow_conv_net_log_layer 
}
# endregion Helper funcs

# region Models
def shallow_conv_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, **kwargs):
    """
        From: https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """
    
    _POOL_SIZE_D2_ = kwargs.get("pool_size_d2", 35)
    _STRIDES_D2_ = kwargs.get("strides_d2", 7)
    _CONV_FILTERS_D2_ = kwargs.get("conv_filters_d2", 13)

    _POOL_SIZE_ = kwargs.get("pool_size", (1, _POOL_SIZE_D2_))
    _STRIDES_ = kwargs.get("strides", (1, _STRIDES_D2_))
    _CONV_FILTERS_ = kwargs.get("conv_filters", (1, _CONV_FILTERS_D2_))

    _CONV2D_1_UNITS_ = kwargs.get("conv2d_1_units", 40)
    _CONV2D_2_UNITS_ = kwargs.get("conv2d_2_units", 40)
    _L2_REG_1_ = kwargs.get("l2_reg_1", 0.01)
    _L2_REG_2_ = kwargs.get("l2_reg_2", 0.01)
    _L2_REG_3_ = kwargs.get("l2_reg_3", 0.01)
    _DROPOUT_RATE_ = kwargs.get("dropout_rate", 0.5)

    input_main   = Input(shape=(channels, samples, 1))
    block1       = Conv2D(_CONV2D_1_UNITS_, _CONV_FILTERS_,
                                 input_shape=(channels, samples, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)), kernel_regularizer=keras.regularizers.L2(_L2_REG_1_))(input_main)
    # block1       = Conv2D(40, (channels, 1), use_bias=False, 
    #                       kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block1       = Conv2D(_CONV2D_2_UNITS_, (channels, 1), use_bias=False, 
                          kernel_constraint = max_norm(2., axis=(0,1,2)), kernel_regularizer=keras.regularizers.L2(_L2_REG_2_))(block1)
    block1       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
    block1       = Activation(shallow_conv_net_square_layer)(block1)
    block1       = AveragePooling2D(pool_size=_POOL_SIZE_, strides=_STRIDES_)(block1)
    block1       = Activation(shallow_conv_net_log_layer)(block1)
    block1       = Dropout(_DROPOUT_RATE_)(block1)
    flatten      = Flatten()(block1)
    # dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
    dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5), kernel_regularizer=keras.regularizers.L2(_L2_REG_3_))(flatten)
    softmax      = Activation('softmax')(dense)
    
    return Model(inputs=input_main, outputs=softmax)

def lstm_cnn_net(nb_classes = NUM_CLASSES, channels = NUM_CHANNELS, samples = NUM_SAMPLES, **kwargs):

    _CONV1D_1_UNITS_ = kwargs.get("conv1d_1_units", 40)
    _CONV1D_1_KERNEL_SIZE_ = kwargs.get("conv1d_1_kernel_size", 20)
    _CONV1D_1_STRIDES_ = kwargs.get("conv1d_1_strides", 4)
    _CONV1D_1_MAXPOOL_SIZE_ = kwargs.get("conv1d_1_maxpool_size", 4)
    _CONV1D_1_MAXPOOL_STRIDES_ = kwargs.get("conv1d_1_maxpool_strides", 4)
    _LSTM_1_UNITS_ = kwargs.get("lstm_1_units", 50)
    _L2_REG_1_ = kwargs.get("l2_reg_1", 0.01)
    _L2_REG_2_ = kwargs.get("l2_reg_2", 0.01)
    _L2_REG_3_ = kwargs.get("l2_reg_3", 0.01)
    _DROPOUT_RATE_1_ = kwargs.get("dropout_rate_1", 0.5)
    _DROPOUT_RATE_2_ = kwargs.get("dropout_rate_2", 0.5)

    model = Sequential()
    model.add(Input(shape=(samples, channels)))

    # add 1-layer cnn
    model.add(Conv1D(_CONV1D_1_UNITS_, kernel_size=_CONV1D_1_KERNEL_SIZE_, strides=_CONV1D_1_STRIDES_, kernel_regularizer=keras.regularizers.L2(_L2_REG_1_)))
    model.add(Activation('relu'))
    model.add(Dropout(_DROPOUT_RATE_1_))
    model.add(BatchNormalization())
    model.add(MaxPooling1D(pool_size=_CONV1D_1_MAXPOOL_SIZE_, strides=_CONV1D_1_MAXPOOL_STRIDES_))


    # add 1-layer lstm
    model.add(LSTM(_LSTM_1_UNITS_, return_sequences=True, stateful=False, kernel_regularizer=keras.regularizers.L2(_L2_REG_2_)))
    model.add(Dropout(_DROPOUT_RATE_2_))
    model.add(BatchNormalization())
    model.add(Flatten())
    model.add(Dense(nb_classes, activation='softmax', kernel_regularizer=keras.regularizers.L2(_L2_REG_3_)))
    
    return model
# endregion Models

In [38]:
# Load subject model
# subject_data = np.load("best_subjects_search_results/subject_12_shallow_conv_net_3a2d5dc0-13c5-4bdb-86dc-b9a4afb56ce2.npy", allow_pickle=True).item()
subject_data = np.load("../temp/[1]/292cbc92b8cf46da9986fe7d8447819f/model/study_best_trial.npy", allow_pickle=True).item()
subject_trial_data = subject_data.user_attrs['trial_data']
subject_model_json = subject_trial_data['model']
subject_model = keras.models.model_from_json(subject_model_json, custom_objects=CUSTOM_OBJECTS)
subject_model_weights = subject_trial_data['weights']
subject_model.set_weights(subject_model_weights)

def channels_to_channels_idx(channels, dataset):
    all_channels = dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1]
    channels_dict = { k: (True if k in channels else False) for k in all_channels }
    channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    return channels_idx

subject_channels_idx = channels_to_channels_idx(subject_trial_data['channels_selected'], fat_dataset)
subject_sfreq = subject_data.params['sfreq'] if 'sfreq' in subject_data.params else 128
subject_data_hyper_params = { k: v for k, v in subject_data.params.items() if not k.startswith('channels_') }
subject_batch_size = subject_data.params['batch_size']

rpprint({ k: v for k, v in subject_data.params.items() if not k.startswith("channels") })
rprint("test_accuracy =", subject_data.user_attrs["trial_data"]["test_accuracy"])
rprint("val_accuracy =", np.max(subject_data.user_attrs["trial_data"]["val_accuracy"]))
rprint("channels_selected =", subject_data.user_attrs["trial_data"]["channels_selected"])

In [39]:
X, y, _ = data_generator(fat_dataset, subjects=[1], channel_idx=subject_channels_idx, sfreq=subject_sfreq)
# X, y = X[0:50].reshape(len(X[0:50]), len(subject_channels_idx), -1), y[0:50]
y_encoded = LabelEncoder().fit_transform(y)

# X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=SKLRNG, shuffle=True, stratify=y_encoded)

# subject_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
subject_model.evaluate(X, y_encoded, batch_size=subject_batch_size)

Sampling frequency of the instance is already 300.0, returning unmodified.
Adding metadata with 3 columns
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.9074 - loss: 50.4109


[50.41094970703125, 0.9074074029922485]