Workflow to fit SVM per day of recording on selected gestrues (identified from Day 16/day 1).

In [None]:
import os
import sys
notebook_dir = os.getcwd()
project_dir = os.path.dirname(notebook_dir)
if project_dir not in sys.path:
    sys.path.insert(0, project_dir)

from srcs.engdataset import ENGDataset, Nerve
import utils.preprocessing as pre
import utils.classify as classify
import utils.plot as uplot
from constants import *

from collections import Counter
import logging
from collections import namedtuple
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.gridspec as gridspec
from typing import Dict, List


plt.rcParams.update({"figure.dpi": 150})
plt.rcParams['axes.axisbelow'] = True
logging.getLogger().setLevel(logging.INFO)
matplotlib.rcParams.update({'font.size': 6})

plt.rcParams.update({
            "figure.dpi": 150, 'font.size': 10,
            'figure.figsize': (5,3), 'axes.axisbelow': True,
            'axes.edgecolor': COLOR_DICT['clouds'], 'axes.linewidth': 0.4
        })

logging.getLogger().setLevel(logging.INFO)

%load_ext autoreload
%autoreload 2

# 1. Load ENG Data

In [None]:
# load raw ENG for the following parameters
day = 16                   # day of recording: can be 16, 17 or 23
session = '01'             # session of recording

preproc_plots = False      # whether to plot figures during preprocessing
filter_signal = True      # whether to filter all channels or reload a saved file with filtered data
save_figs = True
optimize_mem = True        # whether to save some memory by deleting raw df once not needed.

feature = 'power'
wind_size = 0.100          # window size in seconds
overlap_perc = 0.5          # overlap ratio
organize_strat = 'flx_vs_ext_separate'     #  defines how to prepare the dataset: 'flx_vs_ext_separate' or 'flx_vs_ext_together' or'flx_vs_ext_combined'

# Select only few classes from all for classification
Gest_namedtup = namedtuple('gesture', ['id', 'phase'])
sel_gest_phase = [Gest_namedtup(0, 'Open'),
                  Gest_namedtup(1, 'Close'),
                  # Gest_namedtup(2,'Close'),
                  Gest_namedtup(3, 'Close'),
                  Gest_namedtup(4, 'Close')]


# Classifier parameters
seed = 10     # random seed used for splitting the k-folds
k_cv = 5     # number of k folds

In [None]:
# create directory for figures
directories = [FIG_DIR, CLF_FIG, FILTERED_DIR, CLF_RESULTS_DIR]

for directory in directories:
    if not os.path.exists(directory):
        os.makedirs(directory)

CLF_FIG = os.path.join(FIG_DIR, 'clf')
if not os.path.exists(CLF_FIG):
    os.makedirs(CLF_FIG)

In [None]:
eng_dataset = ENGDataset(day= day, session=session, load_raw_data = True, save_figs=save_figs)
pipeline = {'bp_order': 3, 'bp_cutoff_freq': np.array([300, 2000]), 'notch_bandwidth': 0.5, 'notch_reject': 50}
eng_dataset.filt_pipeline = pipeline

In [None]:
def show_matfile_vars(data):
    for key in data.keys():
        if isinstance(data[key], np.ndarray):
            print(f"{key}: {data[key].dtype} {data[key].shape} ")
        if isinstance(data[key], list):
            print(f"{key}: list {len(data[key])}")
    print("\n")
show_matfile_vars(eng_dataset.raw_data)
show_matfile_vars(eng_dataset.post_data)

# 2.Filter the raw data

In [None]:
if filter_signal: # filter all channels and plot the bandpassed version
    notch_filt_data, bp_filt_data = pre.apply_filter_pipeline(eng_dataset)

    if preproc_plots:
    # plot single ch fft after filtering
        fig = plt.figure(figsize=(8, 4))
        gs = gridspec.GridSpec(nrows=3, ncols=1)
        sel_ch = 0
        xlim = [0,  pipeline['bp_cutoff_freq'][-1]]  # in Hz

        xf_raw, yf_raw = pre.get_fft(np.array(eng_dataset.post_data_df[sel_ch]), ENG_FS)

        xf_bp, yf_bp = pre.get_fft(bp_filt_data[:, sel_ch], ENG_FS)
        xf_not, yf_not = pre.get_fft(notch_filt_data[:, sel_ch], ENG_FS)

        ax = fig.add_subplot(gs[0])
        ax.plot(xf_raw, np.sqrt(np.abs(yf_raw)), label='raw', color=COLOR_DICT['midnight_blue'])
        ax.set(frame_on=False)
        plt.xlim(xlim)
        plt.ylabel('amplitude [uV]')
        plt.title(f"Signal ch:{sel_ch} after Notch filters + BP: {pipeline['bp_cutoff_freq']}", fontsize=8)

        ax2 = fig.add_subplot(gs[1])
        ax2.plot(xf_not, np.sqrt(np.abs(yf_not)), label='Notch', color=COLOR_DICT['pumpkin'])
        ax2.set(frame_on=False)
        plt.legend()
        plt.xlim(xlim)

        # plt.xticks(ticks=np.arange(-0,1000,50))
        plt.ylabel('amplitude [uV]')

        ax1 = fig.add_subplot(gs[2])
        ax1.plot(xf_bp, np.sqrt(np.abs(yf_bp)), label='BPF', color=COLOR_DICT['midnight_blue'])
        ax1.set(frame_on=False)
        plt.legend()
        plt.xlim(xlim)
        plt.ylabel('amplitude [uV]')
        plt.xlabel('freq [Hz]')
        plt.show()

        fig.savefig(f"figures/fft_bp_{pipeline['bp_cutoff_freq'][0]}_"
                    f"{pipeline['bp_cutoff_freq'][1]}.png")

In [None]:
# save filtered data to pickle or load from pickle a presaved filtered data
filt_filename = f"day{eng_dataset.day}{eng_dataset.session}_eng_filt_{eng_dataset.filt_pipeline['bp_cutoff_freq'][0]}_{eng_dataset.filt_pipeline['bp_cutoff_freq'][1]}.pkl"
if filter_signal:
    # organize filtered data in dataframe
    filt_df = pd.DataFrame(bp_filt_data)
    filt_df[TIME_VAR] = eng_dataset.post_data_df[TIME_VAR]
    filt_df.to_pickle(os.path.join(FILTERED_DIR, filt_filename))
else:
    logging.info(f"Loading filtered data from {filt_filename}")
    filt_df = pd.read_pickle(os.path.join(FILTERED_DIR, filt_filename))
eng_dataset.filt_df = filt_df


In [None]:
bad_channels, bad_channels_std = pre.detect_bad_channels(eng_dataset, std_threshold=6)
print(f"Bad channels:{bad_channels}\nBad channels std:{np.round(bad_channels_std,3)}")

In [None]:
# # plot bad channels
# fig = plt.figure(figsize=(8, 4))
# gs = gridspec.GridSpec(nrows=4, ncols=2)
# bad_channels_good = bad_channels + [43,0]
# for i, ch in enumerate(bad_channels_good):
#     ax = fig.add_subplot(gs[i])
#     ax.plot(eng_dataset.filt_df[TIME_VAR], eng_dataset.filt_df[ch], label='raw', color=COLOR_DICT['midnight_blue'])
#     plt.ylabel('amplitude [uV]')
#     plt.title(f"Bad channel {ch}")
#     # ax.set_xlim([5.645, 5.8])
#     # ax.set_ylim([-800,800])

# fig.tight_layout()


# 3. Unfold each rep in time and fit SVM

In [None]:
# optimizing some memory
if optimize_mem:
    eng_dataset._detete_raw_data()

In [None]:
bad_channels, wind_size

In [None]:
input_df, labels_map = classify.prepare_input_df(eng_dataset, feature, organize_strat,  wind_size, overlap_perc)
avg_win_per_class = input_df.groupby([LABEL_COL],as_index=True)[FEAT_WIN_COL].count() / input_df.groupby([LABEL_COL],as_index=True)[REP_ID_COL].nunique()
print(f"Average number of windows per class\n{avg_win_per_class}\n")



select_class, select_class_labels = pre.encode_gest_phase(eng_dataset, sel_gest_phase, labels_map)
print(f"Selected classes:{select_class} with names:\n{select_class_labels}")
results_df = classify.fit_svm(input_df, labels_map, select_class, eng_dataset,
                              annotate_cm=False,
                              seed=seed,
                              is_temporal=True,
                              k_cv=k_cv,
                              bad_channels=[], # old code, [1,2,43] for day 16, session 01
                              bin_width=wind_size,
                              bin_stat=feature,
                              exp_var=None)

In [None]:
results_df

In [None]:
results_df.mean()

In [None]:
print(fr"Mean balanced acc across all classes:{np.round(results_df['acc_val'].mean(),4)*100} % +- {np.round(results_df['acc_val'].std(),4)*100} %")