In [1]:
import mne
import numpy as np
import os
import seaborn as sb
import pandas as pd
import matplotlib.pyplot as plt
from mne.parallel import parallel_func
from mne.decoding import CSP
from scipy import signal
from scipy.signal import resample
from datetime import datetime
from autoreject import Ransac, AutoReject
from mne.time_frequency import tfr_stockwell, psd_array_welch, psd_welch, psd_multitaper
import warnings
from autoreject import get_rejection_threshold
import sys
sys.path.append('./')
from surface_laplacian import surface_laplacian
warnings.filterwarnings("ignore")

In [2]:
def _get_file_path(subject):
    """
    Get the EEG file path and Force data path of a subject

    Parameter
    ----------
    subject : string of subject ID e.g. 7707 

    Returns
    ----------
    eeg_path        : path to EEG data of all trials
    eeg_file_path   : path to a EEG data to a particular trial. Trials: HighFine, HighGross, LowFine, LowGross
    trial_path      : path to Force data of all trials
    trial_file_path : path to a Force data to a particular trial.

    """
    # EEG file 
    eeg_path = '../EEG Data/' + subject + '/'
    fname = [f for f in os.listdir(eeg_path) if f.endswith('.edf')]
    fname.sort()
    eeg_file_path = eeg_path + fname[1]  # Decontaminated file 

    # Trial time
    trial_path = '../Force Data/' + subject + '/'
    fname = [f for f in os.listdir(trial_path) if f.endswith('.csv')]
    fname.sort()
    trial_file_path = [trial_path + path for path in
                       fname[0:4]]  # Four files: 'HighFine', 'HighGross', 'LowFine', 'LowGross'

    return eeg_path, eeg_file_path, trial_path, trial_file_path


def _get_time(eeg_file_path, trial_file_path, trial):
    """
    Get the start and end time of a trial to align with eeg data

    Parameter
    ---------- 
    eeg_file_path   : path to eeg file 
    trial_file_path : path to trial file (force data folder)
    trial           : trial (str)

    Returns
    ----------
    start_time : start time of the trial with eeg as reference
    end_time   : end time of the trial with eeg as reference

    """
    # EEG time 
    eeg_time = eeg_file_path.split('.')
    eeg_time = datetime.strptime(''.join(eeg_time[3:5]) + '0000', '%d%m%y%H%M%S%f')

    # Trial time
    idx = ['HighFine', 'HighGross', 'LowFine', 'LowGross'].index(trial)
    trial_time = np.genfromtxt(trial_file_path[idx], dtype=str, delimiter=',', usecols=0, skip_footer=150,
                               skip_header=100).tolist()

    # Update year, month, and day
    start_t = datetime.strptime(trial_time[0], '%H:%M:%S:%f')
    start_t = start_t.replace(year=eeg_time.year, month=eeg_time.month, day=eeg_time.day)
    end_t = datetime.strptime(trial_time[-1], '%H:%M:%S:%f')
    end_t = end_t.replace(year=eeg_time.year, month=eeg_time.month, day=eeg_time.day)

    start_time = (start_t - eeg_time).total_seconds()  # convert to seconds
    end_time = (end_t - eeg_time).total_seconds()

    return start_time, end_time


def _get_eeg_data(eeg_file_path):
    """
    Get the eeg data excluding unnessary channels from edf file

    Parameter
    ---------- 
    eeg_file_path   : path to eeg file 

    Returns
    ----------
    start_time : start time of the trial with eeg as reference
    end_time   : end time of the trial with eeg as reference

    """
    # EEG info 
    info = mne.create_info(ch_names=['Fp1', 'F7', 'F8', 'T4', 'T6', 'T5', 'T3', 'Fp2', 'O1',
                                     'P3', 'Pz', 'F3', 'Fz', 'F4', 'C4', 'P4', 'POz', 'C3', 'Cz', 'O2',
                                     'STI 014'],
                           ch_types=['eeg'] * 20 + ['stim'],
                           sfreq=256.0,
                           montage="standard_1020")

    # Read the raw data
    exclude = ['ECG', 'AUX1', 'AUX2', 'AUX3', 'ESUTimestamp', 'SystemTimestamp', 'Tilt X', 'Tilt Y', 'Tilt Z']
    raw = mne.io.read_raw_edf(eeg_file_path, preload=True, exclude=exclude, verbose=False)
    data = raw.get_data()
    raw_selected = mne.io.RawArray(data, info, verbose=False)

    return raw_selected


def _get_epoch_data(subject, read_path, trial, preload=False):
    """
    Get the epcohed eeg data excluding unnessary channels from fif file

    Parameter
    ---------- 
    read_path   : path to epoched data file
    prelaod     : default False

    Returns
    ----------
    epochs : epoched data

    """
    path = read_path + subject + '_' + trial + '_' + str(epoch_length) + '_cleaned_epo.fif'
    epochs = mne.read_epochs(path, preload=preload, verbose=False)

    return epochs


def _create_filtered_epochs(raw_eeg, trial_start, trial_end):
    """
    Creates epochs of data from raw eeg and filters the raw eeg with notch filter, band-pass filter, and sets eeg reference

    Parameter
    ---------- 
    raw_eeg     : Raw eeg file containing data
    trail_start : start time of the trial with eeg time as reference
    trail_end   : end time of the trial with eeg time as reference

    Returns
    ----------
    epochs      : Epochs eeg data

    """
    raw_trial = raw_eeg.copy().crop(tmin=trial_start, tmax=trial_end)  # Crop the trials 
    raw_trial.set_eeg_reference('average')  # Reference the EEG
    raw_trial.notch_filter(60, filter_length='auto', phase='zero', verbose=False)  # Remove line noise 
    raw_trial.filter(l_freq=1, h_freq=50, fir_design='firwin', verbose=False)  # Band pass filter
    events = mne.make_fixed_length_events(raw_trial, duration=epoch_length)
    epochs = mne.Epochs(raw_trial, events, tmin=0,
                        tmax=epoch_length, verbose=False, preload=True)  # Create the epochs of the data

    return epochs


def _autoreject_epochs(epochs):
    """
    Rejects the bad epochs with AutoReject algorithm

    Parameter
    ---------- 
    epochs : Epoched, filtered eeg data

    Returns
    ----------
    epochs : Epoched data after rejection of bad epochs

    """
    # Cleaning with autoreject
    picks = mne.pick_types(epochs.info, eeg=True)  # Find indices of all EEG channels
    ar = AutoReject(n_interpolate=[1, 4, 8], n_jobs=6, picks=picks, thresh_func='bayesian_optimization', cv=10,
                    random_state=42, verbose=False)

    cleaned_epochs, reject_log = ar.fit_transform(epochs, return_log=True)
    # reject_log.plot_epochs(epochs, scalings=dict(eeg=40e-6)) use this if you want to see the rejected epochs.

    return cleaned_epochs


def _run_ica(epochs, reject):
    """
    Runs ICA on the given epochs data

    Parameter
    ---------- 
    epochs : Epoched, filtered, and autorejected eeg data

    Returns
    ----------
    ICA : ICA object from mne

    """
    picks = mne.pick_types(epochs.info, meg=False, eeg=True, eog=False, stim=False, exclude='bads')
    ica = mne.preprocessing.ICA(n_components=None, method="picard", verbose=False)
    ica.fit(epochs, picks=picks, reject=reject)

    return ica

def _find_eog(epochs, ica):
    """
    Detects the eye blink aritifact indices

    Parameter
    ---------- 
    epochs : Epoched, filtered, and autorejected eeg data
    ica    : ica object from mne

    Returns
    ----------
    ICA : ICA object with eog indices appended

    """
    # Find bad EOG artifact (eye blinks) by correlating with Fp1
    eog_inds, scores_eog = ica.find_bads_eog(epochs, ch_name='Fp1', verbose=False)
    eog_inds.sort()
    # Append only when the correlation is high
    id_eog = [i for i, n in enumerate(scores_eog.tolist()) if abs(n) >= 0.65]
    ica.exclude += id_eog
    
    # Find bad EOG artifact (eye blinks) by correlation with Fp2
    eog_inds, scores_eog = ica.find_bads_eog(epochs, ch_name='Fp2', verbose=False)
    eog_inds.sort()
    # Append only when the correlation is high
    id_eog = [i for i, n in enumerate(scores_eog.tolist()) if abs(n) >= 0.65]
    ica.exclude += id_eog
    
    return ica

   

In [1]:
"""---------------------------------------------------Force Data Related Function-----------------------------------------------------------"""

def _get_dropped_id(x):
    """
    Get the id of the dropped epcoh eeg

    Parameter
    ---------- 
    x   : list of indices from drop_log fie 

    Returns
    ----------
    drop_id : id of the dropped epochs

    """
    drop_id = []
    for i in range(len(x)):
        if x[i]:
            drop_id.append(i)

    return drop_id


def _resample_force_moment(x, freq_in, freq_out):
    """
    Resamples the force (or any general vector x) to desired frequency

    Parameter
    ---------- 
    freq_in  : frequency of x signal
    freq_out : desired frequency of x

    Returns
    ----------
    out : resampled signal with freq_out frequency

    """
    n_samples = round(len(x)*freq_out/freq_in)
    out = resample(x, n_samples)
    return out


def _get_force_moment_data(subject, trial_file_path, trial):
    """
    Resamples the force (or any general vector x) to desired frequency

    Parameter
    ---------- 
    subject         : string of subject ID e.g. 7707 
    trial_file_path : path to trial file (force data folder)
    trial           : trial (str)

    Returns
    ----------
    force_moment_data : numpy array containing x, y, force_x, force_y, total_force, moment_x, moment_y, total_moment, total_moment_scaled

    """
    idx = ['HighFine', 'HighGross', 'LowFine', 'LowGross'].index(trial)
    force_moment_data = np.genfromtxt(trial_file_path[idx], dtype=float, delimiter=',',
                                      usecols=[13, 14, 15, 16, 17, 18, 19, 20],
                                      skip_footer=100, skip_header=150).tolist()
    time_data = np.genfromtxt(trial_file_path[idx], dtype=str, delimiter=',', usecols=0, skip_footer=150,
                              skip_header=100).tolist()

    # Get the sampling frequency
    time = [datetime.strptime(item, '%H:%M:%S:%f') for item in time_data]
    time = np.array(time)  # convert to numpy
    dt = np.diff(time).mean()  # average sampling rate
    freq_in = 1/dt.total_seconds()
    freq_out = 256.0  # according to eeg sampling rate 

    force_moment_resampled = _resample_force_moment(force_moment_data, freq_in, freq_out)

    # Required data
    force_x = force_moment_resampled[:, 0]
    force_y = force_moment_resampled[:, 1]
    total_force = np.linalg.norm(force_moment_resampled[:, 0:2], axis=1)
    moment_x = force_moment_resampled[:, 3]
    moment_y = force_moment_resampled[:, 4]
    total_moment = np.linalg.norm(force_moment_resampled[:, 3:5], axis=1)
    x = force_moment_resampled[:, 6]
    y = force_moment_resampled[:, 7]
    total_moment_scaled = np.mean(total_force)/np.mean(total_moment)*total_moment

    # Stack all the vectors
    force_moment_data = np.vstack((x, y, force_x, force_y, total_force, moment_x, moment_y, total_moment, total_moment_scaled))
    start_time = (time[0]-time[0]).total_seconds()
    end_time = (time[-1]-time[0]).total_seconds()

    return force_moment_data, start_time, end_time

In [None]:
def _get_n_params(model):
    """
    Calculates the number of parameters in the model

    Parameter
    ---------- 
    model : pytorch model 

    Returns
    ----------
    n_p  : number of parameters to train

    """
    n_p = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return n_p


def _make_balanced(x, y):
    """
    Converts imbalanced data to balanced data

    Parameter
    ---------- 
    x : input x 
    y : output y 

    Returns
    ----------
    x_balanced  : input balanced
    y_balanced  : ouput balanced

    """
    
    
    x_normal = x[np.argmax(y, axis=1)==1,:,:]
    y_normal = y[np.argmax(y, axis=1)==1]

    x_low = x[np.argmax(y, axis=1)==0,:,:]
    y_low = y[np.argmax(y, axis=1)==0]

    x_high = x[np.argmax(y, axis=1)==2,:,:]
    y_high = y[np.argmax(y, axis=1)==2]

    x_normal, x_test, y_normal, y_test = train_test_split(x_normal, y_normal, test_size = 0.50)

    x_balanced = np.vstack((x_low, x_normal, x_high))
    y_balanced = np.vstack((y_low, y_normal, y_high))
    
    return x_balanced, y_balanced