# Figure 6 - The spelling approach can generalize to larger vocabularies and conversational settings

In [None]:
import itertools
import os
import pickle

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from statannot import add_stat_annotation
import matplotlib.font_manager

from silent_spelling.utils import plotting_defaults, holm_bonferroni_correction, bootstrap_confidence_intervals

plotting_defaults(font='Arial', fontsize=16)

%load_ext autoreload
%autoreload 2

In [None]:
# Subject
subject = 'bravo1'
sig_thresh = 0.01

fig_dir = 'saved_figures'
load_from_RT = False
save_to_excel = True

# Name of the folder that contains result .pkl's
result_folder_name = 'spelling_paper_signal_analyses'

# Define the result file nums
result_nums = {
    'convo': 33,
    'vocab_cer': 34,
    'vocab_wer': 35
}

# Proper utterance set names
vocab_names = {
    'alphabet1_1': 'English alphabet',
    'alphabet1_2': 'NATO code words'
}
lm_names = {
    'Google 9k': '9,170',
    'Oxford 5k': '5,249',
    'Oxford 3k': '3,303',
    'AAC 1k': '1,152'
}

excel_filepath = os.path.join(os.path.split(os.getcwd())[0], 'source_data', 'source_data.xlsx')

## Load data

In [None]:
if load_from_RT:
    
    # Custom software for file handling on Chang Lab systems
    from RT.util import fileHandler, RTConfig
    
    # Load the data with RT and save in excel file
    ## WER
    result_path = fileHandler.getSubResultFilePath(
        sub_dir_key='analysis',
        result_label=result_folder_name,
        sub_result_num=result_nums['vocab_wer']
    )
    vocab_wer = pd.read_hdf(result_path)
    
    ## CER
    result_path = fileHandler.getSubResultFilePath(
        sub_dir_key='analysis',
        result_label=result_folder_name,
        sub_result_num=result_nums['vocab_cer']
    )
    vocab_cer = pd.read_hdf(result_path)
    
    ## Conversation data
    result_path = fileHandler.getSubResultFilePath(
        sub_dir_key='analysis',
        result_label=result_folder_name,
        sub_result_num=result_nums['convo']
    )
    convo_df = pd.read_hdf(result_path)
    
    if save_to_excel:
        
        if os.path.exists(excel_filepath):
            mode = 'a'
        else:
            mode = 'w'
        
        with pd.ExcelWriter(excel_filepath, mode=mode) as writer:  
            vocab_cer.to_excel(writer, sheet_name='Fig 6A', index=False)
            vocab_wer.to_excel(writer, sheet_name='Fig 6B', index=False)
            convo_df.to_excel(writer, sheet_name='Fig 6C', index=False)

else:
    vocab_cer = pd.read_excel(excel_filepath, sheet_name='Fig 6A', engine='openpyxl')
    vocab_wer = pd.read_excel(excel_filepath, sheet_name='Fig 6B', engine='openpyxl')
    convo_df = pd.read_excel(excel_filepath, sheet_name='Fig 6C', engine='openpyxl')

## Vocab CER

In [None]:
vocab_cer['Character Error Rate'] = vocab_cer['Character Error Rate'].values * 100

vocab_cer_pvals_float = {}
vocab_cer_stats_float = {}
for c1, c2 in itertools.combinations(vocab_cer['Paradigm'].unique(), 2):
    key = f'{c1}&{c2}'
    group1 = vocab_cer.loc[vocab_cer['Paradigm'] == c1]['Character Error Rate'].values
    group2 = vocab_cer.loc[vocab_cer['Paradigm'] == c2]['Character Error Rate'].values
    vocab_cer_stats_float[key], vocab_cer_pvals_float[key] = stats.ranksums(group1, group2)
    
vocab_cer_pvals_float_corrected = holm_bonferroni_correction(vocab_cer_pvals_float)

vocab_cer_pvals, vocab_cer_box_pairs = [], []
for key, val in vocab_cer_pvals_float_corrected.items():
    c1, c2 = key.split('&')
    vocab_cer_box_pairs.append((c1, c2))
    vocab_cer_pvals.append(val)

In [None]:
vocab_cer_stats_float

In [None]:
for para in vocab_cer.Paradigm.unique():
    cur_cer = vocab_cer.loc[vocab_cer.Paradigm == para]['Character Error Rate'].values
    print(f'\n{para}')
    print(np.median(cur_cer))
    print(bootstrap_confidence_intervals(cur_cer))

## Vocab WER

In [None]:
vocab_wer['Word Error Rate'] = vocab_wer['Word Error Rate'].values * 100

vocab_wer_pvals_float = {}
vocab_wer_stats_float = {}
for c1, c2 in itertools.combinations(vocab_wer['Paradigm'].unique(), 2):
    key = f'{c1}&{c2}'
    group1 = vocab_wer.loc[vocab_wer['Paradigm'] == c1]['Word Error Rate'].values
    group2 = vocab_wer.loc[vocab_wer['Paradigm'] == c2]['Word Error Rate'].values
    vocab_wer_stats_float[key], vocab_wer_pvals_float[key] = stats.ranksums(group1, group2)
    
vocab_wer_pvals_float_corrected = holm_bonferroni_correction(vocab_wer_pvals_float)
print(f'{len(vocab_wer_pvals_float.keys())}-way Holm-Bonferroni correction')

vocab_wer_pvals, vocab_wer_box_pairs = [], []
for key, val in vocab_wer_pvals_float_corrected.items():
    c1, c2 = key.split('&')
    vocab_wer_box_pairs.append((c1, c2))
    vocab_wer_pvals.append(val)

In [None]:
for para in vocab_wer.Paradigm.unique():
    cur_wer = vocab_wer.loc[vocab_wer.Paradigm == para]['Word Error Rate'].values
    print(f'\n{para}')
    print(np.median(cur_wer))
    print(bootstrap_confidence_intervals(cur_wer))

## Conversation CER & WER

In [None]:
convo_df['Percent Error Rate'] = convo_df['Error Rate'].values * 100

for para in convo_df['Metric Type'].unique():
    cur_wer = convo_df.loc[convo_df['Metric Type'] == para]['Percent Error Rate'].values
    print(f'\n{para}')
    print(np.median(cur_wer))
    print(bootstrap_confidence_intervals(cur_wer))

## Overall Figure

In [None]:
linewidth = 1
annot_linewidth = 1
fontsize = 7
plotting_defaults(font='Arial', fontsize=fontsize, linewidth=linewidth)
panel_label_fontsize = 7
boxplot_kwargs = {'fliersize': 3, 'width': 0.5}
scatter_size = 3
mm = 1 / 25.4
mm_figsize = [mm*180, mm*50]

fig = plt.figure(figsize=mm_figsize)
gs = mpl.gridspec.GridSpec(1, 4, figure=fig, width_ratios=[2, 2, 1, 5])
axs = {}

vocab_paradigms = vocab_cer['Paradigm'].unique()

order = np.flip(vocab_paradigms)
xticklabels = [lm_names[key] for key in order]

##### ----- Vocab CER
axs['vocab_cer'] = fig.add_subplot(gs[0, 0])
axs['vocab_cer'] = sns.boxplot(data=vocab_cer, x='Paradigm', y='Character Error Rate',
                                order=order, ax=axs['vocab_cer'], palette='Set2', **boxplot_kwargs)
axs['vocab_cer'].axes.set(xlabel='Vocabulary size (words)', ylim=(None, 50), yticks=range(0, 51, 10), ylabel='Character error rate (%)',
                         xticklabels=xticklabels)
axs['vocab_cer'].tick_params(axis='x', labelrotation=45)
# add_stat_annotation(axs['vocab_cer'], data=vocab_cer, x='Paradigm', y='Character Error Rate',
#                    box_pairs=vocab_cer_box_pairs, perform_stat_test=False, pvalues=vocab_cer_pvals,
#                    text_format='star', loc='outside', order=vocab_paradigms)


##### ----- Vocab WER
axs['vocab_wer'] = fig.add_subplot(gs[0, 1])
axs['vocab_wer'] = sns.boxplot(data=vocab_wer, x='Paradigm', y='Word Error Rate',
                                order=order, ax=axs['vocab_wer'], palette='Set2', **boxplot_kwargs)
axs['vocab_wer'].axes.set(xlabel='Vocabulary size (words)', ylim=(None, 70), yticks=range(0, 71, 10), ylabel='Word error rate (%)',
                         xticklabels=xticklabels)
axs['vocab_wer'].tick_params(axis='x', labelrotation=45)
# add_stat_annotation(axs['vocab_wer'], data=vocab_wer, x='Paradigm', y='Word Error Rate',
#                    box_pairs=vocab_wer_box_pairs, perform_stat_test=False, pvalues=vocab_wer_pvals,
#                    text_format='star', loc='outside', order=vocab_paradigms)


##### ----- Conversation CER & WER
colors = sns.color_palette('Set2')[4:6]
axs['convo'] = fig.add_subplot(gs[0, 2])
axs['convo'] = sns.boxplot(data=convo_df, x='Metric Type', y='Percent Error Rate',
                           order=['cer', 'wer'],showfliers=False, ax=axs['convo'], palette=colors, **boxplot_kwargs)
axs['convo'] = sns.stripplot(data=convo_df, x='Metric Type', y='Percent Error Rate',
                           order=['cer', 'wer'], ax=axs['convo'], palette=colors, 
                            edgecolor='black', linewidth=linewidth - 0.5, size=scatter_size)
axs['convo'].axes.set(xlabel='', ylim=(None, 60), xticklabels=['Char.', 'Word'], ylabel='Error rate (%)')

a = fig.add_subplot(gs[0, -1])
a.axis('off')

##### ----- Figure panel labels
axs['vocab_cer'].annotate('a', (-0.25, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_label_fontsize, weight='bold')
axs['vocab_wer'].annotate('b', (-0.25, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_label_fontsize, weight='bold')
axs['convo'].annotate('c', (-0.25, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_label_fontsize, weight='bold')
axs['convo'].annotate('d', (1.2, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_label_fontsize, weight='bold')

fig.tight_layout();

### Save figure

In [None]:
figure_dpi = 300

for ext in ['png', 'pdf']:
    fig.savefig(os.path.join(fig_dir, f'figure6_vocab_convo.{ext}'), 
                transparent=True, bbox_inches='tight', dpi=figure_dpi)
    fig.savefig(os.path.join(fig_dir, f'figure6_vocab_convo_white.{ext}'), 
                transparent=False, bbox_inches='tight', dpi=figure_dpi)