# Main Body

## Import packages and set image styles

In [None]:
import glob
import os
import pickle
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.patches  import Rectangle, Patch
from matplotlib.lines    import Line2D
from matplotlib          import gridspec
from matplotlib.ticker   import MultipleLocator
from matplotlib          import path_effects

from scipy import stats
import seaborn as sns
import math

import warnings 
warnings.filterwarnings("ignore", category = FutureWarning)
warnings.filterwarnings("ignore", category = DeprecationWarning)

root_path = "PDClassification/Results/"
model_list = ["xeg",   "egn",   "shn",   "atc",    "con",    "etr",    "dcn",     "res"] 
param_list = ["245", "2609", "57441", "146629", "191153", "210561", "287901", "1337665"]

In [None]:
style = "light"

custom_params_dark = {
    'figure.facecolor': 'white',
    'axes.labelcolor': '.15',
    'xtick.direction': 'out',
    'ytick.direction': 'out',
    'xtick.color': '.15',
    'ytick.color': '.15',
    'axes.axisbelow': True,
    'grid.linestyle': '-',
    'text.color': '.15',
    'font.family': ['sans-serif'],
    'font.sans-serif': [
        'Arial',
        'DejaVu Sans',
        'Liberation Sans',
        'Bitstream Vera Sans',
        'sans-serif'
    ],
    'lines.solid_capstyle': 'round',
    'patch.edgecolor': 'w',
    'patch.force_edgecolor': True,
    'image.cmap': 'rocket',
    'xtick.top': False,
    'ytick.right': False,
    'axes.grid': True,
    'axes.facecolor': '#EAEAF2',
    'axes.edgecolor': '#0072b2',
    'grid.color': 'white',
    'axes.spines.left': True,
    'axes.spines.bottom': True,
    'axes.spines.right': True,
    'axes.spines.top': True,
    'xtick.bottom': False,
    'ytick.left': False
}
custom_params_light = {
    "axes.spines.right": False,
    "axes.spines.top": False,
    'axes.grid': True,
    'grid.linestyle': '-',
    'grid.color': 'lightgray',
}
if style == "dark":
    folder = "DarkTheme"
    sns.set_style("darkgrid", rc = custom_params_dark)
elif style == "light":
    folder = "LightTheme"
    sns.set_theme(style="ticks", rc=custom_params_light)
sns.set_context("paper", font_scale=1.5) 

In [None]:
def add_median_labels(
    ax: plt.Axes,
    fmt: str = ".1f",
    fontsize = None,
    lim = [-1*np.inf, np.inf],
    iqr_values = None,
    iqr_offset = 1,
    fmt2: str = ".2f",
) -> None:
    """Add text labels to the median lines of a seaborn boxplot.

    Args:
        ax: plt.Axes, e.g. the return value of sns.boxplot()
        fmt: format string for the median value
    """
    lines = ax.get_lines()
    boxes = [c for c in ax.get_children() if "Patch" in str(c)]
    start = 4
    if not boxes:  # seaborn v0.13 => fill=False => no patches => +1 line
        boxes = [c for c in ax.get_lines() if len(c.get_xdata()) == 5]
        start += 1
    lines_per_box = len(lines) // len(boxes)
    for n, median in enumerate(lines[start::lines_per_box]):
        x, y = (data.mean() for data in median.get_data())
        # choose value depending on horizontal or vertical plot orientation
        value = x if len(set(median.get_xdata())) == 1 else y
        if iqr_values is not None:
            y = y + iqr_offset
        if (value > lim[0]) and (value < lim[1]):
            if fontsize is not None:
                text = ax.text(x, y, f'{value:{fmt}}', ha='center', va='center',
                               fontweight='bold', color='white', fontsize = fontsize)
            else:
                text = ax.text(x, y, f'{value:{fmt}}', ha='center', va='center',
                               fontweight='bold', color='white')
            # create median-colored border around white text for contrast
            text.set_path_effects([
                path_effects.Stroke(linewidth=3, foreground=median.get_color()),
                path_effects.Normal(),
            ])
            if iqr_values is not None:
                value = iqr_values[n]
                if fontsize is not None:
                    text = ax.text(x, y-2*iqr_offset-iqr_offset/4, f'{value:{fmt2}}', ha='center', va='center',
                                   fontweight='bold', color='white', fontsize = fontsize)
                else:
                    text = ax.text(x, y-2*iqr_offset-iqr_offset/4, f'{value:{fmt2}}', ha='center', va='center',
                                   fontweight='bold', color='white')
                # create median-colored border around white text for contrast
                text.set_path_effects([
                    path_effects.Stroke(linewidth=3, foreground=median.get_color()),
                    path_effects.Normal(),
                ])

In [None]:
def get_full_name(name):
    if name == "egn":
        full_name = "EEGNet"
    elif name == "shn":
        full_name = "ShallowNet"
    elif name == "xeg":
        full_name = "xEEGNet"
    elif name == "dcn":
        full_name = "DeepConvNet"
    elif name == "hyb":
        full_name = "HybridNet"
    elif name == "con":
        full_name = "EEGConformer"
    elif name == "atc":
        full_name = "ATCNet"
    elif name == "etr":
        full_name = "TransformEEG"
    elif name == "ps4":
        full_name = "PSDNet"
    elif name == "res":
        full_name = "EEGResNet"
    else:
        raise ValueError("Wrong name")
    return full_name

full_names = [get_full_name(i) for i in model_list]

In [None]:
def QCD(x):
    first, third = np.percentile(x*100, [25, 75])
    return (third-first)/(first+third)

def range_5_95(x):
    five, ninetyfive = np.percentile(x, [5, 95])
    return ninetyfive - five

def range_1_99(x):
    one, ninetynine = np.percentile(x, [1, 99])
    return ninetynine - one
    
def get_aug_idx(augmentation_to_idx):
    i = augmentation_to_idx
    if i == 'flip_horizontal':
        return 0
    elif i == 'flip_vertical':
        return 1
    elif i == 'add_band_noise':
        return 2
    elif i == 'add_eeg_artifact':
        return 3
    elif i == 'add_noise_snr':
        return 4
    elif i == 'channel_dropout':
        return 5
    elif i == 'masking':
        return 6
    elif i == 'warp_signal':
        return 7
    elif i == 'random_FT_phase':
        return 8
    elif i == 'phase_swap':
        return 9

def get_aug_score_matrix(perf, baseline):
    base_med = np.median(baseline)
    base_iqr = stats.iqr(baseline)
    med = np.median(perf, axis=-1)
    iqr = stats.iqr(perf, axis=-1)
    mask = np.logical_or((med-base_med)<=0, (base_iqr-iqr)<=0 )
    aug_coeff = ((med-base_med)/base_med)*((base_iqr-iqr)/base_iqr)
    aug_coeff[mask]=0
    return aug_coeff

In [None]:
def get_result_dict(
    model_list,
    root_path,
    folder_path,
    patience=20,
    metric='accuracy_weighted',
    start_string= "*",
    end_string="*",
    outer_position=4,
    inner_position=5,
    verbose = False,
):
    data_base={i: None for i in model_list}
    for m in model_list:
        win_05 = np.zeros(100)
        win_th = np.zeros(100)
        sub_05 = np.zeros(100)
        sub_th = np.zeros(100)
        epochs = np.zeros(100)
        trains = np.zeros( (100,300) )
        valids = np.zeros( (100,300) )
        
        glob_name = f"{root_path}{folder_path}/{start_string}_{m}_{end_string}"
        if verbose:
            print(glob_name)
        path_list = glob.glob(glob_name)
        path_list = sorted(
            path_list, 
            key = lambda x: (
                int(x.split(os.sep)[-1].split('_')[outer_position]),
                int(x.split(os.sep)[-1].split('_')[inner_position]),
            )
        )
        if len(path_list)>100:
            print(path_list)
        if len(path_list)<100 and verbose:
            print(f"Discarding model {get_full_name(m)}. Not enough trainings.")
            data_base.pop(m, None)
        else:
            for n, p in enumerate(path_list):
                with open(p,'rb') as scorefile:
                    scores = pickle.load(scorefile)
            
                bal_acc = scores['th_standard'][metric]
                win_05[n] = bal_acc
                
                bal_acc_roc = scores['th_corrected'][metric]
                win_th[n] = bal_acc_roc
                
                bal_acc_sub = scores['subject'][metric]
                sub_05[n] = bal_acc_sub

                try:
                    bal_acc_sub_roc = scores['subject_corrected'][metric]
                    sub_th[n] = bal_acc_sub_roc
                except Exception:
                    sub_th[n] = 0.0
        
                for j in scores['loss_progression'].keys():
                    if scores['loss_progression'][j][0] is None:
                        break
                    else:
                        trains[n, j] = scores['loss_progression'][j][0]
                        valids[n, j] = scores['loss_progression'][j][1]
                epochs[n] = j-patience+1
        
            data_base[m] = {
                "win_05": win_05,
                "win_th": win_th,
                "sub_05": sub_05,
                "sub_th": sub_th,
                "epochs": epochs,
                "trains": trains,
                "valids": valids,
            }
    return data_base

## Window and overlap effects

In [None]:
win_and_over={'ShallowNet': None, 'TransformEEG': None}
for m in ['shn', 'etr']:
    dict_key = 'ShallowNet' if m == 'shn' else 'TransformEEG'
    win_05 = np.zeros((5,4,100))
    win_th = np.zeros((5,4,100))
    sub_05 = np.zeros((5,4,100))
    sub_th = np.zeros((5,4,100))
    for j, w in enumerate(['002', '004', '008', '016', '032']):
        for k, o in enumerate(['000', '025', '050', '075']):
            glob_name = f"{root_path}extra_window_and_overlap/*_{m}_*_{w}_{o}_*"
            path_list = glob.glob(glob_name)
            for n, p in enumerate(path_list):
                
                with open(p,'rb') as scorefile:
                    scores = pickle.load(scorefile)
            
                bal_acc = scores['th_standard']['accuracy_weighted']
                win_05[j, k, n] = bal_acc
                
                bal_acc_roc = scores['th_corrected']['accuracy_weighted']
                win_th[j, k, n] = bal_acc_roc
                
                bal_acc_sub = scores['subject']['accuracy_weighted']
                sub_05[j, k, n] = bal_acc_sub
                
                bal_acc_sub_roc = scores['subject_corrected']['accuracy_weighted']
                sub_th[j, k, n] = bal_acc_sub_roc

    win_and_over[dict_key] = {
        "win_05": win_05,
        "win_th": win_th,
        "sub_05": sub_05,
        "sub_th": sub_th
    }

In [None]:
font         = 25
fliercolor   = 'gray'
linecolor    = '#137'
scattercolor = "#0072b2"
boxcolor     = '#e69f00'
letters      = ["A", "B"]

fig, ax = plt.subplots(2, 1, figsize=(19, 7.75*2+1))
for i, model in enumerate(['ShallowNet','TransformEEG']):
    
    df = pd.DataFrame(np.reshape(win_and_over[model]["win_05"], (20,100), ('F')))
    df.insert(loc=0, column='window', value=[i for _ in range(4) for i in [2,4,8,16,32]])
    df.insert(loc=1, column='overlap', value=[i for i in [0,25,50,75] for _ in range(5)])
    model_df = pd.melt(df, id_vars=['window', 'overlap'], value_name='Metric')
    model_df['Metric'] = model_df['Metric']*100
    
    sns.stripplot(
        x         = 'window',
        y         = 'Metric',
        data      = model_df,
        legend    = False,
        linewidth = 1,
        hue       = 'overlap',
        dodge     = True,
        ax        = ax[i],
        size      = 6,
        palette   = sns.color_palette("colorblind")[0:4], #["#56b4e9", "#e69f00"],
        alpha    = 0.9
    )
    
    sns.boxplot(
        data=model_df,
        x="window",
        y="Metric",
        hue = "overlap",
        ax=ax[i],
        fill=True,
        showfliers=False,
        linecolor = linecolor,
        flierprops = dict(
            marker='o',
            markerfacecolor=fliercolor,
            linestyle='none',
            markeredgecolor=fliercolor
        ),
        boxprops=dict(alpha=.8),
        palette=sns.color_palette("colorblind")[0:4]
    ) 
    iqr_values = df.iloc[:,2:].apply(lambda x: stats.iqr(x.values), axis=1)*100
    add_median_labels(ax[i], '.2f', fontsize=font-12, lim=[50,100], iqr_values=iqr_values, iqr_offset = 0.8 )
    ax[i].set_yticks([i*5 for i in range(8, 20)])
    ax[i].set_title( f'Model: {model}',fontsize = font+3, pad=12)
    if i>0:
        ax[i].set_xlabel('Window Length [s]', fontsize = font, labelpad=12)
    else:
        ax[i].xaxis.label.set_visible(False)
    ax[i].set_ylabel('Balanced Accuracy %', fontsize = font)
    ax[i].set_ylim(55,96)
    ax[i].tick_params(axis='both', which='major', labelsize=font-5)
    ax[i].text(4.25,96,f'$({letters[i]})$',fontsize = font+6)
    for axis in ['top','bottom','left','right']:
        ax[i].spines[axis].set_linewidth(1.5)
    ax[i].legend(
        ["overlap:   0%", "overlap: 25%", "overlap: 50%", "overlap: 75%"],
        fontsize = font - 10,
        loc = "upper left"
    )
fig.suptitle( f'Window length and overlap comparison',fontsize = font+8)
plt.subplots_adjust(top=0.9, hspace=0.25, wspace=0.25)
plt.savefig(f"Images/window_and_overlap.pdf", bbox_inches='tight')
plt.show()

## Data augmentation Choice

In [None]:
aug_list = [
     'flip_horizontal',
     'flip_vertical',
     'add_band_noise',
     'add_eeg_artifact',
     'add_noise_snr',
     'channel_dropout',
     'masking',
     'warp_signal',
     'random_FT_phase',
     'phase_swap',
]
data_aug={'TransformEEG': None}
for m in ['etr']:
    dict_key = 'ShallowNet' if m == 'shn' else 'TransformEEG'
    win_05 = np.zeros((10, 10, 100))
    win_th = np.zeros((10, 10, 100))
    sub_05 = np.zeros((10, 10, 100))
    sub_th = np.zeros((10, 10, 100))
    epochs = np.zeros((10, 10, 100))
    for a1 in aug_list:
        idx1 = get_aug_idx(a1)
        sidx1 = str(idx1+1).zfill(3)
        for a2 in aug_list:
            idx2 = get_aug_idx(a2)
            sidx2 = str(idx2+1).zfill(3)
            glob_name = f"{root_path}data_augmentation/*_{m}_*_{sidx1}_{sidx2}_000250*"
            
            # This is the selected combination rerun in another folder
            if m=="etr" and sidx1=="007" and sidx2=="001":
                glob_name = f"{root_path}baseline_with_aug/*_{m}_*_{sidx1}_{sidx2}_000250*"
            path_list = glob.glob(glob_name)
            for n, p in enumerate(path_list):
                
                with open(p,'rb') as scorefile:
                    scores = pickle.load(scorefile)
            
                bal_acc = scores['th_standard']['accuracy_weighted']
                win_05[idx1, idx2, n] = bal_acc
                
                bal_acc_roc = scores['th_corrected']['accuracy_weighted']
                win_th[idx1, idx2, n] = bal_acc_roc
                
                bal_acc_sub = scores['subject']['accuracy_weighted']
                sub_05[idx1, idx2, n] = bal_acc_sub
                
                bal_acc_sub_roc = scores['subject_corrected']['accuracy_weighted']
                sub_th[idx1, idx2, n] = bal_acc_sub_roc

                for j in scores['loss_progression'].keys():
                    if scores['loss_progression'][j][0] is None:
                        break
                epochs[idx1, idx2, n] = j-15+1
    
    data_aug[dict_key] = {
        "win_05": win_05,
        "win_th": win_th,
        "sub_05": sub_05,
        "sub_th": sub_th,
        "epochs": epochs
    }

In [None]:
font         = 25
fliercolor   = 'gray'
linecolor    = '#137'
scattercolor = "#0072b2"
boxcolor     = '#e69f00'
letters      = ["A", "B"]
th_name      = "win_05"
aug_list_full = [
     'Time reverse',
     'Sign flip',
     'Band noise',
     'Signal drift',
     'SNR scaling',
     'Channel dropout',
     'Masking',
     'Signal warp',
     'Phase randomizer',
     'Phase swap',
]

fig, ax = plt.subplots(1, 2, figsize=(23, 8.75))
for i, model in enumerate(['TransformEEG']): #['ShallowNet','TransformEEG']
    
    row = i//2
    col = i%2
    
    htmp = sns.heatmap(
        #data = np.random.normal(loc=75, scale=4, size=(10,10)),
        data = np.median(data_aug[model][th_name]*100, axis=-1),
        ax = ax[col],
        cbar = True,
        robust = True,
        annot = True,
        fmt = '.1f',
        linewidth = .75,
        vmin = 74,
        vmax = 81,
        xticklabels = aug_list_full,
        yticklabels = aug_list_full,
        cmap = sns.cubehelix_palette(as_cmap=True),#sns.color_palette("ch:s=-.2,r=.6", as_cmap=True), #"Blues"
        cbar_kws = {"pad":0.02},
        annot_kws = {"size": font-13, "fontweight": "bold"}
        
    )
    colorbar = htmp.collections[0].colorbar
    colorbar.set_label("Balanced Accuracy - Median", fontsize=font-8)
    colorbar.ax.yaxis.label.set_size(font-7)  # Set the fontsize for the colorbar label
    colorbar.ax.tick_params(labelsize=font-7)
    ax[col].set_xlabel('2$^{\\text{nd}}$ augmentation', fontsize = font)
    ax[col].set_ylabel('1$^{\\text{st}}$ augmentation', fontsize = font)
    ax[col].tick_params(axis='both', which='major', labelsize=font-8)
    ax[col].set_xticklabels(ax[col].get_xticklabels(), rotation=45, ha='right')
    #ax[col].set_title( f'Model: {model}',fontsize = font+3, pad=12)

    htmp2 = sns.heatmap(
        #data = np.random.normal(loc=5, scale=1, size=(10,10)),
        data = stats.iqr(data_aug[model][th_name]*100, axis=-1),
        ax   = ax[col+1],
        cbar = True,
        robust =True,
        annot=True,
        fmt  ='.1f',
        linewidth=.75,
        vmin=5,
        vmax= 9.5,
        xticklabels =aug_list_full,
        yticklabels = aug_list_full,
        cmap = sns.cubehelix_palette(reverse=True, as_cmap=True),
        cbar_kws = {"pad":0.02},
        annot_kws={"size": font-13, "fontweight": "bold"}
        
    )
    colorbar = htmp2.collections[0].colorbar
    colorbar.set_label("Balanced Accuracy - Inter-quartile range ", fontsize=font-8)
    colorbar.ax.yaxis.label.set_size(font-7)  # Set the fontsize for the colorbar label
    colorbar.ax.tick_params(labelsize=font-7) 
    ax[col+1].tick_params(axis='both', which='major', labelsize=font-8)
    ax[col+1].set_xticklabels(ax[i].get_xticklabels(), rotation=45, ha='right')
    #ax[col+1].set_title( f'Model: {model}',fontsize = font+3, pad=12)
    ax[col+1].set_xlabel('2$^{\\text{nd}}$ augmentation', fontsize = font)

ax[col].text(-3.5, -0.2,'(A)',fontsize = font+6)
ax[col+1].text(-3.5, -0.2,'(B)',fontsize = font+6)
fig.suptitle(f'Data augmentation comparison - Model: {model}           ',
             fontsize = font+8)
plt.subplots_adjust(top=0.825, hspace=0.25, wspace=0.325)
#plt.savefig(f"Images/data_augmentation_comparison.pdf", bbox_inches='tight')
plt.show()

In [None]:
model, th_name = 'etr', 'win_05'
data_base = get_result_dict(model_list, root_path, "baseline", metric=metric)
base_vec = data_base[model][th_name][data_base[model][th_name]>0]*100
print(f'baseline results for {model}: '
      f'{np.median(base_vec):.3f}, '
      f'{stats.iqr(base_vec):.3f}')
fig, ax = plt.subplots(1, 1, figsize=(23, 8.75))

aug_coeff = get_aug_score_matrix(data_aug["TransformEEG"][th_name]*100, base_vec)
htmp = sns.heatmap(
    #data = np.random.normal(loc=75, scale=4, size=(10,10)),
    data = aug_coeff,
    ax = ax,
    cbar = True,
    robust = True,
    annot = True,
    fmt = '.5f',
    linewidth = .75,
    xticklabels = aug_list_full,
    yticklabels = aug_list_full,
    cmap = sns.cubehelix_palette(as_cmap=True),#sns.color_palette("ch:s=-.2,r=.6", as_cmap=True), #"Blues"
    cbar_kws = {"pad":0.02},
    annot_kws = {"size": font-13, "fontweight": "bold"}
    
)
colorbar = htmp.collections[0].colorbar
colorbar.set_label("Aug score", fontsize=font-8)
colorbar.ax.yaxis.label.set_size(font-7)  # Set the fontsize for the colorbar label
colorbar.ax.tick_params(labelsize=font-7)
ax.set_xlabel('Augmentation 2', fontsize = font)
ax.set_ylabel('Augmentation 1', fontsize = font)
ax.tick_params(axis='both', which='major', labelsize=font-8)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
#ax[col].set_title( f'Model: {model}',fontsize = font+3, pad=12)
plt.show()

## Baseline

In [None]:
metric = "accuracy_weighted"
threshold = "win_05"
data_base = get_result_dict(model_list, root_path, "baseline", metric=metric)

df = pd.DataFrame(
    np.array([data_base[i][threshold] for i in data_base.keys()]).T,
    columns=[get_full_name(i) for i in data_base.keys()]
)
model_df = pd.DataFrame(
    np.concat([data_base[i][threshold] for i in data_base.keys()]),
    columns=["Metric"]
)
model_df.insert(
    loc=1,
    column='Model',
    value=[get_full_name(i) for i in data_base.keys() for _ in range(100)]
)
model_df.insert(
    loc=1,
    column='ModelParam',
    value=[get_full_name(i) + "\n" + param_list[n] for n, i in enumerate(data_base.keys()) for _ in range(100)]
)
model_df['Metric'] = model_df['Metric']*100

In [None]:
font         = 25
fliercolor   = 'gray'
linecolor    = '#137'
scattercolor = "#0072b2"
boxcolor     = '#e69f00'

fig, ax = plt.subplots(figsize=(21, 7.75))
#fig, ax = plt.subplots(figsize=(21, 8.75))
sns.stripplot(
    x         = "Model",
    y         = "Metric",
    data      = model_df,
    legend    = False,
    linewidth = 1,
    ax        = ax,
    size      = 5,
    palette   = ["gray"]*len(data_base.keys()),
    alpha    = .9
)

sns.boxplot(
    data=model_df,
    x="Model",
    y="Metric",
    ax=ax,
    fill=True,
    showfliers=False,
    linecolor = linecolor,
    flierprops = dict(
        marker='o',
        markerfacecolor=fliercolor,
        linestyle='none',
        markeredgecolor=fliercolor
    ),
    boxprops=dict(alpha=.8),
    palette=["lightgray"]*len(data_base.keys())
) 
iqr_values = df.apply(lambda x: stats.iqr(x.values), axis=0)*100
add_median_labels(
    ax, '.2f', fontsize=font-10, lim=[50,100], iqr_values=iqr_values, iqr_offset = 1.1)
ax.set_yticks([i*5 for i in range(8, 20)])
ax.set_title(f'Model comparison - no data augmentation - no threshold correction',
             fontsize = font+3, pad=12)
ax.set_xlabel('Model', fontsize = font, labelpad=12)
ax.set_ylabel('Balanced Accuracy %', fontsize = font)
ax.set_ylim(45,96)
ax.tick_params(axis='both', which='major', labelsize=font-6.5)

new_labels = []
for ele in ax.get_xticklabels():
    if ele.get_text()=='TransformEEG':
        ele.set_fontweight('bold')
    new_labels.append(ele)
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(new_labels)

#ax.text(4.25,96,f'$({letters[i]})$',fontsize = font+6)
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.5)
plt.subplots_adjust(top=0.9, hspace=0.25, wspace=0.25)
#plt.savefig(f"Images/baseline_noaug_nocorr.pdf", bbox_inches='tight')
plt.show()

In [None]:
#print(df.apply(range_5_95, axis=0))
print(df.apply(range_1_99, axis=0))

## Baseline + Augmentation

In [None]:
metric = "accuracy_weighted"
threshold = "win_05"

data_base = get_result_dict(model_list, root_path, "baseline_with_aug", metric=metric)
df = pd.DataFrame(
    np.array([data_base[i][threshold] for i in data_base.keys()]).T,
    columns=[get_full_name(i) for i in data_base.keys()]
)
model_df = pd.DataFrame(
    np.concat([data_base[i][threshold] for i in data_base.keys()]),
    columns=["Metric"]
)
model_df.insert(
    loc=1,
    column='Model',
    value=[get_full_name(i) for i in data_base.keys() for _ in range(100)]
)
model_df.insert(
    loc=1,
    column='ModelParam',
    value=[get_full_name(i) + "\n" + param_list[n] for n, i in enumerate(data_base.keys()) for _ in range(100)]
)
model_df['Metric'] = model_df['Metric']*100

In [None]:
font         = 25
fliercolor   = 'gray'
linecolor    = '#137'
scattercolor = "#0072b2"
boxcolor     = '#e69f00'

trafig, ax = plt.subplots(figsize=(21, 7.75))

sns.stripplot(
    x         = "Model",
    y         = "Metric",
    data      = model_df,
    legend    = False,
    linewidth = 1,
    ax        = ax,
    size      = 5,
    #palette   = ["gray"]*len(data_base.keys()),
    palette   = sns.color_palette("colorblind")[:len(data_base.keys())],
    alpha     = .9
)

sns.boxplot(
    data=model_df,
    x="Model",
    y="Metric",
    ax=ax,
    fill=True,
    showfliers=False,
    linecolor = linecolor,
    flierprops = dict(
        marker='o',
        markerfacecolor=fliercolor,
        linestyle='none',
        markeredgecolor=fliercolor
    ),
    boxprops=dict(alpha=.8),
    #palette=["lightgray"]*len(data_base.keys())
    palette=sns.color_palette("colorblind")[:len(data_base.keys())]
) 
iqr_values = df.apply(lambda x: stats.iqr(x.values), axis=0)*100
add_median_labels(
    ax, '.2f', fontsize=font-10, lim=[50,100], iqr_values=iqr_values, iqr_offset = 1.1)
ax.set_yticks([i*5 for i in range(8, 20)])
ax.set_title(f'Model comparison with data augmentation - no threshold correction',
             fontsize = font+3, pad=12)
ax.set_xlabel('Model', fontsize = font, labelpad=12)
ax.set_ylabel('Balanced Accuracy %', fontsize = font)
ax.set_ylim(45,96)
ax.tick_params(axis='both', which='major', labelsize=font-6.5)

new_labels = []
for ele in ax.get_xticklabels():
    if ele.get_text()=='TransformEEG':
        ele.set_fontweight('bold')
    new_labels.append(ele)
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(new_labels)

ax.text(-0.35,51.25,f'$(C)$',fontsize = font+6)
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.5)
plt.subplots_adjust(top=0.9, hspace=0.25, wspace=0.25)
plt.savefig(f"Images/baseline_aug_nocorr_with_letter.pdf", bbox_inches='tight')
plt.show()

In [None]:
for mdl_name in data_base.keys():
    if mdl_name != 'etr':
        print("Levene against ", mdl_name)
        print(stats.levene(data_base['etr']['win_05'], data_base[mdl_name]['win_05']))
    print(stats.shapiro(data_base[mdl_name]['win_05']), '\n')

In [None]:
#print(df.apply(range_5_95, axis=0))
print(df.apply(range_1_99, axis=0))

## Number of epochs change

In [None]:
metric = "accuracy_weighted"
threshold = "win_05"
data_base = get_result_dict(model_list, root_path, "baseline_with_aug", metric=metric)

data_base_1 = get_result_dict(model_list, root_path, "baseline")
df = pd.DataFrame(
    np.array([data_base_1[i]["epochs"] for i in data_base_1.keys()]).T,
    columns=[get_full_name(i) for i in data_base_1.keys()]
)
model_df_1 = pd.DataFrame(
    np.concat([data_base_1[i]["epochs"] for i in data_base_1.keys()]),
    columns=["Epochs"]
)
model_df_1.insert(
    loc=1,
    column='Model',
    value=[get_full_name(i) for i in data_base_1.keys() for _ in range(100)]
    #value=[get_full_name(i) + "\n" + param_list[n] for n, i in enumerate(data_base.keys()) for _ in range(100)]
)
model_df_1.insert(
    loc=2,
    column='Pipeline',
    value=["No Augmentation" for i in data_base_1.keys() for _ in range(100)]
)

data_base_2 = get_result_dict(model_list, root_path, "baseline_with_aug")
df2 = pd.DataFrame(
    np.array([data_base_2[i]["epochs"] for i in data_base_2.keys()]).T,
    columns=[get_full_name(i) for i in data_base_2.keys()]
)
model_df_2 = pd.DataFrame(
    np.concat([data_base_2[i]["epochs"] for i in data_base_2.keys()]),
    columns=["Epochs"]
)
model_df_2.insert(
    loc=1,
    column='Model',
    value=[get_full_name(i) for i in data_base_2.keys() for _ in range(100)]
    #value=[get_full_name(i) + "\n" + param_list[n] for n, i in enumerate(data_base.keys()) for _ in range(100)]
)
model_df_2.insert(
    loc=2,
    column='Pipeline',
    value=["Augmentation" for i in data_base_2.keys() for _ in range(100)]
)
model_df = pd.concat( [model_df_1, model_df_2] )
model_df_2["EpochsDiff"] = model_df_2["Epochs"]  - model_df_1["Epochs"]

In [None]:
font         = 25
fliercolor   = 'gray'
linecolor    = '#137'
scattercolor = "#0072b2"
boxcolor     = '#e69f00'

fig, ax = plt.subplots(figsize=(21, 7.75))

sns.stripplot(
    x         = "Model",
    y         = "Epochs",
    data      = model_df_1,
    legend    = False,
    linewidth = 1,
    ax        = ax,
    size      = 4,
    palette   = ["gray"]*len(data_base.keys()),
    alpha    = .9
)

sns.boxplot(
    data=model_df_1,
    x="Model",
    y="Epochs",
    ax=ax,
    fill=True,
    showfliers=False,
    linecolor = linecolor,
    flierprops = dict(
        marker='o',
        markerfacecolor=fliercolor,
        linestyle='none',
        markeredgecolor=fliercolor
    ),
    boxprops=dict(alpha=.8),
    palette=["lightgray"]*len(data_base.keys())
) 
iqr_values = df.apply(lambda x: stats.iqr(x.values), axis=0)*100
ax.set_yticks([i*10 for i in range(20)])
ax.set_title(f'Training length without early stopping patience',
             fontsize = font+3, pad=12)
ax.set_xlabel('Model', fontsize = font, labelpad=12)
ax.set_ylabel('Number of epochs', fontsize = font)
ax.set_ylim(0,95)
ax.tick_params(axis='both', which='major', labelsize=font-5.1)
#ax.text(4.25,96,f'$({letters[i]})$',fontsize = font+6)
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.5)
plt.subplots_adjust(top=0.9, hspace=0.25, wspace=0.25)
#plt.savefig(f"Images/window_and_overlap.pdf", bbox_inches='tight')
plt.show()

## Threshold correction

In [None]:
metric = "accuracy_weighted"
threshold = "win_th"

data_base = get_result_dict(model_list, root_path, "baseline_with_aug", metric=metric)
df = pd.DataFrame(
    np.array([data_base[i][threshold] for i in data_base.keys()]).T,
    columns=[get_full_name(i) for i in data_base.keys()]
)
model_df = pd.DataFrame(
    np.concat([data_base[i][threshold] for i in data_base.keys()]),
    columns=["Metric"]
)
model_df.insert(
    loc=1,
    column='Model',
    value=[get_full_name(i) for i in data_base.keys() for _ in range(100)]
)
model_df.insert(
    loc=1,
    column='ModelParam',
    value=[get_full_name(i) + "\n" + param_list[n] for n, i in enumerate(data_base.keys()) for _ in range(100)]
)
model_df['Metric'] = model_df['Metric']*100

In [None]:
font         = 25
fliercolor   = 'gray'
linecolor    = '#137'
scattercolor = "#0072b2"
boxcolor     = '#e69f00'

trafig, ax = plt.subplots(figsize=(21, 7.75))

sns.stripplot(
    x         = "Model",
    y         = "Metric",
    data      = model_df,
    legend    = False,
    linewidth = 1,
    ax        = ax,
    size      = 5,
    #palette   = ["gray"]*len(data_base.keys()),
    palette   = sns.color_palette("colorblind")[:len(data_base.keys())],
    alpha     = .9
)

sns.boxplot(
    data=model_df,
    x="Model",
    y="Metric",
    ax=ax,
    fill=True,
    showfliers=False,
    linecolor = linecolor,
    flierprops = dict(
        marker='o',
        markerfacecolor=fliercolor,
        linestyle='none',
        markeredgecolor=fliercolor
    ),
    boxprops=dict(alpha=.8),
    #palette=["lightgray"]*len(data_base.keys())
    palette=sns.color_palette("colorblind")[:len(data_base.keys())]
) 
iqr_values = df.apply(lambda x: stats.iqr(x.values), axis=0)*100
add_median_labels(
    ax, '.2f', fontsize=font-10, lim=[50,100], iqr_values=iqr_values, iqr_offset = 1.1)
ax.set_yticks([i*5 for i in range(8, 20)])
ax.set_title(f'Model comparison with data augmentation and threshold correction',
             fontsize = font+3, pad=12)
ax.set_xlabel('Model', fontsize = font, labelpad=12)
ax.set_ylabel('Balanced Accuracy %', fontsize = font)
ax.set_ylim(45,96)
ax.tick_params(axis='both', which='major', labelsize=font-5.1)
ax.text(-0.35,51.25,f'$(C)$',fontsize = font+6)
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.5)
plt.subplots_adjust(top=0.9, hspace=0.25, wspace=0.25)
plt.savefig(f"Images/baseline_aug_corr.pdf", bbox_inches='tight')
plt.show()

In [None]:
print(df.apply(range_1_99, axis=0))

## Baseline with only 2 datasets

In [None]:
metric = "accuracy_weighted"
threshold = "win_05"

data_base = get_result_dict(model_list, root_path, "two_dataset_baseline", metric=metric)
df = pd.DataFrame(
    np.array([data_base[i][threshold] for i in data_base.keys()]).T,
    columns=[get_full_name(i) for i in data_base.keys()]
)
model_df = pd.DataFrame(
    np.concat([data_base[i][threshold] for i in data_base.keys()]),
    columns=["Metric"]
)
model_df.insert(
    loc=1,
    column='Model',
    value=[get_full_name(i) for i in data_base.keys() for _ in range(100)]
)
model_df.insert(
    loc=1,
    column='ModelParam',
    value=[get_full_name(i) + "\n" + param_list[n] for n, i in enumerate(data_base.keys()) for _ in range(100)]
)
model_df['Metric'] = model_df['Metric']*100

In [None]:
font         = 25
fliercolor   = 'gray'
linecolor    = '#137'
scattercolor = "#0072b2"
boxcolor     = '#e69f00'
name         = "win_05"

fig, ax = plt.subplots(figsize=(21, 7.75))

sns.stripplot(
    x         = "Model",
    y         = "Metric",
    data      = model_df,
    legend    = False,
    linewidth = 1,
    ax        = ax,
    size      = 5,
    palette   = ["gray"]*len(data_base.keys()),
    alpha    = .9
)

sns.boxplot(
    data=model_df,
    x="Model",
    y="Metric",
    ax=ax,
    fill=True,
    showfliers=False,
    linecolor = linecolor,
    flierprops = dict(
        marker='o',
        markerfacecolor=fliercolor,
        linestyle='none',
        markeredgecolor=fliercolor
    ),
    boxprops=dict(alpha=.8),
    palette=["lightgray"]*len(data_base.keys())
) 
iqr_values = df.apply(lambda x: stats.iqr(x.values), axis=0)*100
add_median_labels(
    ax, '.2f', fontsize=font-10, lim=[50,100], iqr_values=iqr_values, iqr_offset = 1.1)
ax.set_yticks([i*5 for i in range(8, 21)])
ax.set_title(f'Model comparison - two dataset - no data augmentation - no threshold correction',
             fontsize = font+3, pad=12)
ax.set_xlabel('Model', fontsize = font, labelpad=12)
ax.set_ylabel('Balanced Accuracy %', fontsize = font)
ax.set_ylim(40,101)
ax.tick_params(axis='both', which='major', labelsize=font-4)
#ax.text(4.25,96,f'$({letters[i]})$',fontsize = font+6)
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.5)
plt.subplots_adjust(top=0.9, hspace=0.25, wspace=0.25)
#plt.savefig(f"Images/two_dataset_baseline_noaug.pdf", bbox_inches='tight')
plt.show()
print(df.apply(range_1_99, axis=0)*100)

In [None]:
metric = "accuracy_weighted"
threshold = "win_05"

data_base = get_result_dict(model_list, root_path, "two_dataset_baseline_with_aug")
df = pd.DataFrame(
    np.array([data_base[i][threshold] for i in data_base.keys()]).T,
    columns=[get_full_name(i) for i in data_base.keys()]
)
model_df = pd.DataFrame(
    np.concat([data_base[i][threshold] for i in data_base.keys()]),
    columns=["Metric"]
)
model_df.insert(
    loc=1,
    column='Model',
    value=[get_full_name(i) for i in data_base.keys() for _ in range(100)]
)
model_df.insert(
    loc=1,
    column='ModelParam',
    value=[get_full_name(i) + "\n" + param_list[n] for n, i in enumerate(data_base.keys()) for _ in range(100)]
)
model_df['Metric'] = model_df['Metric']*100

In [None]:
font         = 25
fliercolor   = 'gray'
linecolor    = '#137'
scattercolor = "#0072b2"
boxcolor     = '#e69f00'
name         = "win_05"

fig, ax = plt.subplots(figsize=(21, 7.75))

sns.stripplot(
    x         = "Model",
    y         = "Metric",
    data      = model_df,
    legend    = False,
    linewidth = 1,
    ax        = ax,
    size      = 5,
    palette   = ["gray"]*len(data_base.keys()),
    alpha    = .9
)

sns.boxplot(
    data=model_df,
    x="Model",
    y="Metric",
    ax=ax,
    fill=True,
    showfliers=False,
    linecolor = linecolor,
    flierprops = dict(
        marker='o',
        markerfacecolor=fliercolor,
        linestyle='none',
        markeredgecolor=fliercolor
    ),
    boxprops=dict(alpha=.8),
    palette=["lightgray"]*len(data_base.keys())
) 
iqr_values = df.apply(lambda x: stats.iqr(x.values), axis=0)*100
add_median_labels(
    ax, '.2f', fontsize=font-10, lim=[50,100], iqr_values=iqr_values, iqr_offset = 1.1)
ax.set_yticks([i*5 for i in range(8, 21)])

ax.set_xlabel('Model', fontsize = font, labelpad=12)
ax.set_ylabel('Balanced Accuracy %', fontsize = font)
ax.set_ylim(40,101)
ax.tick_params(axis='both', which='major', labelsize=font-4)
#ax.text(4.25,96,f'$({letters[i]})$',fontsize = font+6)
ax.set_title(f'Model comparison - two dataset - data augmentation - no threshold correction',
             fontsize = font+3, pad=12)
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.5)
plt.subplots_adjust(top=0.9, hspace=0.25, wspace=0.25)
#plt.savefig(f"Images/two_dataset_baseline_with_aug.pdf", bbox_inches='tight')
plt.show()
print(df.apply(range_1_99, axis=0)*100)

In [None]:
metric = "accuracy_weighted"
threshold = "win_th"

data_base = get_result_dict(model_list, root_path, "two_dataset_baseline_with_aug")
df = pd.DataFrame(
    np.array([data_base[i][threshold] for i in data_base.keys()]).T,
    columns=[get_full_name(i) for i in data_base.keys()]
)
model_df = pd.DataFrame(
    np.concat([data_base[i][threshold] for i in data_base.keys()]),
    columns=["Metric"]
)
model_df.insert(
    loc=1,
    column='Model',
    value=[get_full_name(i) for i in data_base.keys() for _ in range(100)]
)
model_df.insert(
    loc=1,
    column='ModelParam',
    value=[get_full_name(i) + "\n" + param_list[n] for n, i in enumerate(data_base.keys()) for _ in range(100)]
)
model_df['Metric'] = model_df['Metric']*100

In [None]:
font         = 25
fliercolor   = 'gray'
linecolor    = '#137'
scattercolor = "#0072b2"
boxcolor     = '#e69f00'
name         = "win_th"

fig, ax = plt.subplots(figsize=(21, 7.75))

sns.stripplot(
    x         = "Model",
    y         = "Metric",
    data      = model_df,
    legend    = False,
    linewidth = 1,
    ax        = ax,
    size      = 5,
    palette   = ["gray"]*len(data_base.keys()),
    alpha    = .9
)

sns.boxplot(
    data=model_df,
    x="Model",
    y="Metric",
    ax=ax,
    fill=True,
    showfliers=False,
    linecolor = linecolor,
    flierprops = dict(
        marker='o',
        markerfacecolor=fliercolor,
        linestyle='none',
        markeredgecolor=fliercolor
    ),
    boxprops=dict(alpha=.8),
    palette=["lightgray"]*len(data_base.keys())
) 
iqr_values = df.apply(lambda x: stats.iqr(x.values), axis=0)*100
add_median_labels(
    ax, '.2f', fontsize=font-10, lim=[50,100], iqr_values=iqr_values, iqr_offset = 1.1)
ax.set_yticks([i*5 for i in range(8, 21)])

ax.set_xlabel('Model', fontsize = font, labelpad=12)
ax.set_ylabel('Balanced Accuracy %', fontsize = font)
ax.set_ylim(40,101)
ax.tick_params(axis='both', which='major', labelsize=font-4)
#ax.text(4.25,96,f'$({letters[i]})$',fontsize = font+6)
ax.set_title(f'Model comparison - two dataset - data augmentation - no threshold correction',
             fontsize = font+3, pad=12)
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.5)
plt.subplots_adjust(top=0.9, hspace=0.25, wspace=0.25)
#plt.savefig(f"Images/two_dataset_baseline_with_aug.pdf", bbox_inches='tight')
plt.show()
print(df.apply(range_1_99, axis=0)*100)

# Supplementary

## Window Aggregation

In [None]:
metric = "accuracy_weighted"
threshold = "sub_05"

data_base = get_result_dict(model_list, root_path, "baseline_with_aug", metric=metric)
df = pd.DataFrame(
    np.array([data_base[i][threshold] for i in data_base.keys()]).T,
    columns=[get_full_name(i) for i in data_base.keys()]
)
model_df = pd.DataFrame(
    np.concat([data_base[i][threshold] for i in data_base.keys()]),
    columns=["Metric"]
)
model_df.insert(
    loc=1,
    column='Model',
    value=[get_full_name(i) for i in data_base.keys() for _ in range(100)]
)
model_df.insert(
    loc=1,
    column='ModelParam',
    value=[get_full_name(i) + "\n" + param_list[n] for n, i in enumerate(data_base.keys()) for _ in range(100)]
)
model_df['Metric'] = model_df['Metric']*100

In [None]:
font         = 25
fliercolor   = 'gray'
linecolor    = '#137'
scattercolor = "#0072b2"
boxcolor     = '#e69f00'

trafig, ax = plt.subplots(figsize=(21, 7.75))

sns.stripplot(
    x         = "Model",
    y         = "Metric",
    data      = model_df,
    legend    = False,
    linewidth = 1,
    ax        = ax,
    size      = 5,
    #palette   = ["gray"]*len(data_base.keys()),
    palette   = sns.color_palette("colorblind")[:len(data_base.keys())],
    alpha     = .9
)

sns.boxplot(
    data=model_df,
    x="Model",
    y="Metric",
    ax=ax,
    fill=True,
    showfliers=False,
    linecolor = linecolor,
    flierprops = dict(
        marker='o',
        markerfacecolor=fliercolor,
        linestyle='none',
        markeredgecolor=fliercolor
    ),
    boxprops=dict(alpha=.8),
    #palette=["lightgray"]*len(data_base.keys())
    palette=sns.color_palette("colorblind")[:len(data_base.keys())]
) 
iqr_values = df.apply(lambda x: stats.iqr(x.values), axis=0)*100
add_median_labels(
    ax, '.2f', fontsize=font-10, lim=[50,100], iqr_values=iqr_values, iqr_offset = 1.1)
ax.set_yticks([i*5 for i in range(8, 20)])
ax.set_title(f'Model comparison with data augmentation and window aggregation',
             fontsize = font+3, pad=12)
ax.set_xlabel('Model', fontsize = font, labelpad=12)
ax.set_ylabel('Balanced Accuracy %', fontsize = font)
ax.set_ylim(45,96)
ax.tick_params(axis='both', which='major', labelsize=font-5.1)
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.5)
plt.subplots_adjust(top=0.9, hspace=0.25, wspace=0.25)
plt.savefig(f"Images/aggregation.pdf", bbox_inches='tight')
plt.show()

## CSP

In [None]:
csp_scores = {"no_aug": None, "aug": None}
Nfilter = 15
for aug in ['000_000', '007_001']:
    win_05 = np.zeros((Nfilter, 100))
    win_th = np.zeros_like(win_05)
    sub_05 = np.zeros_like(win_05)
    sub_th = np.zeros_like(win_05)
    epochs = np.zeros_like(win_05)
    for csp in range(Nfilter+1):
        csp_str = str(csp+1).zfill(3)
        glob_name = f"{root_path}csp/*_{csp_str}_{aug}_*000250*"
        path_list = glob.glob(glob_name)
        for n, p in enumerate(path_list):
            
            with open(p,'rb') as scorefile:
                scores = pickle.load(scorefile)
        
            bal_acc = scores['th_standard']['accuracy_weighted']
            win_05[csp, n] = bal_acc
            
            bal_acc_roc = scores['th_corrected']['accuracy_weighted']
            win_th[csp, n] = bal_acc_roc
            
            bal_acc_sub = scores['subject']['accuracy_weighted']
            sub_05[csp, n] = bal_acc_sub
            
            bal_acc_sub_roc = scores['subject_corrected']['accuracy_weighted']
            sub_th[csp, n] = bal_acc_sub_roc

            for j in scores['loss_progression'].keys():
                if scores['loss_progression'][j][0] is None:
                    break
            epochs[csp, n] = j-20+1

    dict_key = "no_aug" if aug == "000_000" else "aug"
    csp_scores[dict_key] = {
        "win_05": win_05,
        "win_th": win_th,
        "sub_05": sub_05,
        "sub_th": sub_th,
        "epochs": epochs
    }

In [None]:
font         = 25
fliercolor   = 'gray'
linecolor    = '#137'
scattercolor = "#0072b2"
boxcolor     = '#e69f00'
threshold    = "win_05"
letters      = ["A", "B"]

fig, ax = plt.subplots(2, 1, figsize=(19, 7.75*2+1))
for n, k in enumerate(["no_aug", "aug"]):
    
    csp_df= pd.DataFrame(
        csp_scores[k][threshold].T*100,
        columns=[str(i) for i in range(1,Nfilter+1)]
    )
    
    sns.stripplot(
        data      = csp_df,
        legend    = False,
        linewidth = 1,
        ax        = ax[n],
        size      = 4,
        palette   = ["gray"]*Nfilter,
        alpha     = .9
    )

    sns.boxplot(
        data = csp_df,
        ax=ax[n],
        fill=True,
        showfliers=False,
        linecolor = linecolor,
        flierprops = dict(
            marker='o',
            markerfacecolor=fliercolor,
            linestyle='none',
            markeredgecolor=fliercolor
        ),
        boxprops=dict(alpha=.8),
        palette=["lightgray"]*Nfilter
    )
    if n==0:
        ax[n].set_title( f'without data augmentation',fontsize = font+3, pad=12)
    if n==1:
        ax[n].set_title( f'with data augmentation (masking + flip horizontal)',fontsize = font+3, pad=12)
        ax[n].set_xlabel('Number of CSP Filters', fontsize = font, labelpad=12)
    
    ax[n].set_ylabel('Balanced Accuracy %', fontsize = font)
    ax[n].set_yticks([i*5 for i in range(8, 21)])
    ax[n].set_ylim(45,96)
    ax[n].set_xticks([i for i in range(Nfilter)])
    ax[n].set_xticklabels([i*2+2 for i in range(Nfilter)])
    ax[n].tick_params(axis='both', which='major', labelsize=font-5)
    ax[n].text(-0.25,91,f'$({letters[n]})$',fontsize = font+6)
    for axis in ['top','bottom','left','right']:
        ax[n].spines[axis].set_linewidth(1.5)
    
    iqr_values = csp_df.apply(lambda x: stats.iqr(x.values), axis=0)
    add_median_labels(
        ax[n], '.2f', fontsize=font-12, lim=[45,95],
        iqr_values=iqr_values, iqr_offset = 1.1
    )
    
fig.suptitle( f'TransformEEG performance with Common Spatial Patterns (CSP)',fontsize = font+8)
plt.subplots_adjust(top=0.9, hspace=0.25, wspace=0.25)
plt.savefig(f"Images/CSP_analysis.pdf", bbox_inches='tight')
plt.show()

## PSD Addition

In [None]:
metric = "accuracy_weighted"
threshold = "win_th"

data_base = get_result_dict(model_list, root_path, "baseline_with_aug", metric=metric)
data_base_1 = get_result_dict(
    ["ps3"], root_path, "psdnet", start_string="late*",
    outer_position=7, inner_position=8
)
data_base_2 = get_result_dict(
    ["ps3"], root_path, "psdnet", start_string="mid_cat*",
    outer_position=7, inner_position=8
)
data_base_3 = get_result_dict(
    ["ps3"], root_path, "psdnet", start_string="mid_dec*",
    outer_position=7, inner_position=8
)

df = pd.DataFrame(
    np.array([
        data_base["etr"][threshold],
        data_base_1["ps3"][threshold],
        data_base_2["ps3"][threshold],
        data_base_3["ps3"][threshold]
    ]).T,
    columns=[
        '$TransformEEG$',
        "$TransformEEG_{LateFusion}$",
        "$TransformEEG_{MidFusion}$",
        "$TransformEEG_{Decoder}$",
    ]
)
model_df = pd.DataFrame(
    np.concat([
        data_base["etr"][threshold],
        data_base_1["ps3"][threshold],
        data_base_2["ps3"][threshold],
        data_base_3["ps3"][threshold]
    ]),
    columns=["Metric"]
)
model_df.insert(
    loc=1,
    column='Model',
    value=[i for i in ['$TransformEEG$',
                       "$TransformEEG_{LateFusion}$",
                       "$TransformEEG_{MidFusion}$",
                       "$TransformEEG_{Decoder}$",
                      ] for _ in range(100)]
)
model_df["Metric"]=model_df["Metric"]*100

In [None]:
font         = 25
fliercolor   = 'gray'
linecolor    = '#137'
scattercolor = "#0072b2"
boxcolor     = '#e69f00'
name         = "win_th"

fig, ax = plt.subplots(figsize=(16, 7.75))

sns.stripplot(
    x         = "Model",
    y         = "Metric",
    data      = model_df,
    legend    = False,
    linewidth = 1,
    ax        = ax,
    size      = 5,
    palette   = ["gray"]*len(model_df["Model"].unique()),
    alpha    = .9
)

sns.boxplot(
    data=model_df,
    x="Model",
    y="Metric",
    ax=ax,
    fill=True,
    showfliers=False,
    linecolor = linecolor,
    flierprops = dict(
        marker='o',
        markerfacecolor=fliercolor,
        linestyle='none',
        markeredgecolor=fliercolor
    ),
    boxprops=dict(alpha=.8),
    palette=["lightgray"]*len(model_df["Model"].unique())
) 
iqr_values = df.apply(lambda x: stats.iqr(x.values), axis=0)*100
add_median_labels(
    ax, '.2f', fontsize=font-10, lim=[50,100], iqr_values=iqr_values, iqr_offset = 1.1)
ax.set_yticks([i*5 for i in range(8, 21)])

ax.set_xlabel('Model', fontsize = font, labelpad=12)
ax.set_ylabel('Balanced Accuracy %', fontsize = font)
ax.set_ylim(55,101)
ax.tick_params(axis='both', which='major', labelsize=font-5.1)
ax.text(-0.35,96.5,f'$(B)$',fontsize = font+6)
ax.set_title(f'Addition of the PSD effects',
             fontsize = font+3, pad=12)
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.5)
plt.subplots_adjust(top=0.9, hspace=0.25, wspace=0.25)
plt.savefig(f"Images/transformeeg_with_psd.pdf", bbox_inches='tight')
plt.show()

## Transformer Variants

In [None]:
metric = "accuracy_weighted"
threshold = "win_05"

data_base = get_result_dict(model_list, root_path, "baseline_with_aug", metric=metric)
data_base_1 = get_result_dict(
    ["etr"], root_path, "transformer_variants", start_string="pos_pd*",
    outer_position=7, inner_position=8
)
data_base_2 = get_result_dict(
    ["etr"], root_path, "transformer_variants", start_string="cls_pd*",
    outer_position=7, inner_position=8
)
data_base_3 = get_result_dict(
    ["etr"], root_path, "transformer_variants", start_string="cls_and_pos*",
    outer_position=7, inner_position=8
)

df = pd.DataFrame(
    np.array([
        data_base["etr"][threshold],
        data_base_1["etr"][threshold],
        data_base_2["etr"][threshold],
        data_base_3["etr"][threshold]
    ]).T,
    columns=[
        '$TransformEEG$',
        "$TransformEEG_{pos}$",
        "$TransformEEG_{cls}$",
        "$TransformEEG_{cls+pos}$",
    ]
)
model_df = pd.DataFrame(
    np.concat([
        data_base["etr"][threshold],
        data_base_1["etr"][threshold],
        data_base_2["etr"][threshold],
        data_base_3["etr"][threshold]
    ]),
    columns=["Metric"]
)
model_df.insert(
    loc=1,
    column='Model',
    value=[i for i in ["$TransformEEG$",
                       "$TransformEEG_{pos}$",
                       "$TransformEEG_{cls}$",
                       "$TransformEEG_{cls+pos}$",
                      ] for _ in range(100)]
)
model_df["Metric"]=model_df["Metric"]*100

In [None]:
font         = 25
fliercolor   = 'gray'
linecolor    = '#137'
scattercolor = "#0072b2"
boxcolor     = '#e69f00'
name         = "win_05"

fig, ax = plt.subplots(figsize=(16, 7.75))

sns.stripplot(
    x         = "Model",
    y         = "Metric",
    data      = model_df,
    legend    = False,
    linewidth = 1,
    ax        = ax,
    size      = 5,
    palette   = ["gray"]*len(model_df["Model"].unique()),
    alpha    = .9
)

sns.boxplot(
    data=model_df,
    x="Model",
    y="Metric",
    ax=ax,
    fill=True,
    showfliers=False,
    linecolor = linecolor,
    flierprops = dict(
        marker='o',
        markerfacecolor=fliercolor,
        linestyle='none',
        markeredgecolor=fliercolor
    ),
    boxprops=dict(alpha=.8),
    palette=["lightgray"]*len(model_df["Model"].unique())
) 
iqr_values = df.apply(lambda x: stats.iqr(x.values), axis=0)*100
add_median_labels(
    ax, '.2f', fontsize=font-10, lim=[50,100], iqr_values=iqr_values, iqr_offset = 1.1)
ax.set_yticks([i*5 for i in range(8, 21)])

ax.set_xlabel('Model', fontsize = font, labelpad=12)
ax.set_ylabel('Balanced Accuracy %', fontsize = font)
ax.set_ylim(60,101)
ax.tick_params(axis='both', which='major', labelsize=font-5.1)
#ax.text(4.25,96,f'$({letters[i]})$',fontsize = font+6)
ax.set_title(f'Positional Embedding and Class Token effects on TransformEEG',
             fontsize = font+3, pad=12)
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.5)
plt.subplots_adjust(top=0.9, hspace=0.25, wspace=0.25)
plt.savefig(f"Images/transformer_variants.pdf", bbox_inches='tight')
plt.show()

In [None]:
metric = "accuracy_weighted"
threshold = "win_05"

data_base = get_result_dict(model_list, root_path, "baseline_with_aug", metric=metric)
data_base_1 = get_result_dict(
    ["etr"], root_path, "transformer_variants", start_string="two_head*",
    outer_position=7, inner_position=8
)
data_base_2 = get_result_dict(
    ["etr"], root_path, "transformer_variants", start_string="four_head*",
    outer_position=7, inner_position=8
)

df = pd.DataFrame(
    np.array([
        data_base["etr"][threshold],
        data_base_1["etr"][threshold],
        data_base_2["etr"][threshold],
    ]).T,
    columns=[
        '$TransformEEG$',
        "$TransformEEG_{2heads}$",
        "$TransformEEG_{4heads}$",
    ]
)
model_df = pd.DataFrame(
    np.concat([
        data_base["etr"][threshold],
        data_base_1["etr"][threshold],
        data_base_2["etr"][threshold],
    ]),
    columns=["Metric"]
)
model_df.insert(
    loc=1,
    column='Model',
    value=[i for i in ["$TransformEEG$",
                       "$TransformEEG_{2heads}$",
                       "$TransformEEG_{4heads}$",
                      ] for _ in range(100)]
)
model_df["Metric"]=model_df["Metric"]*100

In [None]:
font         = 25
fliercolor   = 'gray'
linecolor    = '#137'
scattercolor = "#0072b2"
boxcolor     = '#e69f00'
name         = "win_05"

fig, ax = plt.subplots(figsize=(16, 7.75))

sns.stripplot(
    x         = "Model",
    y         = "Metric",
    data      = model_df,
    legend    = False,
    linewidth = 1,
    ax        = ax,
    size      = 5,
    palette   = ["gray"]*len(model_df["Model"].unique()),
    alpha    = .9
)

sns.boxplot(
    data=model_df,
    x="Model",
    y="Metric",
    ax=ax,
    fill=True,
    showfliers=False,
    linecolor = linecolor,
    flierprops = dict(
        marker='o',
        markerfacecolor=fliercolor,
        linestyle='none',
        markeredgecolor=fliercolor
    ),
    boxprops=dict(alpha=.8),
    palette=["lightgray"]*len(model_df["Model"].unique())
) 
iqr_values = df.apply(lambda x: stats.iqr(x.values), axis=0)*100
add_median_labels(
    ax, '.2f', fontsize=font-10, lim=[50,100], iqr_values=iqr_values, iqr_offset = 1.1)
ax.set_yticks([i*5 for i in range(8, 21)])

ax.set_xlabel('Model', fontsize = font, labelpad=12)
ax.set_ylabel('Balanced Accuracy %', fontsize = font)
ax.set_ylim(60,101)
ax.tick_params(axis='both', which='major', labelsize=font-5.1)

ax.set_title(f"Number of heads effect on TransformEEG", fontsize = font+3, pad=12)
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.5)
plt.subplots_adjust(top=0.9, hspace=0.25, wspace=0.25)
plt.savefig(f"Images/transformer_heads.pdf", bbox_inches='tight')
plt.show()

## Seed

In [None]:
font         = 25
fliercolor   = 'gray'
linecolor    = '#137'
scattercolor = "#0072b2"
boxcolor     = '#e69f00'
metric       = "accuracy_weighted"
threshold    = "win_05"

for seed in ["0001", "0012"]:
    print("Results with seed ", seed)
    data_base = get_result_dict(
        model_list, root_path, "seed", metric=metric, end_string="*"+seed+".pickle")
    df = pd.DataFrame(
        np.array([data_base[i][threshold] for i in data_base.keys()]).T,
        columns=[get_full_name(i) for i in data_base.keys()]
    )
    model_df = pd.DataFrame(
        np.concat([data_base[i][threshold] for i in data_base.keys()]),
        columns=["Metric"]
    )
    model_df.insert(
        loc=1,
        column='Model',
        value=[get_full_name(i) for i in data_base.keys() for _ in range(100)]
    )
    model_df.insert(
        loc=1,
        column='ModelParam',
        value=[get_full_name(i) + "\n" + param_list[n] for n, i in enumerate(data_base.keys()) for _ in range(100)]
    )
    model_df['Metric'] = model_df['Metric']*100
    
    fig, ax = plt.subplots(figsize=(21, 7.75))
    sns.stripplot(
        x         = "Model",
        y         = "Metric",
        data      = model_df,
        legend    = False,
        linewidth = 1,
        ax        = ax,
        size      = 5,
        palette   = ["gray"]*len(data_base.keys()),
        alpha    = .9
    )
    
    sns.boxplot(
        data=model_df,
        x="Model",
        y="Metric",
        ax=ax,
        fill=True,
        showfliers=False,
        linecolor = linecolor,
        flierprops = dict(
            marker='o',
            markerfacecolor=fliercolor,
            linestyle='none',
            markeredgecolor=fliercolor
        ),
        boxprops=dict(alpha=.8),
        palette=["lightgray"]*len(data_base.keys())
    ) 
    iqr_values = df.apply(lambda x: stats.iqr(x.values), axis=0)*100
    add_median_labels(
        ax, '.2f', fontsize=font-10, lim=[50,100], iqr_values=iqr_values, iqr_offset = 1.1)
    ax.set_yticks([i*5 for i in range(8, 20)])
    ax.set_title(f'Model comparison - no data augmentation - no threshold correction',
                 fontsize = font+3, pad=12)
    ax.set_xlabel('Model', fontsize = font, labelpad=12)
    ax.set_ylabel('Balanced Accuracy %', fontsize = font)
    ax.set_ylim(45,96)
    ax.tick_params(axis='both', which='major', labelsize=font-5.1)
    #ax.text(4.25,96,f'$({letters[i]})$',fontsize = font+6)
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_linewidth(1.5)
    plt.subplots_adjust(top=0.9, hspace=0.25, wspace=0.25)
    #plt.savefig(f"Images/baseline_noaug_nocorr.pdf", bbox_inches='tight')
    plt.show()
    print(df.apply(range_1_99, axis=0))

## Data augmentation choice 2

In [None]:
aug_list = [
    'flip_horizontal', 'flip_vertical', 'add_band_noise', 'add_eeg_artifact',
    'add_noise_snr', 'channel_dropout', 'masking', 'warp_signal',
    'random_FT_phase', 'phase_swap',
]
def get_model_da_res(Mdl):
    data_aug={Mdl: None}
    win_05 = np.zeros((10, 10, 50))
    win_th = np.zeros((10, 10, 50))
    sub_05 = np.zeros((10, 10, 50))
    sub_th = np.zeros((10, 10, 50))
    epochs = np.zeros((10, 10, 50))
    for a1 in aug_list:
        idx1 = get_aug_idx(a1)
        sidx1 = str(idx1+1).zfill(3)
        for a2 in aug_list:
            idx2 = get_aug_idx(a2)
            sidx2 = str(idx2+1).zfill(3)
            glob_name = f"{root_path}full_data_augmentation/*_{Mdl}_*_{sidx1}_{sidx2}_000250*"
            path_list = glob.glob(glob_name)
            path_list = sorted(
                path_list, 
                key = lambda x: (
                    int(x.split(os.sep)[-1].split('_')[4]),
                    int(x.split(os.sep)[-1].split('_')[5]),
                )
            )
            for n, p in enumerate(path_list):
                
                with open(p,'rb') as scorefile:
                    scores = pickle.load(scorefile)
            
                bal_acc = scores['th_standard']['accuracy_weighted']
                win_05[idx1, idx2, n] = bal_acc
                
                bal_acc_roc = scores['th_corrected']['accuracy_weighted']
                win_th[idx1, idx2, n] = bal_acc_roc
                
                bal_acc_sub = scores['subject']['accuracy_weighted']
                sub_05[idx1, idx2, n] = bal_acc_sub

                for j in scores['loss_progression'].keys():
                    if scores['loss_progression'][j][0] is None:
                        break
                epochs[idx1, idx2, n] = j-15+1
    
    data_aug[Mdl] = {
        "win_05": win_05,
        "win_th": win_th,
        "sub_05": sub_05,
        "epochs": epochs
    }
    return data_aug

In [None]:
data_base_1 = get_result_dict(
    ['xeg', 'egn', 'shn', 'atc', 'con', 'dcn', 'res'], root_path,
    "full_data_augmentation", end_string="*000_000_000250_*.pickle",
    outer_position=7, inner_position=8
)

font         = 25
fliercolor   = 'gray'
linecolor    = '#137'
scattercolor = "#0072b2"
boxcolor     = '#e69f00'
letters      = ["A", "B"]
th_name      = "win_05"
aug_list_full = [
     'flip horizontal',
     'flip vertical',
     'add band noise',
     'signal drift',
     'add noise SNR',
     'channel dropout',
     'masking',
     'warp',
     'random phase',
     'phase swap',
]

i = 0
for _, model in enumerate(['xeg', 'egn', 'shn', 'atc', 'con', 'dcn', 'res']):

    if model == "shn":
        iqr_min, iqr_max = 5 , 12
    elif model == "xeg":
        iqr_min, iqr_max = 7 , 12
    elif model == "res":
        iqr_min, iqr_max = 5 , 10
    elif model == "egn":
        iqr_min, iqr_max = 5 , 12
    elif model == "con":
        iqr_min, iqr_max = 5 , 12
    elif model == "dcn":
        iqr_min, iqr_max = 7 , 15
    else: 
        iqr_min, iqr_max = 7 , 12 
    base_vec = data_base_1[model][th_name][data_base_1[model][th_name]>0]*100
    print(f'baseline results for {model}: '
          f'{np.median(base_vec):.3f}, '
          f'{stats.iqr(base_vec):.3f}')
    data_aug = get_model_da_res(model)
    fig, ax = plt.subplots(1, 2, figsize=(23, 8.75))
    row = i//2
    col = i%2
    
    htmp = sns.heatmap(
        #data = np.random.normal(loc=75, scale=4, size=(10,10)),
        data = np.median(data_aug[model][th_name]*100, axis=-1),
        ax = ax[col],
        cbar = True,
        robust = True,
        annot = True,
        fmt = '.2f',
        linewidth = .75,
        vmin = 74,
        vmax = 81,
        xticklabels = aug_list_full,
        yticklabels = aug_list_full,
        cmap = sns.cubehelix_palette(as_cmap=True),#sns.color_palette("ch:s=-.2,r=.6", as_cmap=True), #"Blues"
        cbar_kws = {"pad":0.02},
        annot_kws = {"size": font-13, "fontweight": "bold"}
        
    )
    colorbar = htmp.collections[0].colorbar
    colorbar.set_label("Balanced Accuracy - Median", fontsize=font-8)
    colorbar.ax.yaxis.label.set_size(font-7)  # Set the fontsize for the colorbar label
    colorbar.ax.tick_params(labelsize=font-7)
    ax[col].set_xlabel('Augmentation 2', fontsize = font)
    ax[col].set_ylabel('Augmentation 1', fontsize = font)
    ax[col].tick_params(axis='both', which='major', labelsize=font-8)
    ax[col].set_xticklabels(ax[col].get_xticklabels(), rotation=45, ha='right')
    #ax[col].set_title( f'Model: {model}',fontsize = font+3, pad=12)

    htmp2 = sns.heatmap(
        #data = np.random.normal(loc=5, scale=1, size=(10,10)),
        data = stats.iqr(data_aug[model][th_name]*100, axis=-1),
        ax   = ax[col+1],
        cbar = True,
        robust =True,
        annot=True,
        fmt  ='.2f',
        linewidth=.75,
        vmin=iqr_min,
        vmax=iqr_max,
        xticklabels =aug_list_full,
        yticklabels = aug_list_full,
        cmap = sns.cubehelix_palette(reverse=True, as_cmap=True),
        cbar_kws = {"pad":0.02},
        annot_kws={"size": font-13, "fontweight": "bold"}
        
    )
    colorbar = htmp2.collections[0].colorbar
    colorbar.set_label("Balanced Accuracy - Inter-quartile range ", fontsize=font-8)
    colorbar.ax.yaxis.label.set_size(font-7)  # Set the fontsize for the colorbar label
    colorbar.ax.tick_params(labelsize=font-7) 
    ax[col+1].tick_params(axis='both', which='major', labelsize=font-8)
    ax[col+1].set_xticklabels(ax[i].get_xticklabels(), rotation=45, ha='right')
    ax[col+1].set_xlabel('Augmentation 2', fontsize = font)
    ax[col].text(-3.5, -0.2,'(A)',fontsize = font+6)
    ax[col+1].text(-3.5, -0.2,'(B)',fontsize = font+6)
    fig.suptitle( f'Data augmentation comparison - Model: {model}           ',
                 fontsize = font+8)
    plt.subplots_adjust(top=0.825, hspace=0.25, wspace=0.325)
    plt.show()

In [None]:
data_base_1 = get_result_dict(
    ['xeg', 'egn', 'shn', 'atc', 'con', 'dcn', 'res'], root_path,
    "full_data_augmentation", end_string="*000_000_000250_*.pickle",
    outer_position=7, inner_position=8
)

font         = 25
fliercolor   = 'gray'
linecolor    = '#137'
scattercolor = "#0072b2"
boxcolor     = '#e69f00'
letters      = ["A", "B"]
th_name      = "win_05"
aug_list_full = [
     'flip horizontal',
     'flip vertical',
     'add band noise',
     'signal drift',
     'add noise SNR',
     'channel dropout',
     'masking',
     'warp',
     'random phase',
     'phase swap',
]

i = 0
for _, model in enumerate(['xeg', 'egn', 'shn', 'atc', 'con', 'dcn', 'res']):
    
    base_vec = data_base_1[model][th_name][data_base_1[model][th_name]>0]*100
    print(f'baseline results for {model}: '
          f'{np.median(base_vec):.3f}, '
          f'{stats.iqr(base_vec):.3f}')
    data_aug = get_model_da_res(model)
    fig, ax = plt.subplots(1, 1, figsize=(23, 8.75))

    aug_coeff = get_aug_score_matrix(data_aug[model][th_name]*100, base_vec)
    htmp = sns.heatmap(
        #data = np.random.normal(loc=75, scale=4, size=(10,10)),
        data = aug_coeff,
        ax = ax,
        cbar = True,
        robust = True,
        annot = True,
        fmt = '.5f',
        linewidth = .75,
        xticklabels = aug_list_full,
        yticklabels = aug_list_full,
        cmap = sns.cubehelix_palette(as_cmap=True),#sns.color_palette("ch:s=-.2,r=.6", as_cmap=True), #"Blues"
        cbar_kws = {"pad":0.02},
        annot_kws = {"size": font-13, "fontweight": "bold"}
        
    )
    colorbar = htmp.collections[0].colorbar
    colorbar.set_label("Aug score", fontsize=font-8)
    colorbar.ax.yaxis.label.set_size(font-7)  # Set the fontsize for the colorbar label
    colorbar.ax.tick_params(labelsize=font-7)
    ax.set_xlabel('Augmentation 2', fontsize = font)
    ax.set_ylabel('Augmentation 1', fontsize = font)
    ax.tick_params(axis='both', which='major', labelsize=font-8)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    #ax[col].set_title( f'Model: {model}',fontsize = font+3, pad=12)
    plt.show()