In [1]:
from moabb.datasets import SSVEPExo

dataset = SSVEPExo()
interval = dataset.interval
sfreq=256

## Algorithm 1: Canonical correlation analysis

In [2]:
from moabb.paradigms import SSVEP

n_classes=3
paradigm_bandpass = SSVEP(fmin=10, fmax=45, tmin=0, tmax=5, n_classes=n_classes)
freqs = paradigm_bandpass.used_events(dataset)
print(freqs)

Choosing the first 3 classes from all possible events


{'13': 2, '17': 3, '21': 4}


In [3]:
from sklearn.pipeline import make_pipeline
from moabb.pipelines import SSVEP_CCA
from moabb.pipelines import ExtendedSSVEPSignal


pipelines_bandpass = dict()
pipelines_bandpass["CCA"] = make_pipeline(
    SSVEP_CCA(interval=interval, freqs=freqs, n_harmonics=3)
)

## Algorithm 2: Task-related correlation analysis

In [4]:
from moabb.pipelines import SSVEP_TRCA

pipelines_bandpass["TRCA"] = make_pipeline(
    SSVEP_TRCA(interval=interval, freqs=freqs,n_fbands=5)
    
)

## Algorithm 3: Riemannian Geometry

In [5]:
from moabb.paradigms import FilterBankSSVEP

filter_freqs = [13,13*2,13*3,17,17*2,17*3,21,21*2,21*3]
filters = [[f-.5, f+.5] for f in filter_freqs]

paradigm_filterbank = FilterBankSSVEP(n_classes=n_classes, tmin=0, tmax=5, filters=filters)

Choosing the first 3 classes from all possible events


In [6]:
from moabb.pipelines import ExtendedSSVEPSignal
from pyriemann.estimation import Covariances
from pyriemann.tangentspace import TangentSpace
from sklearn.linear_model import LogisticRegression

pipelines_filterbank = dict()
pipelines_filterbank["RG+logreg"] = make_pipeline(
    ExtendedSSVEPSignal(),
    Covariances(estimator="lwf"),
    TangentSpace(),
    LogisticRegression(solver="lbfgs", multi_class="auto"),
)


## Our algorithm

In [7]:
paradigm_filterbank_baseline = FilterBankSSVEP(n_classes=n_classes, tmin=-3, tmax=5, filters=filters)

Choosing the first 3 classes from all possible events


In [10]:
from sklearn.base import BaseEstimator, TransformerMixin
import numpy as np
from sklearn.preprocessing import FunctionTransformer

class BandPower(BaseEstimator, TransformerMixin):
    
    def __init__(self, sfreq=256, tmin=-3, baseline=(-2.5, -.5)):
        self.sfreq=sfreq
        self.tmin=tmin
        self.baseline=baseline
    
    def fit(self, X,y=None):
        return self
    
    def transform(self, X,y=None):
        baseline_start = int((self.baseline[0]-self.tmin)*self.sfreq)
        baseline_end = int((self.baseline[1]-self.tmin)*self.sfreq)        
        baseline = X[:,:,baseline_start:baseline_end,:]
        stim_start = max(0, int(-self.tmin*self.sfreq))
        stimulation = X[:,:,stim_start:,:]
        baseline_power = np.mean(baseline**2, axis=(1,2))
        stimulation_power = np.mean(stimulation**2, axis=(1,2))
        return stimulation_power/baseline_power


pipelines_filterbank_baseline = dict()
pipelines_filterbank_baseline['power+logreg'] = make_pipeline(
    BandPower(),
    FunctionTransformer(np.log),
    LogisticRegression(solver="lbfgs", multi_class="auto"),
)

## Evaluate algorithm performance

In [None]:
from moabb.evaluations import WithinSessionEvaluation

evaluation_bandpass = WithinSessionEvaluation(
    paradigm=paradigm_bandpass,
    datasets=dataset,
    suffix="ssvep_workshop_whithin_session_bandpass",
    overwrite=False
)
results_bandpass = evaluation_bandpass.process(pipelines_bandpass)
results_bandpass

SSVEP Exoskeleton-WithinSession:  25%|████████████▎                                    | 3/12 [00:20<01:01,  6.79s/it]

In [None]:
evaluation_filterbank = WithinSessionEvaluation(
    paradigm=paradigm_filterbank,
    datasets=dataset,
    suffix="ssvep_workshop_whithin_session_filterbank",
    overwrite=False
)
results_filterbank = evaluation_filterbank.process(pipelines_filterbank)
results_filterbank

In [None]:
evaluation_filterbank = WithinSessionEvaluation(
    paradigm=paradigm_filterbank_baseline,
    datasets=dataset,
    suffix="ssvep_workshop_whithin_session_filterbank_baseline",
    overwrite=True
)
results_filterbank_baseline = evaluation_filterbank.process(pipelines_filterbank_baseline)
results_filterbank_baseline

In [None]:
import seaborn as sns
import pandas as pd

results = pd.concat([results_bandpass, results_filterbank, results_filterbank_baseline])
ax = sns.stripplot(data=results, y="score", x="pipeline", alpha=.5, palette="Set1")
ax = sns.pointplot(data=results, y="score", x="pipeline" ,zorder=1, palette="Set1")
ax.set_ylabel("Accuracy")
ax.set_ylim(0,1)
ax.axhline(1/n_classes)