In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from matplotlib import patches
import seaborn as sns
sns.set_theme()
import pandas as pd
import mne
from scipy.fftpack import fft
from numpy.fft import rfft, irfft, rfftfreq
from sklearn.cross_decomposition import CCA
from sklearn.decomposition import PCA
import tensorflow as tf
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.metrics import confusion_matrix, accuracy_score
from meegkit import dss, ress
from meegkit import sns as msns
from meegkit.utils import unfold, rms, fold, tscov, matmul3d, snr_spectrum
from brainda.paradigms import SSVEP
from brainda.algorithms.utils.model_selection import (
    set_random_seeds, 
    generate_loo_indices, match_loo_indices)
from brainda.algorithms.decomposition import (
    FBTRCA, FBTDCA, FBSCCA, FBECCA, FBDSP,
    generate_filterbank, generate_cca_references)
from collections import OrderedDict
import numpy as np
from scipy.signal import sosfiltfilt
from sklearn.pipeline import clone
from sklearn.metrics import balanced_accuracy_score

  from .autonotebook import tqdm as notebook_tqdm


## Load Square Data

In [15]:
def load_data_temp_function(eeg, meta, classes, stim_duration=4, filter=True):
    trials = meta[1:,:2]
    times = meta[:,2]
    times = (times - times[0])[1:]
    stim_end = int(stim_duration * 300 + 225)
    eeg['time'] = eeg['time'] - eeg['time'].iloc[0]
    eeg = np.array([eeg.loc[eeg['time']>t].drop(columns=['time',' TRG']).to_numpy()[:stim_end].T for t in times])[:,:,225:]
    if filter:
        eeg = mne.filter.filter_data(eeg, sfreq=300, l_freq=5, h_freq=49, verbose=0, method='fir')
    # eeg = mne.filter.filter_data(eeg, sfreq=300, l_freq=5, h_freq=49, verbose=0, method='fir',phase='minimum')
    eeg_temp = []
    for i in range(len(classes)):
        eeg_temp.append([])
    for i,freq in enumerate(trials):
        for j,target in enumerate(classes):
            if (freq==target).all():
                eeg_temp[j].append(eeg[i])
    eeg = np.array(eeg_temp).transpose(1,0,2,3)
    return eeg

target_tab = {}
data_path = "../data/eeg_recordings/pilot_data/simon/n12/"
eeg = pd.read_csv(data_path + 'eeg.csv').astype(float)
meta = np.loadtxt(data_path + 'meta.csv', delimiter=',', dtype=float)
trials = meta[1:,:2]
classes = np.unique(trials, axis=0)
more_targets = {tuple(target):index for index,target in enumerate(classes)}
target_tab.update(more_targets)
eeg = load_data_temp_function(eeg, meta, classes, stim_duration=4,filter=False)
target_by_trial = [list(target_tab.keys())] * 15
eeg.shape, np.array(target_by_trial).shape

((15, 12, 7, 1200), (15, 12, 2))

## Apply Brainda

In [16]:
srate = 300
duration = 3.5
n_trials = 15
n_classes = 12
delay_samples = int(0.14 * srate)
y = np.array([list(target_tab.values())] * n_trials).T.reshape(-1)
X = eeg[:n_trials].swapaxes(0,1).reshape(-1,*eeg.shape[2:])[:,:,delay_samples:]

freq_targets = np.array(target_by_trial)[0,:,0]
phase_targets = np.array(target_by_trial)[0,:,1]
n_harmonics = 5
n_bands = 3
Yf = generate_cca_references(
    freq_targets, srate, duration, 
    phases=phase_targets, 
    n_harmonics=n_harmonics)
wp = [[8*i, 90] for i in range(1, n_bands+1)]
ws = [[8*i-2, 95] for i in range(1, n_bands+1)]
filterbank = generate_filterbank(
    wp, ws, srate, order=4, rp=1)
filterweights = np.arange(1, len(filterbank)+1)**(-1.25) + 0.25
set_random_seeds(64)
l = 5
models = OrderedDict([
    ('fbscca', FBSCCA(
            filterbank, filterweights=filterweights)),
    ('fbecca', FBECCA(
            filterbank, filterweights=filterweights)),
    ('fbdsp', FBDSP(
            filterbank, filterweights=filterweights)),
    ('fbtrca', FBTRCA(
            filterbank, filterweights=filterweights)),
    ('fbtdca', FBTDCA(
            filterbank, l, n_components=8, 
            filterweights=filterweights)),
])
events = []
for j_class in range(n_classes):
# for j_class in classes:
    events.extend([str(target_by_trial[i_trial][j_class]) for i_trial in range(n_trials)])
events = np.array(events)
subjects = ['1'] * (n_classes*n_trials)
meta = pd.DataFrame(data=np.array([subjects,events]).T, columns=["subject", "event"])
set_random_seeds(42)
loo_indices = generate_loo_indices(meta)

for model_name in models:
    # if model_name != 'fbtrca':
    #     continue
    if model_name == 'fbtdca':
        filterX, filterY = np.copy(X[..., :int(srate*duration)+l]), np.copy(y)
    else:
        filterX, filterY = np.copy(X[..., :int(srate*duration)]), np.copy(y)
    
    filterX = filterX - np.mean(filterX, axis=-1, keepdims=True)

    n_loo = len(loo_indices['1'][events[0]])
    loo_accs = []
    for k in range(n_loo):
        train_ind, validate_ind, test_ind = match_loo_indices(
            k, meta, loo_indices)
        train_ind = np.concatenate([train_ind, validate_ind])

        trainX, trainY = filterX[train_ind], filterY[train_ind]
        testX, testY = filterX[test_ind], filterY[test_ind]

        model = clone(models[model_name]).fit(
            trainX, trainY,
            Yf=Yf
        )
        pred_labels = model.predict(testX)
        loo_accs.append(
            balanced_accuracy_score(testY, pred_labels))
    print("Model:{} LOO Acc:{:.2f}".format(model_name, np.mean(loo_accs)))

Model:fbscca LOO Acc:0.52
Model:fbecca LOO Acc:0.38
Model:fbdsp LOO Acc:0.26
Model:fbtrca LOO Acc:0.16
Model:fbtdca LOO Acc:0.23


## Load Sine Dataset

In [17]:
def load_data_temp_function(eeg, meta, classes, stim_duration=4, filter=True):
    trials = meta[1:,:2]
    times = meta[:,2]
    times = (times - times[0])[1:]
    stim_end = int(stim_duration * 300 + 225)
    eeg['time'] = eeg['time'] - eeg['time'].iloc[0]
    eeg = np.array([eeg.loc[eeg['time']>t].drop(columns=['time',' TRG']).to_numpy()[:stim_end].T for t in times])[:,:,225:]
    if filter:
        eeg = mne.filter.filter_data(eeg, sfreq=300, l_freq=5, h_freq=49, verbose=0, method='fir')
    # eeg = mne.filter.filter_data(eeg, sfreq=300, l_freq=5, h_freq=49, verbose=0, method='fir',phase='minimum')
    eeg_temp = []
    for i in range(len(classes)):
        eeg_temp.append([])
    for i,freq in enumerate(trials):
        for j,target in enumerate(classes):
            if (freq==target).all():
                eeg_temp[j].append(eeg[i])
    eeg = np.array(eeg_temp).transpose(1,0,2,3)
    return eeg

target_tab = {}
data_path = "../data/eeg_recordings/pilot_data/simon/n12_2/"
eeg = pd.read_csv(data_path + 'eeg.csv').astype(float)
meta = np.loadtxt(data_path + 'meta.csv', delimiter=',', dtype=float)
trials = meta[1:,:2]
classes = np.unique(trials, axis=0)
more_targets = {tuple(target):index for index,target in enumerate(classes)}
target_tab.update(more_targets)
eeg = load_data_temp_function(eeg, meta, classes, stim_duration=4,filter=False)
target_by_trial = [list(target_tab.keys())] * 15
eeg.shape, np.array(target_by_trial).shape

((15, 12, 7, 1200), (15, 12, 2))

## Apply Brainda

### 3.5s

In [14]:
srate = 300
duration = 3.5
n_trials = 15
n_classes = 12
delay_samples = int(0.14 * srate)
y = np.array([list(target_tab.values())] * n_trials).T.reshape(-1)
X = eeg[:n_trials].swapaxes(0,1).reshape(-1,*eeg.shape[2:])[:,:,delay_samples:]

freq_targets = np.array(target_by_trial)[0,:,0]
phase_targets = np.array(target_by_trial)[0,:,1]
n_harmonics = 5
n_bands = 3
Yf = generate_cca_references(
    freq_targets, srate, duration, 
    phases=phase_targets, 
    n_harmonics=n_harmonics)
wp = [[8*i, 90] for i in range(1, n_bands+1)]
ws = [[8*i-2, 95] for i in range(1, n_bands+1)]
filterbank = generate_filterbank(
    wp, ws, srate, order=4, rp=1)
filterweights = np.arange(1, len(filterbank)+1)**(-1.25) + 0.25
set_random_seeds(64)
l = 5
models = OrderedDict([
    ('fbscca', FBSCCA(
            filterbank, filterweights=filterweights)),
    ('fbecca', FBECCA(
            filterbank, filterweights=filterweights)),
    ('fbdsp', FBDSP(
            filterbank, filterweights=filterweights)),
    ('fbtrca', FBTRCA(
            filterbank, filterweights=filterweights)),
    ('fbtdca', FBTDCA(
            filterbank, l, n_components=8, 
            filterweights=filterweights)),
])
events = []
for j_class in range(n_classes):
# for j_class in classes:
    events.extend([str(target_by_trial[i_trial][j_class]) for i_trial in range(n_trials)])
events = np.array(events)
subjects = ['1'] * (n_classes*n_trials)
meta = pd.DataFrame(data=np.array([subjects,events]).T, columns=["subject", "event"])
set_random_seeds(42)
loo_indices = generate_loo_indices(meta)

for model_name in models:
    # if model_name != 'fbtrca':
    #     continue
    if model_name == 'fbtdca':
        filterX, filterY = np.copy(X[..., :int(srate*duration)+l]), np.copy(y)
    else:
        filterX, filterY = np.copy(X[..., :int(srate*duration)]), np.copy(y)
    
    filterX = filterX - np.mean(filterX, axis=-1, keepdims=True)

    n_loo = len(loo_indices['1'][events[0]])
    loo_accs = []
    for k in range(n_loo):
        train_ind, validate_ind, test_ind = match_loo_indices(
            k, meta, loo_indices)
        train_ind = np.concatenate([train_ind, validate_ind])

        trainX, trainY = filterX[train_ind], filterY[train_ind]
        testX, testY = filterX[test_ind], filterY[test_ind]

        model = clone(models[model_name]).fit(
            trainX, trainY,
            Yf=Yf
        )
        pred_labels = model.predict(testX)
        loo_accs.append(
            balanced_accuracy_score(testY, pred_labels))
    print("Model:{} LOO Acc:{:.2f}".format(model_name, np.mean(loo_accs)))

Model:fbscca LOO Acc:0.79
Model:fbecca LOO Acc:0.56
Model:fbdsp LOO Acc:0.10
Model:fbtrca LOO Acc:0.11
Model:fbtdca LOO Acc:0.18


### 2s

In [18]:
srate = 300
duration = 2
n_trials = 15
n_classes = 12
delay_samples = int(0.14 * srate)
y = np.array([list(target_tab.values())] * n_trials).T.reshape(-1)
X = eeg[:n_trials].swapaxes(0,1).reshape(-1,*eeg.shape[2:])[:,:,delay_samples:]

freq_targets = np.array(target_by_trial)[0,:,0]
phase_targets = np.array(target_by_trial)[0,:,1]
n_harmonics = 5
n_bands = 3
Yf = generate_cca_references(
    freq_targets, srate, duration, 
    phases=phase_targets, 
    n_harmonics=n_harmonics)
wp = [[8*i, 90] for i in range(1, n_bands+1)]
ws = [[8*i-2, 95] for i in range(1, n_bands+1)]
filterbank = generate_filterbank(
    wp, ws, srate, order=4, rp=1)
filterweights = np.arange(1, len(filterbank)+1)**(-1.25) + 0.25
set_random_seeds(64)
l = 5
models = OrderedDict([
    ('fbscca', FBSCCA(
            filterbank, filterweights=filterweights)),
    ('fbecca', FBECCA(
            filterbank, filterweights=filterweights)),
    ('fbdsp', FBDSP(
            filterbank, filterweights=filterweights)),
    ('fbtrca', FBTRCA(
            filterbank, filterweights=filterweights)),
    ('fbtdca', FBTDCA(
            filterbank, l, n_components=8, 
            filterweights=filterweights)),
])
events = []
for j_class in range(n_classes):
# for j_class in classes:
    events.extend([str(target_by_trial[i_trial][j_class]) for i_trial in range(n_trials)])
events = np.array(events)
subjects = ['1'] * (n_classes*n_trials)
meta = pd.DataFrame(data=np.array([subjects,events]).T, columns=["subject", "event"])
set_random_seeds(42)
loo_indices = generate_loo_indices(meta)

for model_name in models:
    # if model_name != 'fbtrca':
    #     continue
    if model_name == 'fbtdca':
        filterX, filterY = np.copy(X[..., :int(srate*duration)+l]), np.copy(y)
    else:
        filterX, filterY = np.copy(X[..., :int(srate*duration)]), np.copy(y)
    
    filterX = filterX - np.mean(filterX, axis=-1, keepdims=True)

    n_loo = len(loo_indices['1'][events[0]])
    loo_accs = []
    for k in range(n_loo):
        train_ind, validate_ind, test_ind = match_loo_indices(
            k, meta, loo_indices)
        train_ind = np.concatenate([train_ind, validate_ind])

        trainX, trainY = filterX[train_ind], filterY[train_ind]
        testX, testY = filterX[test_ind], filterY[test_ind]

        model = clone(models[model_name]).fit(
            trainX, trainY,
            Yf=Yf
        )
        pred_labels = model.predict(testX)
        loo_accs.append(
            balanced_accuracy_score(testY, pred_labels))
    print("Model:{} LOO Acc:{:.2f}".format(model_name, np.mean(loo_accs)))

Model:fbscca LOO Acc:0.53
Model:fbecca LOO Acc:0.43
Model:fbdsp LOO Acc:0.14
Model:fbtrca LOO Acc:0.11
Model:fbtdca LOO Acc:0.18


### 1.5s

In [19]:
srate = 300
duration = 1.5
n_trials = 15
n_classes = 12
delay_samples = int(0.14 * srate)
y = np.array([list(target_tab.values())] * n_trials).T.reshape(-1)
X = eeg[:n_trials].swapaxes(0,1).reshape(-1,*eeg.shape[2:])[:,:,delay_samples:]

freq_targets = np.array(target_by_trial)[0,:,0]
phase_targets = np.array(target_by_trial)[0,:,1]
n_harmonics = 5
n_bands = 3
Yf = generate_cca_references(
    freq_targets, srate, duration, 
    phases=phase_targets, 
    n_harmonics=n_harmonics)
wp = [[8*i, 90] for i in range(1, n_bands+1)]
ws = [[8*i-2, 95] for i in range(1, n_bands+1)]
filterbank = generate_filterbank(
    wp, ws, srate, order=4, rp=1)
filterweights = np.arange(1, len(filterbank)+1)**(-1.25) + 0.25
set_random_seeds(64)
l = 5
models = OrderedDict([
    ('fbscca', FBSCCA(
            filterbank, filterweights=filterweights)),
    ('fbecca', FBECCA(
            filterbank, filterweights=filterweights)),
    ('fbdsp', FBDSP(
            filterbank, filterweights=filterweights)),
    ('fbtrca', FBTRCA(
            filterbank, filterweights=filterweights)),
    ('fbtdca', FBTDCA(
            filterbank, l, n_components=8, 
            filterweights=filterweights)),
])
events = []
for j_class in range(n_classes):
# for j_class in classes:
    events.extend([str(target_by_trial[i_trial][j_class]) for i_trial in range(n_trials)])
events = np.array(events)
subjects = ['1'] * (n_classes*n_trials)
meta = pd.DataFrame(data=np.array([subjects,events]).T, columns=["subject", "event"])
set_random_seeds(42)
loo_indices = generate_loo_indices(meta)

for model_name in models:
    # if model_name != 'fbtrca':
    #     continue
    if model_name == 'fbtdca':
        filterX, filterY = np.copy(X[..., :int(srate*duration)+l]), np.copy(y)
    else:
        filterX, filterY = np.copy(X[..., :int(srate*duration)]), np.copy(y)
    
    filterX = filterX - np.mean(filterX, axis=-1, keepdims=True)

    n_loo = len(loo_indices['1'][events[0]])
    loo_accs = []
    for k in range(n_loo):
        train_ind, validate_ind, test_ind = match_loo_indices(
            k, meta, loo_indices)
        train_ind = np.concatenate([train_ind, validate_ind])

        trainX, trainY = filterX[train_ind], filterY[train_ind]
        testX, testY = filterX[test_ind], filterY[test_ind]

        model = clone(models[model_name]).fit(
            trainX, trainY,
            Yf=Yf
        )
        pred_labels = model.predict(testX)
        loo_accs.append(
            balanced_accuracy_score(testY, pred_labels))
    print("Model:{} LOO Acc:{:.2f}".format(model_name, np.mean(loo_accs)))

Model:fbscca LOO Acc:0.34
Model:fbecca LOO Acc:0.29
Model:fbdsp LOO Acc:0.11
Model:fbtrca LOO Acc:0.10
Model:fbtdca LOO Acc:0.15


### 0.8s

In [21]:
srate = 300
duration = 0.8
n_trials = 15
n_classes = 12
delay_samples = int(0.14 * srate)
y = np.array([list(target_tab.values())] * n_trials).T.reshape(-1)
X = eeg[:n_trials].swapaxes(0,1).reshape(-1,*eeg.shape[2:])[:,:,delay_samples:]

freq_targets = np.array(target_by_trial)[0,:,0]
phase_targets = np.array(target_by_trial)[0,:,1]
n_harmonics = 5
n_bands = 3
Yf = generate_cca_references(
    freq_targets, srate, duration, 
    phases=phase_targets, 
    n_harmonics=n_harmonics)
wp = [[8*i, 90] for i in range(1, n_bands+1)]
ws = [[8*i-2, 95] for i in range(1, n_bands+1)]
filterbank = generate_filterbank(
    wp, ws, srate, order=4, rp=1)
filterweights = np.arange(1, len(filterbank)+1)**(-1.25) + 0.25
set_random_seeds(64)
l = 5
models = OrderedDict([
    ('fbscca', FBSCCA(
            filterbank, filterweights=filterweights)),
    ('fbecca', FBECCA(
            filterbank, filterweights=filterweights)),
    ('fbdsp', FBDSP(
            filterbank, filterweights=filterweights)),
    ('fbtrca', FBTRCA(
            filterbank, filterweights=filterweights)),
    ('fbtdca', FBTDCA(
            filterbank, l, n_components=8, 
            filterweights=filterweights)),
])
events = []
for j_class in range(n_classes):
# for j_class in classes:
    events.extend([str(target_by_trial[i_trial][j_class]) for i_trial in range(n_trials)])
events = np.array(events)
subjects = ['1'] * (n_classes*n_trials)
meta = pd.DataFrame(data=np.array([subjects,events]).T, columns=["subject", "event"])
set_random_seeds(42)
loo_indices = generate_loo_indices(meta)

for model_name in models:
    # if model_name != 'fbtrca':
    #     continue
    if model_name == 'fbtdca':
        filterX, filterY = np.copy(X[..., :int(srate*duration)+l]), np.copy(y)
    else:
        filterX, filterY = np.copy(X[..., :int(srate*duration)]), np.copy(y)
    
    filterX = filterX - np.mean(filterX, axis=-1, keepdims=True)

    n_loo = len(loo_indices['1'][events[0]])
    loo_accs = []
    for k in range(n_loo):
        train_ind, validate_ind, test_ind = match_loo_indices(
            k, meta, loo_indices)
        train_ind = np.concatenate([train_ind, validate_ind])

        trainX, trainY = filterX[train_ind], filterY[train_ind]
        testX, testY = filterX[test_ind], filterY[test_ind]

        model = clone(models[model_name]).fit(
            trainX, trainY,
            Yf=Yf
        )
        pred_labels = model.predict(testX)
        loo_accs.append(
            balanced_accuracy_score(testY, pred_labels))
    print("Model:{} LOO Acc:{:.2f}".format(model_name, np.mean(loo_accs)))

Model:fbscca LOO Acc:0.11
Model:fbecca LOO Acc:0.14
Model:fbdsp LOO Acc:0.13
Model:fbtrca LOO Acc:0.17
Model:fbtdca LOO Acc:0.14
