# MOABB Classifier comparison

This notebook takes multiple files and does the following pipeline:
1) Import data and pre-process
2) For each stimulus type 
    1) Run classifiers (i.e., CCA, MEC, MSI, and RG)
    2) Store accuracy values

A final plot for each stimulus type is plotted 

In [1]:
# Default libraries
import re
import mne
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.metrics import accuracy_score
from sklearn.pipeline import make_pipeline

from pyriemann.estimation import Covariances
from pyriemann.tangentspace import TangentSpace

from moabb.datasets import Wang2016, SSVEPExo
from moabb.paradigms import FilterBankSSVEP, SSVEP
from moabb.pipelines import ExtendedSSVEPSignal
from sklearn.linear_model import LogisticRegression
from moabb.evaluations import CrossSubjectEvaluation

In [None]:
# Import custom libraries
from ncantest import data_tools
from ncantest import processing
from ncantest import classification

In [None]:
from FeatureExtractorSSVEP import FeatureExtractorCCA as CCA
from FeatureExtractorSSVEP import FeatureExtractorMSI as MSI
from FeatureExtractorSSVEP import FeatureExtractorMEC as MEC

# Magic command to reload libraries
%reload_ext autoreload

## Settings

Note that if the `Wang2016` is not already downloaded, it'll be downloaded the first time you run this section of the code.

In [None]:
mne.set_config("MNE_DATA", "~/.mne")

In [None]:
mne.get_config("MNE_DATA")

In [None]:
# Import and epoch data
dataset = Wang2016()
dataset.subject_list = dataset.subject_list[:2]

In [None]:
data = dataset.get_data()

In [None]:
eeg_channels = ["O1","Oz","O2"]

# Information from dataset description
nsubjects = len(data)
# labels_dict = {"13":0, "17":2, "21":1}  # The order is changed because the labels in the dataset are incorrect
# stimulus_freqs = [int(f) for f in labels_dict.keys()]

stimulus_freqs = [float(freq) for freq in dataset.event_id.keys()]  # Stimulus frequencies [Hz]
srate = data[1]["0"]["0"].info['sfreq'] # Sampling rate [Hz]
tmin = 0.5  # Time of start of SSVEP stimulus [sec]
tmax = 5.5  # Time of end of SSVEP stimulus [sec]

# Classifier settings
# - Create CCA subbands like in Chen et al. (2015) paper
first_column = np.arange(1, 11) * 8
second_column = np.full(10, 88)
cca_subbands = np.column_stack((first_column, second_column))
harmonic_count = 2
classifiers = ["fbCCA", "MSI", "MEC", "RG_logreg"]

# Create an empty dataframe to store the accuracies
accuracy_df = pd.DataFrame(
    index = np.arange(0, nsubjects),
    columns = classifiers
    )
accuracy_df.index.name = "Subject"

## Separate epochs

From the raw recording datasets, create the EEG epochs using just the periods where the SSVEP stimulus was active.

In [None]:
# Preallocate data
epochs_list = [None] * len(data)
events_list = [None] * len(data)

# Obtain epochs and events
for s, subject in data.items():
    [events_list[s-1], epochs_list[s-1]] = data_tools.moabb_events_to_np(
        mne_raw = subject["0"]["0"],
        tmin = tmin,
        tmax = tmax,
        events_dict = dataset.event_id,
        chans = eeg_channels
        )
    
# Convert lists to np.ndarrays
# eeg_channels = data[1]["0"]["0"].ch_names
epochs_np = np.float32(np.array(epochs_list))
events_np = np.array(events_list[0][:,2]) - 1   # The `-1` is to make the labels start at 0

## Classifiers

### Riemmanian geometry + logistic regression

### fbCCA

In [None]:
# Prototype and preallocate data
cca = CCA()
cca_accuracies = np.zeros(nsubjects)

cca.setup_feature_extractor(
    harmonics_count = harmonic_count,
    targets_frequencies = stimulus_freqs,
    sampling_frequency = srate,
    samples_count = epochs_np.shape[-1],
    filter_order = 12,
    subbands = cca_subbands
    )
    

# Classify all epochs per subject
for s, subject in enumerate(epochs_np):
    cca_features = cca.extract_features(subject)
    cca_predictions = np.argmax(np.max(np.squeeze(cca_features), axis=1), axis=1)
    cca_accuracies[s] = accuracy_score(events_np, cca_predictions)


In [None]:

accuracy_df["fbCCA"] = cca_accuracies
    

### MEC

### MSI

## Visualization

### Boxplot

## Save results
