# Figure 4 - Comparison of neural signals during attempts to silently say English letters and NATO code words

In [None]:
import itertools
import os
import pickle

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from statannot import add_stat_annotation

from silent_spelling.utils import plotting_defaults, correlation_permutation, bootstrap_confidence_intervals

plotting_defaults(font='Arial', fontsize=16)

%load_ext autoreload
%autoreload 2

In [None]:
# Subject
subject = 'bravo1'
sig_thresh = 0.01
pvalue_thresholds = [[1e-4, "***"], [0.001, "**"], [0.01, "*"], [1, "ns"]]

fig_dir = 'saved_figures'
load_from_RT = False
save_to_excel = True

# Name of the folder that contains result .pkl's
result_folder_name = 'spelling_paper_signal_analyses'

# Define the result file nums
result_nums = {
    'nearest_class_distance': 19,
    'char_vs_nato_acc': 61 
}

excel_filepath = os.path.join(os.path.split(os.getcwd())[0], 'source_data', 'source_data.xlsx')

stats_nums = {
    'char_vs_nato_corr_permute': 71
}

# Proper utterance set names
stim_set_names = {
    'alphabet1_1': 'English\nletters',
    'alphabet1_2': 'NATO\ncode words'
}
df_set_names = {
    'Characters': 'English\nletters',
    'Codewords': 'NATO\ncode words'
}

paradigms = ['overt', 'mimed']

with open(f'{os.path.split(os.getcwd())[0]}/letter_codeword_label_mapping.pkl', 'rb') as f:
    label_mapping = pickle.load(f)

## Load data

In [None]:
if load_from_RT:
    
    # Custom software for file handling on Chang Lab systems
    from RT.util import fileHandler, RTConfig
    
    result_path = fileHandler.getSubResultFilePath(
        sub_dir_key='analysis',
        result_label=result_folder_name,
        sub_result_num=result_nums['char_vs_nato_acc']
    )
    char_nato_df = pd.read_hdf(result_path)
        
    result_path = fileHandler.getSubResultFilePath(
        sub_dir_key='analysis',
        result_label=result_folder_name,
        sub_result_num=result_nums['nearest_class_distance']
    )
    with open(result_path, 'rb') as f:
        nearest_class = pickle.load(f)
    
    if save_to_excel:
        
        with open('letter_codeword_label_mapping.pkl', 'wb') as f:
            pickle.dump(label_mapping, f)
        
        if os.path.exists(excel_filepath):
            mode = 'a'
        else:
            mode = 'w'
        
        with pd.ExcelWriter(excel_filepath, mode=mode) as writer:  
            char_nato_df.to_excel(writer, sheet_name='Fig 4A', index=False)
            
            for key in nearest_class.keys():
                pd.DataFrame(nearest_class[key].reshape((1000, -1))).to_excel(writer, sheet_name=f'Fig 4BC_{key}', index=False)

else:
       
    char_nato_df = pd.read_excel(excel_filepath, sheet_name='Fig 4A', engine='openpyxl')
    
    nearest_class = {}
    for key in stim_set_names.keys():
        
        flat_array = pd.read_excel(excel_filepath, sheet_name=f'Fig 4BC_{key}', engine='openpyxl').values
        nearest_class[key] = flat_array.reshape((1000, 26, 26))

## alphabet1_1 vs alphabet1_2

### Nearest class distance

In [None]:
letter_labels = list(label_mapping['alphabet1_1'].keys())
    
confusion_df_avg = {}
for stim_set in nearest_class.keys():
    A = np.mean(nearest_class[stim_set], axis=0)
    W = np.triu(A) + np.tril(A.T, 1)
    confusion_df_avg[stim_set] = W
    
# Find the nonzero row minimum (nearest neighbor distance in a sense)
nearest_class_distance = {}

for stim_set in nearest_class.keys():
    
    nearest_class_distance[stim_set] = []
    
    for cur_row in range(confusion_df_avg[stim_set].shape[0]):
        row = list(confusion_df_avg[stim_set][cur_row, :])
        row.remove(0.0)
        nearest_class_distance[stim_set].append(np.min(row))

# Make the dict into a DataFrame
nearest_class_dist_dict = {'utterance_set': [], 'distance': []}
for key, val in nearest_class_distance.items():
    for v in val:
        nearest_class_dist_dict['utterance_set'].append(key)
        nearest_class_dist_dict['distance'].append(v)
nearest_class_dist_df = pd.DataFrame(data=nearest_class_dist_dict)

dist_box_pairs = [('alphabet1_1', 'alphabet1_2')]
stat, pval = stats.ranksums(nearest_class_distance['alphabet1_1'], nearest_class_distance['alphabet1_2'])
nearest_dist_pvals = [pval]
print('Letters vs codeword p-value:', pval)
print('Letters vs codeword z-value:', stat)
       
P = 2000

if load_from_RT:
    
    from RT.util import fileHandler, RTConfig
    
    if stats_nums['char_vs_nato_corr_permute'] is not None:
    
        stats_results_path = fileHandler.getSubResultFilePath(
            sub_dir_key='analysis',
            result_label=result_folder_name,
            sub_result_num=stats_nums['char_vs_nato_corr_permute']
        )
        with open(stats_results_path, 'rb') as f:
            alphnato_corr = pickle.load(f)['stats']
    
    else:
    
        stats_results_path = fileHandler.getSubResultFilePath(
            sub_dir_key='analysis',
            result_label=result_folder_name,
            next_file_sub_label='char_vs_nato_corr_permute'
        )
        alphnato_corr = correlation_permutation(nearest_class_distance['alphabet1_1'], nearest_class_distance['alphabet1_2'], corr=stats.spearmanr, n_permute=P)
        with open(stats_results_path + '.pkl', 'wb') as f:
            pickle.dump({'stats': alphnato_corr}, f)
            
    if save_to_excel:
        
        if os.path.exists(excel_filepath):
            mode = 'a'
        else:
            mode = 'w'
        
        with pd.ExcelWriter(excel_filepath, mode=mode) as writer:
            pd.DataFrame(alphnato_corr).to_excel(writer, sheet_name='Fig 4C_stats', index=False)
        
else:
    
    alphnato_corr = pd.read_excel(excel_filepath, sheet_name='Fig 4C_stats', engine='openpyxl').values.reshape((-1))

print('Correlation and p-value:', alphnato_corr)

### Decoding accuracy

In [None]:
alph_nato_box_pairs = []
alph_nato_acc_pvals = []
alph_nato_acc_stats = []
for c1, c2 in itertools.combinations(char_nato_df['training scheme'].unique(), 2):
    group1 = char_nato_df.loc[char_nato_df['training scheme'] == c1]['accuracy'].values
    group2 = char_nato_df.loc[char_nato_df['training scheme'] == c2]['accuracy'].values
    stat, pval = stats.ranksums(group1, group2)
    
    if pval > sig_thresh:
        continue
        
    alph_nato_box_pairs.append((c1, c2))
    alph_nato_acc_pvals.append(pval)
    alph_nato_acc_stats.append(stat)
    print(c1, c2, pval, stat)

## Overall Figure

In [None]:
# Get Dark2 colors
set2_colors = sns.color_palette("Set2")

colors = {
    'alphabet1_1': set2_colors[4],
    'alphabet1_2': set2_colors[6]
}

set2_colors

In [None]:
linewidth = 1
annot_linewidth = 1
fontsize = 7
plotting_defaults(font='Arial', fontsize=fontsize, linewidth=linewidth)
panel_label_fontsize = 7
boxplot_kwargs = {'fliersize': 3}
scatter_size = 3
mm = 1 / 25.4
mm_figsize = [mm*180, mm*60]

fig, axs_tuple = plt.subplots(1, 3, figsize=mm_figsize)
axs = {
    'alphnato_dist_box' : axs_tuple[1],
    'alphnato_dist_scat' : axs_tuple[2],
    'alphnato_acc' : axs_tuple[0]
}


##### ----- 1_1 vs 1_2: ACCURACIES
axs['alphnato_acc'] = sns.boxplot(data=char_nato_df, x='training scheme', y='accuracy', ax=axs['alphnato_acc'],
                                 order=df_set_names.keys(), palette=[colors['alphabet1_1'], colors['alphabet1_2']], 
                                 showfliers=False, **boxplot_kwargs)
axs['alphnato_acc'] = sns.stripplot(data=char_nato_df, x='training scheme', y='accuracy', ax=axs['alphnato_acc'],
                                 order=df_set_names.keys(), palette=[colors['alphabet1_1'], colors['alphabet1_2']],
                                    edgecolor='black', linewidth=linewidth-0.5, size=scatter_size)
plt.setp(axs['alphnato_acc'].lines, clip_on=False)
axs['alphnato_acc'].axes.set(xticklabels=df_set_names.values(),
                             xlabel='',
                             ylabel='Classification accuracy',
                             ylim=(0.0, 0.6));
axs['alphnato_acc'].axhline(y=1/len(letter_labels), linestyle=':', color='k', linewidth=linewidth)
add_stat_annotation(axs['alphnato_acc'], data=char_nato_df, x='training scheme', y='accuracy', order=df_set_names.keys(),
                   box_pairs=alph_nato_box_pairs, perform_stat_test=False, pvalues=alph_nato_acc_pvals,
                   text_format='star', loc='outside', pvalue_thresholds=pvalue_thresholds, linewidth=annot_linewidth)


##### ----- 1_1 vs 1_2: Nearest class distance, boxplot
axs['alphnato_dist_box'] = sns.boxplot(data=nearest_class_dist_df, x='utterance_set', y='distance',
                                 order=stim_set_names.keys(), palette=[colors['alphabet1_1'], colors['alphabet1_2']],
                                      ax=axs['alphnato_dist_box'], 
                                      showfliers=False, **boxplot_kwargs)
axs['alphnato_dist_box'] = sns.stripplot(data=nearest_class_dist_df, x='utterance_set', y='distance',
                                 order=stim_set_names.keys(), palette=[colors['alphabet1_1'], colors['alphabet1_2']],
                                      ax=axs['alphnato_dist_box'],  edgecolor='black', linewidth=linewidth-0.5,
                                        size=scatter_size)
axs['alphnato_dist_box'].axes.set(xticklabels=stim_set_names.values(),
                                  xlabel='',
                                  ylim=(3.66, 3.78),
                                  ylabel='Nearest-class distance');
add_stat_annotation(axs['alphnato_dist_box'], data=nearest_class_dist_df, x='utterance_set', y='distance', order=stim_set_names.keys(),
                   box_pairs=dist_box_pairs, perform_stat_test=False, pvalues=nearest_dist_pvals,
                   text_format='star', loc='outside', pvalue_thresholds=pvalue_thresholds, linewidth=annot_linewidth)


##### ----- 1_1 vs 1_2: Nearest class distance, scatter plot
axes_limit = [3.65, 3.85]
diag_line = np.arange(int(axes_limit[0]) - 5, int(axes_limit[1]) + 5)
axs['alphnato_dist_scat'].plot(diag_line, diag_line, color='k', linestyle=':', linewidth=linewidth)

for cur_letter in range(len(label_mapping['alphabet1_1'])):
    axs['alphnato_dist_scat'].annotate(letter_labels[cur_letter], 
                                       (nearest_class_distance['alphabet1_1'][cur_letter], 
                                        nearest_class_distance['alphabet1_2'][cur_letter]),
                                       fontsize=fontsize)
axs['alphnato_dist_scat'].axes.set(ylim=axes_limit, xlim=axes_limit, 
                                   xlabel='Nearest-class distance -\n' + stim_set_names['alphabet1_1'].replace('\n', ' '),
                                   ylabel='Nearest-class distance -\n' + stim_set_names['alphabet1_2'].replace('\n', ' '))
# pval = f'{alphnato_corr[1]:0.2e}'.split('e')
axs['alphnato_dist_scat'].annotate(r'$P$ = {:0.3f}'.format(alphnato_corr[1]), (3.75, 3.835), fontsize=fontsize)
axs['alphnato_dist_scat'].annotate(r'$\rho$ = {:0.3f}'.format(alphnato_corr[0]), (3.75, 3.82), fontsize=fontsize)

fig.tight_layout();

##### ----- Figure panel labels
axs['alphnato_acc'].annotate('a', (-0.15, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_label_fontsize, weight='bold')
axs['alphnato_dist_box'].annotate('b', (-0.15, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_label_fontsize, weight='bold')
axs['alphnato_dist_scat'].annotate('c', (-0.15, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_label_fontsize, weight='bold');

### Save figure

In [None]:
figure_dpi = 300

for ext in ['png', 'pdf']:
    fig.savefig(os.path.join(fig_dir, f'figure4_codeword_vs_alphabet.{ext}'), 
                transparent=True, bbox_inches='tight', dpi=figure_dpi)
    fig.savefig(os.path.join(fig_dir, f'figure4_codeword_vs_alphabet_white.{ext}'), 
                transparent=False, bbox_inches='tight', dpi=figure_dpi)