# Setup

## Download required packages

In [1]:
from collections import OrderedDict

import numpy as np
from numpy.fft import rfft, irfft, rfftfreq

import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from matplotlib import patches
import seaborn as sns
sns.set_theme()
import pandas as pd

import mne

from scipy.fftpack import fft
from sklearn.cross_decomposition import CCA
from sklearn.decomposition import PCA
import tensorflow as tf
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.metrics import confusion_matrix, accuracy_score
from scipy.signal import sosfiltfilt
from sklearn.pipeline import clone
from sklearn.metrics import balanced_accuracy_score

from meegkit import dss, ress
from meegkit import sns as msns
from meegkit.utils import unfold, rms, fold, tscov, matmul3d

from brainda.paradigms import SSVEP
from brainda.algorithms.utils.model_selection import (
    set_random_seeds, 
    generate_loo_indices, match_loo_indices)
from brainda.algorithms.decomposition import (
    FBTRCA, FBTDCA, FBSCCA, FBECCA, FBDSP,
    generate_filterbank, generate_cca_references)

  from .autonotebook import tqdm as notebook_tqdm


## Experimental constants

In [2]:
sub_dirs = ['run1/','run2/','run3/','run4/','run5/','run6/','run7/','run8/','run9/']  # each folder is a single independent run
duration = 1.5  # duration of a trial in seconds
n_trials = 2  
n_classes = 32  # number of classes, 8 freq, 4 phase, total of 8 * 4 = 32 targets
n_channels = 19  # number of recording channels

## Data wrangling

In [3]:
def load_data_temp_function(eeg, meta, classes, stim_duration=5, filter=True):
    """
    
    """
    trials = meta[1:,:2]
    times = []
    duration_samples = int(stim_duration*300)
    for index, row in eeg.loc[eeg[' TRG']==16.0].iterrows():
        if index > 0 and eeg.iloc[index-1][' TRG'] == 0 and (not eeg.iloc[index:index+duration_samples][' TRG'].isin([18.0]).any()):
            times.append(row['time'])
    times = np.array(times)
    
    eeg = np.array([eeg.loc[eeg['time']>t].drop(columns=['time',' TRG',' X1',' X2',' X3',' A2']).to_numpy()[:duration_samples].T for t in times])
    if filter:
        eeg = mne.filter.filter_data(eeg, sfreq=300, l_freq=5, h_freq=49, verbose=0, method='fir')
    eeg_temp = []
    for i in range(len(classes)):
        eeg_temp.append([])
    for i,freq in enumerate(trials):
        for j,target in enumerate(classes):
            if (freq==target).all():
                eeg_temp[j].append(eeg[i])
    eeg = np.array(eeg_temp).transpose(1,0,2,3)
    return eeg

In [4]:
# load experimental dataset
eeg_whole = np.zeros((n_trials*len(sub_dirs),n_classes,n_channels,int(duration*300)))
target_tab = {}
# i_class = 0
for i_dir,sub_dir in enumerate(sub_dirs):
    print(i_dir)
    data_path = "../data/" + sub_dir
    eeg = pd.read_csv(data_path + 'eeg.csv').astype(float)
    meta = np.loadtxt(data_path + 'meta.csv', delimiter=',', dtype=float)
    trials = meta[1:,:2]
    classes = np.unique(trials, axis=0)
    more_targets = {tuple(target):index for index,target in enumerate(classes)}
    target_tab.update(more_targets)
    eeg = load_data_temp_function(eeg, meta, classes, stim_duration=duration,filter=False)
    eeg_whole[i_dir*n_trials:(i_dir+1)*n_trials,:,:,:] = eeg
    # i_class+=3
eeg = eeg_whole
target_by_trial = [list(target_tab.keys())] * n_trials*len(sub_dirs)
eeg.shape, np.array(target_by_trial).shape 

0
1
2
3
4
5
6
7
8


((18, 32, 19, 450), (18, 32, 2))

## Will be fixed 8/9/2022 

In [7]:
n_trials = eeg.shape[0]
classes = range(n_classes)
n_classes = len(classes)
y = np.array([list(target_tab.values())] * n_trials).T.reshape(-1)
eeg_temp = eeg[:n_trials,classes,:,40:]
X = eeg_temp.swapaxes(0,1).reshape(-1,*eeg_temp.shape[2:])


freq_targets = np.array(target_by_trial)[0,:,0]
phase_targets = np.array(target_by_trial)[0,:,1]
n_harmonics = 5
n_bands = 3
srate = 300  # sampling rate
duration = 1  # duration to use for modeling
Yf = generate_cca_references(
    freq_targets, srate, duration, 
    phases=phase_targets, 
    n_harmonics=n_harmonics)
wp = [[8*i, 90] for i in range(1, n_bands+1)]
ws = [[8*i-2, 95] for i in range(1, n_bands+1)]
filterbank = generate_filterbank(
    wp, ws, srate, order=4, rp=1)
filterweights = np.arange(1, len(filterbank)+1)**(-1.25) + 0.25
set_random_seeds(64)
l = 5
models = OrderedDict([
    # ('fbscca', FBSCCA(
    #         filterbank, filterweights=filterweights)),
    # ('fbecca', FBECCA(
    #         filterbank, filterweights=filterweights)),
    ('fbdsp', FBDSP(
            filterbank, filterweights=filterweights)),
    ('fbtrca', FBTRCA(
            filterbank, filterweights=filterweights)),
    ('fbtdca', FBTDCA(
            filterbank, l, n_components=8, 
            filterweights=filterweights)),
])
events = []
for j_class in classes:
    events.extend([str(target_by_trial[i_trial][j_class]) for i_trial in range(n_trials)])
events = np.array(events)
subjects = ['1'] * (n_classes*n_trials)
meta = pd.DataFrame(data=np.array([subjects,events]).T, columns=["subject", "event"])
set_random_seeds(42)
loo_indices = generate_loo_indices(meta)

for model_name in models:
    if model_name == 'fbtdca':
        filterX, filterY = np.copy(X[..., :int(srate*duration)+l]), np.copy(y)
    else:
        filterX, filterY = np.copy(X[..., :int(srate*duration)]), np.copy(y)
    
    filterX = filterX - np.mean(filterX, axis=-1, keepdims=True)

    n_loo = len(loo_indices['1'][events[0]])
    loo_accs = []
    for k in range(n_loo):
        train_ind, validate_ind, test_ind = match_loo_indices(
            k, meta, loo_indices)
        train_ind = np.concatenate([train_ind, validate_ind])

        trainX, trainY = filterX[train_ind], filterY[train_ind]
        testX, testY = filterX[test_ind], filterY[test_ind]

        model = clone(models[model_name]).fit(
            trainX, trainY,
            Yf=Yf
        )
        pred_labels = model.predict(testX)
        loo_accs.append(
            balanced_accuracy_score(testY, pred_labels))
    print("Model:{} LOO Acc:{:.2f}".format(model_name, np.mean(loo_accs)))

AttributeError: 'FBTDCA' object has no attribute 'l'