# Figure 2 - Performance summary of the spelling system during the copy-typing task

In [None]:
import itertools
import os
import pickle

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from statannot import add_stat_annotation

from silent_spelling.utils import plotting_defaults, holm_bonferroni_correction, bootstrap_confidence_intervals

%load_ext autoreload
%autoreload 2

In [None]:
# Subject
subject = 'bravo1'
sig_thresh = 0.01
pvalue_thresholds = [[1e-4, "***"], [0.001, "**"], [0.01, "*"], [1, "ns"]]

fig_dir = 'saved_figures'
load_from_RT = False
save_to_excel = True

# Name of the folder that contains result .pkl's
result_folder_name = 'spelling_paper_signal_analyses'

# Define the result file nums
result_nums = {
    'system_cer': 31,
    'system_wer': 32,
    'rates'     : 36,
    'length'    : 52
}

excel_filepath = os.path.join(os.path.split(os.getcwd())[0], 'source_data', 'source_data.xlsx')

## Load data

In [None]:
if load_from_RT:
    
    # Custom software for file handling on Chang Lab systems
    from RT.util import fileHandler, RTConfig
    
    # Load the data with RT and save in excel file
    ## WER
    result_path = fileHandler.getSubResultFilePath(
        sub_dir_key='analysis',
        result_label=result_folder_name,
        sub_result_num=result_nums['system_wer']
    )
    system_wer = pd.read_hdf(result_path)
    
    ## CER
    result_path = fileHandler.getSubResultFilePath(
        sub_dir_key='analysis',
        result_label=result_folder_name,
        sub_result_num=result_nums['system_cer']
    )
    system_cer = pd.read_hdf(result_path)
    
    ## WPM and CPM
    result_path = fileHandler.getSubResultFilePath(
        sub_dir_key='analysis',
        result_label=result_folder_name,
        sub_result_num=result_nums['rates']
    )
    rates = pd.read_hdf(result_path)
    
    ## Sentence length
    result_path = fileHandler.getSubResultFilePath(
        sub_dir_key='analysis',
        result_label=result_folder_name,
        sub_result_num=result_nums['length']
    )
    lens = pd.read_hdf(result_path)
    
    if save_to_excel:
        
        if os.path.exists(excel_filepath):
            mode = 'a'
        else:
            mode = 'w'
        
        with pd.ExcelWriter(excel_filepath, mode=mode) as writer:  
            system_cer.to_excel(writer, sheet_name='Fig 2A', index=False)
            system_wer.to_excel(writer, sheet_name='Fig 2B', index=False)
            rates.to_excel(writer, sheet_name='Fig 2CD', index=False)
            lens.to_excel(writer, sheet_name='Fig 2E', index=False)

else:
    system_cer = pd.read_excel(excel_filepath, sheet_name='Fig 2A', engine='openpyxl')
    system_wer = pd.read_excel(excel_filepath, sheet_name='Fig 2B', engine='openpyxl')
    rates = pd.read_excel(excel_filepath, sheet_name='Fig 2CD', engine='openpyxl')
    lens = pd.read_excel(excel_filepath, sheet_name='Fig 2E', engine='openpyxl')

## Ablations - WER

In [None]:
system_wer['Word Error Rate'] = system_wer['Word Error Rate'].values * 100
    
system_wer_pvals_float = {}
system_wer_stats_float = {}
for c1, c2 in itertools.combinations(system_wer['Paradigm'].unique(), 2):
    key = f'{c1}&{c2}'
    group1 = system_wer.loc[system_wer['Paradigm'] == c1]['Word Error Rate'].values
    group2 = system_wer.loc[system_wer['Paradigm'] == c2]['Word Error Rate'].values
    system_wer_stats_float[key], system_wer_pvals_float[key] = stats.ranksums(group1, group2)
    
system_wer_pvals_float_corrected = holm_bonferroni_correction(system_wer_pvals_float)

system_wer_pvals, system_wer_box_pairs = [], []
wer_stats_df = {key: [] for key in ['Statistical comparison \tnote{1}', '$\mid$ \textit{z}-value $\mid$', '\textit{P}-value \\ & & (corrected\tnote{2}  )']}
for key, val in system_wer_pvals_float_corrected.items():
    c1, c2 = key.split('&')
    
    if val > sig_thresh:
        continue
        
    wer_stats_df['Statistical comparison \tnote{1}'].append(f'{c1} vs. {c2}')
    wer_stats_df['$\mid$ \textit{z}-value $\mid$'].append('{:0.2f}'.format(abs(system_wer_stats_float[key])))
    wer_stats_df['\textit{P}-value \\ & & (corrected\tnote{2}  )'].append('\num{{{:0.3g}}}'.format(val))
    
    
    system_wer_box_pairs.append((c1, c2))
    system_wer_pvals.append(val)
    
wer_stats_df = pd.DataFrame(data=wer_stats_df)

In [None]:
system_wer_stats_float

In [None]:
system_wer_pvals_float_corrected

In [None]:
for para in system_wer['Paradigm'].unique():
    cur_wer = system_wer.loc[system_wer['Paradigm'] == para]['Word Error Rate'].values
    print(para, np.median(cur_wer), bootstrap_confidence_intervals(cur_wer))

In [None]:
system_wer.Paradigm.value_counts()

## Ablations - CER

In [None]:
system_cer['Character Error Rate'] = system_cer['Character Error Rate'].values * 100

system_cer_pvals_float, system_cer_stats_float = {}, {}
for c1, c2 in itertools.combinations(system_cer['Paradigm'].unique(), 2):
    key = f'{c1}&{c2}'
    group1 = system_cer.loc[system_cer['Paradigm'] == c1]['Character Error Rate'].values
    group2 = system_cer.loc[system_cer['Paradigm'] == c2]['Character Error Rate'].values
    system_cer_stats_float[key], system_cer_pvals_float[key] = stats.ranksums(group1, group2)
    
system_cer_pvals_float_corrected = holm_bonferroni_correction(system_cer_pvals_float)

system_cer_pvals, system_cer_box_pairs = [], []
cer_stats_df = {key: [] for key in ['Statistical comparison \tnote{1}', '$\mid$ \textit{z}-value $\mid$', '\textit{P}-value \\ & & (corrected\tnote{2}  )']}
for key, val in system_cer_pvals_float_corrected.items():
    c1, c2 = key.split('&')
    
    if val > sig_thresh:
        continue
        
    cer_stats_df['Statistical comparison \tnote{1}'].append(f'{c1} vs. {c2}')
    cer_stats_df['$\mid$ \textit{z}-value $\mid$'].append('{:0.2f}'.format(abs(system_cer_stats_float[key])))
    cer_stats_df['\textit{P}-value \\ & & (corrected\tnote{2}  )'].append('\num{{{:0.3g}}}'.format(val))
    
    system_cer_box_pairs.append((c1, c2))
    system_cer_pvals.append(val)
    
cer_stats_df = pd.DataFrame(data=cer_stats_df)

In [None]:
system_cer_stats_float

In [None]:
system_cer_pvals_float_corrected

In [None]:
for para in system_cer['Paradigm'].unique():
    cur_cer = system_cer.loc[system_cer['Paradigm'] == para]['Character Error Rate'].values
    print(para, np.median(cur_cer), bootstrap_confidence_intervals(cur_cer))

In [None]:
system_cer.Paradigm.value_counts()

## CPM and WPM

In [None]:
print(np.median(rates.CPM.values), np.max(rates.CPM.values), bootstrap_confidence_intervals(rates.CPM.values))
print(np.median(rates.WPM.values), np.max(rates.WPM.values), bootstrap_confidence_intervals(rates.WPM.values))

## Decoded sentence length

In [None]:
lens['off'] = lens['length pred'].values - lens['length real'].values

lengths, length_counts = np.unique(lens['off'].values, return_counts=True)
length_percents = 100 * (length_counts / lens.shape[0])
length_df = pd.DataFrame({'lengths': lengths, 'percent': length_percents, 'count': length_counts})

In [None]:
length_df

## Overall Figure

In [None]:
exp_names = {
    'Chance': 'Chance',
    'Only Neural Decoding': 'Only\nneural\ndecoding',
    '+ Vocab. Constraints': '+Vocab.\nconstr.',
    '+ LM (Realtime results)': '+LM\n(Real-time\nresults)'
}

In [None]:
linewidth = 1
annot_linewidth = 1
plotting_defaults(font='Arial', fontsize=7, linewidth=linewidth)
panel_label_fontsize = 7
boxplot_kwargs = {'fliersize': 2, 'width': 0.5}
mm = 1 / 25.4
mm_figsize = [mm*180, mm*110]

# orig_figsize = np.array([11, 11])
fig = plt.figure(figsize=mm_figsize)
gs = mpl.gridspec.GridSpec(2, 12, figure=fig, height_ratios=[1.3, 1])
axs = {}

colors = sns.color_palette('Set2')

system_paradigms = system_cer.Paradigm.unique()
system_paradigm_labels = [label.replace('+ ', '+').replace(' ', '\n').replace('Realtime', 'Real-time') for label in system_paradigms]

stripplot_kwargs = {
    'color': 'k',
    'alpha': 0.6,
    's': 3
}

##### ----- System CER
axs['system_cer'] = fig.add_subplot(gs[0, :4])
axs['system_cer'] = sns.boxplot(data=system_cer, x='Paradigm', y='Character Error Rate',
                                order=exp_names.keys(), ax=axs['system_cer'], palette='Set2', **boxplot_kwargs)
# axs['system_cer'] = sns.stripplot(data=system_cer, x='Paradigm', y='Character Error Rate',
#                                 order=exp_names.keys(), ax=axs['system_cer'], palette='Set2', 
#                                  edgecolor='black', linewidth=1.3)
axs['system_cer'].axes.set(xticklabels=exp_names.values(), xlabel='', ylim=(None, 100))
axs['system_cer'].axes.set_ylabel('Character error rate (%)', labelpad=-2)
add_stat_annotation(axs['system_cer'], data=system_cer, x='Paradigm', y='Character Error Rate',
                    order=exp_names.keys(), box_pairs=system_cer_box_pairs, perform_stat_test=False, 
                    pvalues=system_cer_pvals, text_format='star', loc='outside', pvalue_thresholds=pvalue_thresholds,
                    linewidth=annot_linewidth, text_offset=-3)


##### ----- System WER
axs['system_wer'] = fig.add_subplot(gs[0, 4:8])
axs['system_wer'] = sns.boxplot(data=system_wer, x='Paradigm', y='Word Error Rate',
                                order=exp_names.keys(), ax=axs['system_wer'], palette='Set2', **boxplot_kwargs)
# axs['system_wer'] = sns.stripplot(data=system_wer, x='Paradigm', y='Word Error Rate',
#                                 order=exp_names.keys(), ax=axs['system_wer'], palette='Set2', 
#                                  edgecolor='black', linewidth=1.3)
axs['system_wer'].axes.set(xticklabels=exp_names.values(), xlabel='', ylim=(None, 140), yticks=np.arange(0, 160, 20))
axs['system_wer'].axes.set_ylabel('Word error rate (%)', labelpad=0)
add_stat_annotation(axs['system_wer'], data=system_wer, x='Paradigm', y='Word Error Rate',
                    order=exp_names.keys(), box_pairs=system_wer_box_pairs, perform_stat_test=False, 
                    pvalues=system_wer_pvals, text_format='star', loc='outside', pvalue_thresholds=pvalue_thresholds,
                   linewidth=annot_linewidth, text_offset=-3)


##### ----- CPM
axs['cpm'] = fig.add_subplot(gs[0, 8:10])
axs['cpm'] = sns.boxplot(data=rates, x='blocks', y='CPM',
                         ax=axs['cpm'], palette=[colors[3]], **boxplot_kwargs)
axs['cpm'].axes.set(xlabel='Real-time\nresults', xticks=[], ylim=(27, 31), yticks=range(27, 32))
axs['cpm'].axes.set_ylabel('Characters per minute', labelpad=1)


##### ----- WPM
axs['wpm'] = fig.add_subplot(gs[0, 10:])
axs['wpm'] = sns.boxplot(data=rates, x='blocks', y='WPM',
                         ax=axs['wpm'], palette=[colors[3]], **boxplot_kwargs)
axs['wpm'].axes.set(xlabel='Real-time\nresults', xticks=[], ylim=(4, 9))
axs['wpm'].axes.set_ylabel('Words per minute', labelpad=1)


##### ----- Sentence length
axs['length'] = fig.add_subplot(gs[1, :3])
axs['length'] = sns.barplot(data=length_df, x='lengths', y='percent', color=colors[3], ax=axs['length'])
axs['length'].axes.set(ylim=(0, 100), xlim=(-1, 5), xticks=range(5), xticklabels=range(-2, 3), xlabel='No. of excess\ncharacters')
axs['length'].axes.set_ylabel('Percent of trials', labelpad=1)


#### ----- Figure panel labels
axs['system_cer'].annotate('a', (-0.2, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_label_fontsize, weight='bold')
axs['system_wer'].annotate('b', (-0.2, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_label_fontsize, weight='bold')
axs['cpm'].annotate('c', (-0.2, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_label_fontsize, weight='bold')
axs['wpm'].annotate('d', (-0.2, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_label_fontsize, weight='bold')
axs['length'].annotate('e', (-0.2, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_label_fontsize, weight='bold')
# axs['system_cer'].annotate('f', (-0.2, -0.6), xycoords='axes fraction', ha='right', fontsize=panel_label_fontsize, weight='bold')

fig.tight_layout();

### Save figure

In [None]:
figure_dpi = 300

for ext in ['png', 'pdf']:
    fig.savefig(os.path.join(fig_dir, f'figure2_realtime_decoding.{ext}'), 
                transparent=True, bbox_inches='tight', dpi=figure_dpi)
    fig.savefig(os.path.join(fig_dir, f'figure2_realtime_decoding_white.{ext}'), 
                transparent=False, bbox_inches='tight', dpi=figure_dpi)