# Visualisation
## Preparation

In [None]:
from pathlib import Path

import yaml

import numpy as np
import pandas as pd
from torch.nn import MSELoss
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import FunctionTransformer
from sklearn.manifold import TSNE
from umap import UMAP

import moabb
from moabb.datasets import Schirrmeister2017
from moabb.evaluations import CrossSubjectEvaluation, WithinSessionEvaluation
from moabb.paradigms import MotorImagery, FilterBankMotorImagery
from moabb.analysis import Results

import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import seaborn as sns
%matplotlib inline

from models import EEGNetv4
from skorch_frozen import FrozenNeuralNetTransformer

moabb.set_log_level("info")

### Load condig

In [None]:
config_file = Path( 'config.yaml')
local_config_file = Path('local_config.yaml')
with config_file.open('r') as f:
    config = yaml.safe_load(f)
with local_config_file.open('r') as f:
    local_config = yaml.safe_load(f)
suffix = local_config['evaluation_params']['base']['suffix']
n_classes = config['paradigm_params']['base']['n_classes']
channels = config['paradigm_params']['base']['channels']
resample = config['paradigm_params']['base']['resample']
t0, t1 = Schirrmeister2017().interval


### Dataset

In [None]:
dataset = Schirrmeister2017()

paradigm = MotorImagery(
    **config['paradigm_params']['base'],
    **config['paradigm_params']['single_band'],
)


### Get network checkpoint paths

In [None]:
results_param_names = ['hdf5_path', 'additional_columns'] # do not use overwrite=True !!
results_params = {k: local_config['evaluation_params']['base'][k] for k in results_param_names if
                  k in local_config['evaluation_params']['base']}
results_params['overwrite'] = False

fake_results = Results(CrossSubjectEvaluation, MotorImagery, **results_params)
checkpoints_root_dir = Path(fake_results.filepath).parent
del fake_results
checkpoints_dict = {}
for subject in dataset.subject_list:
    path = checkpoints_root_dir / str(subject)
    files = list(path.glob('*.ckpt'))
    if len(files) != 1:
        raise ValueError(f'Multiple or no checkpoint file(s) present at {path}')
    checkpoints_dict[subject] = str(files[0])

checkpoints_dict


### Load embedding functions

In [None]:
embeddings_dict = {
    subject: FrozenNeuralNetTransformer(EEGNetv4.load_from_checkpoint(checkpoint_path).embedding, criterion=MSELoss)
    for subject, checkpoint_path in checkpoints_dict.items()
}

# embeddings_dict


### Load data

In [None]:
X, labels, metadata = paradigm.get_data(dataset, return_epochs=False)
X = X.astype('float32')

### Load results

In [None]:
results_FBCSP = Results(WithinSessionEvaluation, FilterBankMotorImagery, suffix='', **results_params)
results_EEGNetLP = Results(WithinSessionEvaluation, MotorImagery, suffix='', **results_params)
results_FBCSP_as = Results(WithinSessionEvaluation, FilterBankMotorImagery, suffix='all_samples', **results_params)
results_EEGNetLP_as = Results(WithinSessionEvaluation, MotorImagery, suffix='all_samples', **results_params)
results_EEGNet = Results(CrossSubjectEvaluation, MotorImagery, suffix='', **results_params)

results_ws = pd.concat([
    results_FBCSP.to_dataframe(),
    results_EEGNetLP.to_dataframe(),
], join='outer', ignore_index=True)

results_ws_as = pd.concat([
    results_FBCSP_as.to_dataframe(),
    results_EEGNetLP_as.to_dataframe(),
], join='outer', ignore_index=True)

results_cs = results_EEGNet.to_dataframe()


del results_FBCSP, results_EEGNet, results_FBCSP_as, results_EEGNetLP_as, results_EEGNetLP
# results_df


In [None]:
# complet results
calibration_col = 'Individual calibration'
score_col = 'Accuracy'
time_col = 'Calibration time [s]'
samples_col = 'Number of calibration trials per class'
subject_col = 'Subject'
pipeline_col = 'Pipeline'
pipeline_names_map = {
    'FBCSP+LogisticReg': 'FBCSP',
    'EEGNet+LP': 'EEGNet+LP',
    'EEGNet-CrossSubject': 'EEGNet',
}
def complete_ressults(results_df):
    results_df[calibration_col] = results_df.pipeline.isin(['FBCSP+LogisticReg','EEGNet+LP'])

    results_df[score_col] = results_df.score

    results_df[time_col] = results_df.time
    results_df.loc[~results_df[calibration_col], time_col] = 0.0

    results_df[samples_col] = results_df.samples / 4
    results_df.loc[~results_df[calibration_col], samples_col] = np.nan

    results_df[subject_col] = results_df.subject

    results_df[pipeline_col] = results_df.pipeline.map(pipeline_names_map)
    assert not results_df[pipeline_col].isnull().any()

    return results_df

for r in [results_ws, results_cs, results_ws_as]:
    complete_ressults(r)
results_df = pd.concat([results_ws, results_cs], join='outer', ignore_index=True)


### Prepare export directory

In [None]:
export_dir = Path('./export')
if not export_dir.is_dir():
    export_dir.mkdir()
    

## Results 

In [None]:
results_cs.score.apply(['mean', 'std', 'count'])

In [None]:
results_ws_as[results_ws_as.pipeline=='FBCSP+LogisticReg'].score.apply(['mean', 'std', 'count'])

In [None]:
results_ws_as[results_ws_as.pipeline=='EEGNet+LP'].score.apply(['mean', 'std', 'count'])

### Results table

In [None]:
df = pd.concat([results_cs, results_ws_as], ignore_index=True)
df = df.pivot(index=subject_col, columns=pipeline_col, values=score_col).sort_index(key=pd.to_numeric)
score_table = pd.concat([df, df.apply(['mean', 'std'])]).T

print(score_table.to_latex(float_format="{:0.3f}".format))
score_table

### Wilcoxon signer-rank test

In [None]:
from scipy.stats import wilcoxon
wil_ds = results_ws_as.pivot(index=subject_col, columns=pipeline_col, values=score_col).sort_index(key=pd.to_numeric)

print(wilcoxon(wil_ds['EEGNet+LP'], wil_ds['FBCSP'], alternative='two-sided'))
print(wil_ds.mean())
wil_ds

## Performance plots

In [None]:
df0 = results_df.copy()[results_df.pipeline=='EEGNet-CrossSubject']
df1 = df0.copy()
df0.loc[:, samples_col] = results_df[samples_col].min()
df1.loc[:, samples_col] = results_df[samples_col].max()
df3 = results_df.groupby([samples_col, pipeline_col, subject_col], as_index=False)[score_col].agg('mean')
df = pd.concat([df3, df0, df1], join='outer', ignore_index=True)
ticks = results_df[samples_col].unique()
ticks = np.sort(ticks[~np.isnan(ticks)].astype(int)).tolist()


In [None]:
sns.set_theme(style='whitegrid', context='paper')
ax = sns.lineplot(data=df, x=samples_col, y=score_col, hue=pipeline_col, style=pipeline_col, dashes=[(1,0),(1,0),(4,1)], markers=['o','o','.']) #, aspect=1.3, )
ax.set(xscale="log", xticks=ticks, xticklabels=ticks)

plt.savefig(f'export/acc_vs_num-samples_xlog.pdf', bbox_inches='tight')


In [None]:
sns.set_theme(style='whitegrid', context='paper')
ax = sns.lineplot(data=df, x=samples_col, y=score_col, hue=pipeline_col, style=pipeline_col, dashes=[(1,0),(1,0),(4,1)], markers=['o','o','.']) #, aspect=1.3, )
ax.set(xticks=ticks, xticklabels=ticks)

plt.savefig(f'export/acc_vs_num-samples_xlin_exact-ticks.pdf', bbox_inches='tight')


In [None]:
sns.set_theme(style='whitegrid', context='paper')
ax = sns.lineplot(data=df, x=samples_col, y=score_col, hue=pipeline_col, style=pipeline_col, dashes=[(1,0),(1,0),(4,1)], markers=['o','o','.']) #, aspect=1.3, )

plt.savefig(f'export/acc_vs_num-samples_xlin.pdf', bbox_inches='tight')


In [None]:
# ax = sns.lineplot(data=df, x=samples_col, y=time_col, hue=pipeline_col, style=pipeline_col, dashes=[(1,0),(1,0),(4,1)]) #, aspect=1.3, )

# plt.savefig(f'export/calib-time_vs_num-samples.pdf', bbox_inches='tight')

## Projection plots
### Functions

In [None]:
def sns_scatterplot(data, ax0=None, comment=None, **kwargs):
    ax = sns.scatterplot(data=data, x='x', y='y', ax=ax0, **kwargs)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_xlabel('')
    ax.set_ylabel('')
    if comment is not None:
        ax.text(
            0.99,
            0.01,
            comment,
            transform=ax.transAxes,
            horizontalalignment="right")
    return ax

### Compute projections

In [None]:
test_subject = 2
%time X_emb = embeddings_dict[test_subject].transform(X)


In [None]:
reduction_algo = 'TSNE'
# reduction_algo = 'UMAP'

if reduction_algo=='TSNE':
    reducer = TSNE(n_components=2, perplexity=50, random_state=12, metric='euclidean', learning_rate='auto', init='pca', n_jobs=-1)
elif reduction_algo=='UMAP':
    reducer = UMAP(n_components=2, n_neighbors=15, min_dist=.25, n_jobs=-1)
%time features_2d = reducer.fit_transform(X_emb)

df_2d = pd.DataFrame(features_2d, columns=['x','y'])
df_2d = pd.concat([df_2d, metadata], axis=1)
df_2d['im_class'] = pd.Series(labels, dtype=pd.CategoricalDtype(categories=['right_hand', 'left_hand', 'rest', 'feet']))


### Plots

In [None]:
sns.set_theme(style='white', context='paper')
sns_scatterplot(data=df_2d, hue='im_class', alpha=.7, )# edgecolor=['k' if x else 'w' for x in df.is_test])

plt.savefig(f'export/{reduction_algo.lower()}_merged_test-subj-{test_subject}.pdf', bbox_inches='tight')


In [None]:
sns.set_theme(style='white', context='paper')
g = sns.FacetGrid(data=df_2d, col='subject', col_wrap=4, despine=False)
g.map_dataframe(lambda **args: sns.kdeplot(data=df_2d, x="x", y="y", alpha=.5, hue="im_class", levels=5))
g.map_dataframe(sns.scatterplot, x='x', y='y', hue='im_class', alpha=.7, style='session',)# edgecolor=['k' if x else 'w' for x in df.is_test])
g.set(xlabel='', ylabel='', xticks=[], yticks=[])

print('test subject:', test_subject)
test_axis = g.axes_dict[test_subject] 
test_axis.title.set_color('red')
for spine in test_axis.spines.values():
    spine.set_edgecolor('red')

g.savefig(f'export/{reduction_algo.lower()}_splitted_test-subj-{test_subject}.pdf', bbox_inches='tight')


## SANDBOX

In [None]:

from brokenaxes import brokenaxes


In [None]:
sns.set_theme(style='whitegrid', context='paper')
ax = sns.lineplot(data=df, x=samples_col, y=score_col, hue=pipeline_col, style=pipeline_col, dashes=[(1,0),(1,0),(4,1)], ) #, aspect=1.3, )
# d controls the angle of the break, 1 is 45 degrees I think.
d = 1
kwargs = dict(
    marker=[(-1, -d), (1, d)],
    markersize=5,
    linestyle="-",
    color="w",
    mec="#000000",
    mew=1.2,
    clip_on=False,
    zorder=100,
)

xbreak = (100,110)
ylim = ax.get_ylim()
# you need to set the numerical position of the break manually, in this example between 6.4 and 6.6
ax.plot(xbreak, [ylim[0], ylim[0]], **kwargs)
ax.plot(xbreak, [ylim[1], ylim[1]], **kwargs)
ax.set_ylim(ylim)

plt.show()

In [None]:
fig, ax = plt.subplots(1, 1)
ax.set_xlim(0.,100.)
ax = sns.lineplot(data=df, x=samples_col, y=score_col, hue=pipeline_col, style=pipeline_col, dashes=[(1,0),(1,0),(4,1)], ax=ax) #, aspect=1.3, )
ax = sns.lineplot(data=df[df[samples_col]>50], x=samples_col, y=score_col, hue=pipeline_col, style=pipeline_col, ax=ax, legend=False, markers=['o','o','o'], dashes=False, err_style="bars", ) #, aspect=1.3, )
plt.savefig('TEMP.pdf')
