# Figure S4 - Effects of feature selection on code-word classification accuracy

In [None]:
import itertools
import os
import pickle

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats

from silent_spelling.utils import plotting_defaults, holm_bonferroni_correction, correlation_permutation

plotting_defaults(font='Arial', fontsize=16)

%load_ext autoreload
%autoreload 2

In [None]:
# Subject
subject = 'bravo1'
sr = 200
sig_thresh = 0.01
pvalue_thresholds = [[1e-4, "***"], [0.001, "**"], [0.01, "*"], [1, "ns"]]

fig_dir = 'saved_figures'
load_from_RT = False
save_to_excel = True

# Name of the folder that contains result .pkl's
result_folder_name = 'spelling_paper_signal_analyses'

# Define the result file nums
result_nums = {
    'hga_raw_acc_single_letter': 83
}
# Replace with a result number if you don't want the most recent one
stats_nums = {
    'hgaraw_single_codeword_corr_permute': 88
}

# Proper utterance set names
stim_set_names = {
    'alphabet1_1': 'English alphabet',
    'alphabet1_2': 'NATO codewords'
}

# Proper feature names
all_feature_names = {
    'hga': 'HGA',
    'raw': 'LFS',
    'hga + raw': 'HGA+LFS',
    'hga_and_raw': 'HGA+LFS'
}
feature_names = {
    'hga': 'HGA',
    'raw': 'LFS',
    'hga + raw': 'HGA+LFS'
}
second_feature_names = {
    'HGA': 'HGA',
    'LFS': 'LFS',
    'HGA + LFS': 'HGA+LFS'
}

with open(f'{os.path.split(os.getcwd())[0]}/letter_codeword_label_mapping.pkl', 'rb') as f:
    label_mapping = pickle.load(f)
letter_labels = list(label_mapping['alphabet1_1'].keys())

excel_filepath = os.path.join(os.path.split(os.getcwd())[0], 'source_data', 'source_data.xlsx')

## Load data

In [None]:
if load_from_RT:
    
    # Custom software for file handling on Chang Lab systems
    from RT.util import fileHandler, RTConfig
    
    ## Single letter accuracy
    result_path = fileHandler.getSubResultFilePath(
        sub_dir_key='analysis',
        result_label=result_folder_name,
        sub_result_num=result_nums['hga_raw_acc_single_letter']
    )
    hgaraw_single_letter_acc_df = pd.read_hdf(result_path)
                
    if save_to_excel:
        
        if os.path.exists(excel_filepath):
            mode = 'a'
        else:
            mode = 'w'
        
        with pd.ExcelWriter(excel_filepath, mode=mode) as writer:
            hgaraw_single_letter_acc_df.to_excel(writer, sheet_name='Fig S4', index=False)
            

else:
    
    hgaraw_single_letter_acc_df = pd.read_excel(excel_filepath, sheet_name='Fig S4', engine='openpyxl')

## Single letter decoding accuracy

In [None]:
# Create a dictionary for easy plotting
hgaraw_letter_accs = {}
for feat in feature_names.keys():
    
    hgaraw_letter_accs[feat] = []
    
    for cur_letter in letter_labels:
        s_acc = hgaraw_single_letter_acc_df.loc[(hgaraw_single_letter_acc_df['Paradigm'] == feat) & 
                                             (hgaraw_single_letter_acc_df['Letter'] == cur_letter)]['Accuracy'].values[0]
        hgaraw_letter_accs[feat].append(s_acc)

In [None]:
P = 2000

if load_from_RT:
    
    from RT.util import fileHandler, RTConfig
    
    if stats_nums['hgaraw_single_codeword_corr_permute'] is not None:
    
        stats_results_path = fileHandler.getSubResultFilePath(
            sub_dir_key='analysis',
            result_label=result_folder_name,
            sub_result_num=stats_nums['hgaraw_single_codeword_corr_permute']
        )
        with open(stats_results_path, 'rb') as f:
            stats_dict = pickle.load(f)
            hga_vs_hgaraw_corr = stats_dict['hga_vs_hgaraw_corr']
            raw_vs_hgaraw_corr = stats_dict['raw_vs_hgaraw_corr']
    
    else:
        
        print('Computing correlation...')
    
        stats_results_path = fileHandler.getSubResultFilePath(
            sub_dir_key='analysis',
            result_label=result_folder_name,
            next_file_sub_label='hgaraw_single_codeword_corr_permute'
        )

        hga_vs_hgaraw_corr = correlation_permutation(hgaraw_letter_accs['hga'], 
                                                     hgaraw_letter_accs['hga + raw'], 
                                                     corr=stats.spearmanr, 
                                                     n_permute=P)
        raw_vs_hgaraw_corr = correlation_permutation(hgaraw_letter_accs['raw'], 
                                                     hgaraw_letter_accs['hga + raw'], 
                                                     corr=stats.spearmanr, 
                                                     n_permute=P)

        with open(stats_results_path + '.pkl', 'wb') as f:
            pickle.dump({'raw_vs_hgaraw_corr': raw_vs_hgaraw_corr,
                         'hga_vs_hgaraw_corr': hga_vs_hgaraw_corr}, f)
            
    stats_df = {key: [] for key in ['feature_name', 'spearman_corr', 'pvalue']}
    stats_df['feature_name'].append('raw_vs_hgaraw_corr')
    stats_df['spearman_corr'].append(raw_vs_hgaraw_corr[0])
    stats_df['pvalue'].append(raw_vs_hgaraw_corr[1])
    stats_df['feature_name'].append('hga_vs_hgaraw_corr')
    stats_df['spearman_corr'].append(hga_vs_hgaraw_corr[0])
    stats_df['pvalue'].append(hga_vs_hgaraw_corr[1])
    stats_df = pd.DataFrame(data=stats_df)
    
    if save_to_excel:
        
        if os.path.exists(excel_filepath):
            mode = 'a'
        else:
            mode = 'w'
        
        with pd.ExcelWriter(excel_filepath, mode=mode) as writer:
            stats_df.to_excel(writer, sheet_name='Fig S4_stats', index=False)
        
else:
    
    stats_df = pd.read_excel(excel_filepath, sheet_name='Fig S4_stats', engine='openpyxl').values[:, 1:]
    raw_vs_hgaraw_corr = stats_df[0, :]
    hga_vs_hgaraw_corr = stats_df[1, :]

In [None]:
hga_vs_hgaraw_corr_p = None
raw_vs_hgaraw_corr_p = None

for _, p_thresh in np.flip(pvalue_thresholds):
    
    p_thresh = float(p_thresh)
    
    if hga_vs_hgaraw_corr[1] < p_thresh:
        hga_vs_hgaraw_corr_p = p_thresh
        
    if raw_vs_hgaraw_corr[1] < p_thresh:
        raw_vs_hgaraw_corr_p = p_thresh
        
print(hga_vs_hgaraw_corr, hga_vs_hgaraw_corr_p)
print(raw_vs_hgaraw_corr, raw_vs_hgaraw_corr_p)

## Overall Figure

In [None]:
axs = {}
fig, (axs['hga_hgaraw_single_letter'], axs['raw_hgaraw_single_letter']) = plt.subplots(1, 2, figsize=(12, 5))

##### ----- HGA vs HGA + Raw: alphabet1_2 accuracies
for cur_letter in range(len(letter_labels)):
    axs['hga_hgaraw_single_letter'].annotate(letter_labels[cur_letter], 
                                             (hgaraw_letter_accs['hga'][cur_letter], hgaraw_letter_accs['hga + raw'][cur_letter]), 
                                             fontsize=16)
axs['hga_hgaraw_single_letter'].plot(range(2), range(2), linestyle=':', color='k')
axs['hga_hgaraw_single_letter'].axes.set(xlabel='NATO code-word accuracy - HGA', ylabel='NATO code-word accuracy - HGA+LFS',
                                        ylim=(0, 1), xlim=(0, 1))
pval = f'{hga_vs_hgaraw_corr[1]:0.2e}'.split('e')
axs['hga_hgaraw_single_letter'].annotate(r'$P$ < {}'.format(hga_vs_hgaraw_corr_p), (0.1, 0.95))
axs['hga_hgaraw_single_letter'].annotate(r'$\rho$ = {:0.3f}'.format(hga_vs_hgaraw_corr[0]), (0.1, 0.9))


##### ----- Raw vs HGA + Raw: alphabet1_2 accuracies
for cur_letter in range(len(letter_labels)):
    axs['raw_hgaraw_single_letter'].annotate(letter_labels[cur_letter], 
                                             (hgaraw_letter_accs['raw'][cur_letter], hgaraw_letter_accs['hga + raw'][cur_letter]), 
                                             fontsize=16)
axs['raw_hgaraw_single_letter'].plot(range(2), range(2), linestyle=':', color='k')
axs['raw_hgaraw_single_letter'].axes.set(xlabel='NATO code-word accuracy - LFS', ylabel='NATO code-word accuracy - HGA+LFS',
                                        ylim=(0, 1), xlim=(0, 1))
axs['raw_hgaraw_single_letter'].annotate(r'$P$ < {}'.format(raw_vs_hgaraw_corr_p), (0.1, 0.95))
axs['raw_hgaraw_single_letter'].annotate(r'$\rho$ = {:0.3f}'.format(raw_vs_hgaraw_corr[0]), (0.1, 0.9))

fig.text(0.00, 0.99, 'a', ha='left', fontsize=22, weight='bold')
fig.text(0.5, 0.99, 'b', ha='left', fontsize=22, weight='bold')

fig.tight_layout();

### Save figure

In [None]:
figure_dpi = 300

for ext in ['png', 'pdf']:
    fig.savefig(os.path.join(fig_dir, f'suppfig_hga_vs_raw_corrs.{ext}'), 
                transparent=True, bbox_inches='tight', dpi=figure_dpi)
    fig.savefig(os.path.join(fig_dir, f'suppfig_hga_vs_raw_corrs_white.{ext}'), 
                transparent=False, bbox_inches='tight', dpi=figure_dpi)