# Figures S5, S6, and S7 - Confusion matrices from isolated-target trial classifications

In [None]:
import os

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import confusion_matrix
from silent_spelling.utils import plotting_defaults, bootstrap_confidence_intervals

plotting_defaults(font='Arial')

%load_ext autoreload
%autoreload 2

In [None]:
fig_dir = 'saved_figures'
load_from_RT = True
save_to_excel = False

plot_key = 'lfs_only'

# Define the result file nums.
result_nums = {
    'hga_only': 89,
    'lfs_only': 90,
    'hga_and_lfs': 91
}

# The keys when using the source data excel file.
source_data_keys = {
    'hga_only': 'S5',
    'lfs_only': 'S6',
    'hga_and_lfs': 'S7'
}

# Name of the folder that contains result .pkl's
result_folder_name = 'spelling_paper_signal_analyses'

excel_filepath = os.path.join(os.path.split(os.getcwd())[0], 'source_data', 'source_data.xlsx')

## Load data

In [None]:
if load_from_RT:
    
    # Custom software for file handling on Chang Lab systems
    from RT.util import fileHandler, RTConfig
    
    result_path = fileHandler.getSubResultFilePath(
        sub_dir_key='analysis',
        result_label=result_folder_name,
        sub_result_num=result_nums[plot_key]
    )
    mimed = pd.read_pickle(result_path)
    
    mimed_df = pd.DataFrame(np.stack(mimed['pred_vec'].values, axis=0), columns=range(27))
    mimed_df['label'] = mimed['label'].values
    mimed_df['cv'] = mimed['cv'].values
    mimed_df['blocks'] = mimed['blocks'].values
    
    if save_to_excel:
        
        if os.path.exists(excel_filepath):
            mode = 'a'
        else:
            mode = 'w'
        
        with pd.ExcelWriter(excel_filepath, mode=mode) as writer:  
            mimed_df.to_excel(writer, sheet_name=f'Fig {source_data_keys[plot_key]}', index=False)
            
else:
    
    mimed_df = pd.read_excel(excel_filepath, sheet_name=f'Fig {source_data_keys[plot_key]}', engine='openpyxl')

In [None]:
# Make the confusion matrix
pred = np.argmax(mimed_df.values[:, :-3], axis=-1)
conf_matrix = confusion_matrix(pred, mimed_df['label'], normalize='true')
mimed_df['pred_label'] = np.argmax(mimed_df.values[:, :-3], axis=-1)
mimed_df['correct'] = mimed_df.pred_label == mimed_df.label

cv_accs = []
for cur_cv in range(10):
    cur_df = mimed_df.loc[(mimed_df.cv == cur_cv)]
    cv_accs.append(cur_df.correct.sum() / cur_df.shape[0])
    
print(f'For {plot_key}, including hand command')
print(np.median(cv_accs))
print(bootstrap_confidence_intervals(cv_accs))

nato = ['Alpha', 'Bravo', 'Charlie', 'Delta', 'Echo', 'Foxtrot', 'Golf', 'Hotel', 'India', 'Juliet', 'Kilo', 
        'Lima','Mike', 'November', 'Oscar', 'Papa', 'Quebec','Romeo', 'Sierra', 'Tango', 'Uniform', 'Victor', 
        'Whiskey', 'X-Ray', 'Yankee', 'Zulu', 'Hand-attempt']
df = pd.DataFrame(data=conf_matrix, columns=nato, index=nato)

In [None]:
fig = plt.figure(figsize=(10, 9), constrained_layout=True)
gs = fig.add_gridspec(100, 100)
ax = fig.add_subplot(gs[:, :80])
cax = fig.add_subplot(gs[:, 97:])

sns.heatmap(df, cbar_kws={"shrink": 0.3}, cmap='bone_r', ax=ax, cbar_ax=cax, vmin=0, vmax=1, xticklabels=True)
ax.spines['bottom'].set_visible(True)
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.set_ylabel('Target', labelpad=10.0, fontsize='large')
ax.set_xlabel('Predicted', labelpad=10.0, fontsize='large')

cax.set_yticklabels(['{:<2g}'.format(float(i) * 100.) for i in cax.get_yticks()])
cax.set_ylabel('Confusion value (%; normalized by row)', fontsize='large', labelpad=20.0, rotation=270);

### Save figure

In [None]:
figure_dpi = 300

for ext in ['png', 'pdf']:
    fig.savefig(os.path.join(fig_dir, f'suppfig_confusion_{plot_key}.{ext}'), 
                transparent=True, bbox_inches='tight', dpi=figure_dpi)
    fig.savefig(os.path.join(fig_dir, f'suppfig_confusion_{plot_key}_white.{ext}'), 
                transparent=False, bbox_inches='tight', dpi=figure_dpi)