# Figure 5 - Differences in neural signals and classification performance between overt- and silent-speech attempts

In [None]:
import decimal
import itertools
import os
import pickle

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from statannot import add_stat_annotation

from silent_spelling.utils import plotting_defaults, holm_bonferroni_correction, bootstrap_confidence_intervals

%load_ext autoreload
%autoreload 2

In [None]:
# Subject
subject = 'bravo1'
sr = 200
sig_thresh = 0.01
pvalue_thresholds = [[1e-4, "***"], [0.001, "**"], [0.01, "*"], [1, "ns"]]

fig_dir = 'saved_figures'
load_from_RT = False
save_to_excel = True

# Name of the folder that contains result .pkl's
result_folder_name = 'spelling_paper_signal_analyses'

# Define the result file nums
result_nums = {
    'alphabet1_2_ecog_trials': 51,
    'overt_mimed_acc': 84
}

training_schemes = {
    'mimed': 'Silent only',
    'overt': 'Overt only',
    'mimed-> ft overt' : 'Silent pre-train,\novert fine-tune',
    'overt-> ft mimed' : 'Overt pre-train,\nsilent fine-tune',
}

# Load brain image and electrode coordinates
brain_img = plt.imread(f'recon/{subject}_brain_2D.png')
elec_coords = np.load(f'recon/{subject}_elecmat_2D.npy')
elec_layout = np.load(f'recon/{subject}_elec_layout.npy')

paradigms = ['overt', 'mimed']
paradigm_names = {
    'overt': 'overt',
    'mimed': 'silent'
}
    
erp_elecs = [0, 101]
erp_words = ['kilo', 'tango']
alphabet1_2_ecog_window = np.array([-1, 3])

excel_filepath = os.path.join(os.path.split(os.getcwd())[0], 'source_data', 'source_data.xlsx')

## Load data

In [None]:
if load_from_RT:
    
    # Custom software for file handling on Chang Lab systems
    from RT.util import fileHandler, RTConfig
    
    ## Accuracy
    result_path = fileHandler.getSubResultFilePath(
        sub_dir_key='analysis',
        result_label=result_folder_name,
        sub_result_num=result_nums['overt_mimed_acc']
    )
    overt_mimed_acc_df = pd.read_hdf(result_path)
    
    ## ERP
    result_path = fileHandler.getSubResultFilePath(
        sub_dir_key='analysis',
        result_label=result_folder_name,
        sub_result_num=result_nums['alphabet1_2_ecog_trials']
    )
    with open(result_path, 'rb') as f:
        alphabet1_2_ecog = pickle.load(f)
        
    # columns
    erp_dict = {}
    for para in paradigms:

        erp_dict[para] = {}
        for word in erp_words:

            erp_dict[para][word] = {}
            for elec in erp_elecs:

                idx = np.where(alphabet1_2_ecog['labels'][para] == word)[0]
                erp_dict[para][word][elec] = alphabet1_2_ecog['ecog'][para][idx, :, elec]
                
    if save_to_excel:
        
        if os.path.exists(excel_filepath):
            mode = 'a'
        else:
            mode = 'w'
        
        with pd.ExcelWriter(excel_filepath, mode=mode) as writer:  
            
            for para in paradigms:
                for word in erp_words:
                    for elec in erp_elecs:
                        pd.DataFrame(erp_dict[para][word][elec]).to_excel(
                            writer, sheet_name=f'Fig 5BC_{para}_{word}_{elec}', index=False
                        )
            
            overt_mimed_acc_df.to_excel(writer, sheet_name='Fig 5D', index=False)

else:
    
    erp_dict = {}
    for para in paradigms:

        erp_dict[para] = {}
        for word in erp_words:

            erp_dict[para][word] = {}
            for elec in erp_elecs:

                data = pd.read_excel(excel_filepath, sheet_name=f'Fig 5BC_{para}_{word}_{elec}', engine='openpyxl')
                erp_dict[para][word][elec] = data.values
                
    overt_mimed_acc_df = pd.read_excel(excel_filepath, sheet_name='Fig 5D', engine='openpyxl')

## Overt vs silent

### `overt` vs `silent`, decoding accuracy

In [None]:
overt_mimed_acc_df['unique scheme'] = [f'{i}${j}' for i, j in zip(overt_mimed_acc_df['training scheme'].values, overt_mimed_acc_df['test data'].values)]

# Perform statistical tests
pvals_float = {}
stats_float = {}
for c1, c2 in itertools.combinations(overt_mimed_acc_df['unique scheme'].unique(), 2):
    key = f'{c1}&{c2}'
    group1 = overt_mimed_acc_df.loc[overt_mimed_acc_df['unique scheme'] == c1]['accuracy'].values
    group2 = overt_mimed_acc_df.loc[overt_mimed_acc_df['unique scheme'] == c2]['accuracy'].values
    stats_float[key], pvals_float[key] = stats.ranksums(group1, group2)

# Holm-Bonferroni correction
hbc_pval_str = {}
hbc_pvals = holm_bonferroni_correction(pvals_float)
print(f'{len(pvals_float.keys())}-way Holm-Bonferroni correction')
    
# Format for stat annot
acc_box_pairs, acc_pvals = [], []
for key, val in hbc_pvals.items():
    
    if val > sig_thresh:
        print(key.split('&'), val, 'not significant')
        
    else:
        print(key.split('&'), val)
        continue
        
    c1, c2 = key.split('&')
    acc_box_pairs.append((tuple(c1.split('$')), tuple(c2.split('$'))))
    acc_pvals.append(val)

In [None]:
## Print output for Latex table (Supplementary Table S4)

for key, hbpval in hbc_pvals.items():
    group1, group2 = key.split('&')
    train1, test1 = group1.split('$')
    train2, test2 = group2.split('$')
    
    if '-> ft' in train1:
#         train1 = train1.replace('-> ft', ' pre-train, \\\\')
        train1 = train1.replace('-> ft', ' pre-train,')
        train1 += ' fine-tune'
        
    if '-> ft' in train2:
#         train2 = train2.replace('-> ft', ' pre-train, \\\\')
        train2 = train2.replace('-> ft', ' pre-train,')
        train2 += ' fine-tune'
        
    train1 = "\\thead{" + train1.capitalize() + "}"
    train2 = "\\thead{" + train2.capitalize() + "}"
    test1 = "\\thead{" + test1.capitalize() + "}"
    test2 = "\\thead{" + test2.capitalize() + "}"
    
#     train1 = train1.capitalize()
#     train2 = train2.capitalize()
#     test1 = test1.capitalize()
#     test2 = test2.capitalize()
    
    string = ' & '.join([train1, test1, train2, test2, str(abs(np.round_(stats_float[key], 2))), 
                         '\\num{' + str(np.format_float_scientific(hbpval, precision=2)) + '}'])
    string += "  \\\\"
    string = string.replace('mimed', 'silent').replace('Mimed', 'Silent')
    print(string)

In [None]:
for unique_scheme in overt_mimed_acc_df['unique scheme'].unique():
    b = overt_mimed_acc_df.loc[overt_mimed_acc_df['unique scheme'] == unique_scheme]['accuracy'].values
    print(unique_scheme.split('$'), 100*np.median(b), 100*bootstrap_confidence_intervals(b))

## Overall Figure

In [None]:
# Get Dark2 colors
set2_colors = sns.color_palette("Set2")

colors = {
    'overt': set2_colors[0],
    'mimed': set2_colors[1],
    'alphabet1_1': set2_colors[4],
    'alphabet1_2': set2_colors[6],
    'hga': set2_colors[3],
    'raw': set2_colors[2],
    'hga + raw': set2_colors[7]
}

brain_closeup = {
    'xlim': [200, 500],
    'ylim': [150, 550]
}

set2_colors

In [None]:
linewidth = 1
annot_linewidth = 1
fontsize = 7
plotting_defaults(font='Arial', fontsize=fontsize, linewidth=linewidth)
panel_label_fontsize = 7
boxplot_kwargs = {'fliersize': 3}
scatter_size = 3
mm = 1 / 25.4
mm_figsize = [mm*180, mm*120]

fig = plt.figure(figsize=mm_figsize, constrained_layout=True)

# fig = plt.figure(figsize=(15, 10), constrained_layout=True)
gs = mpl.gridspec.GridSpec(4, 6, figure=fig)
axs = {}


##### ----- Brain plot
axs['brain'] = fig.add_subplot(gs[:2, :2])
axs['brain'].imshow(brain_img, alpha=0.2)
axs['brain'].axis('off')

for i in range(elec_coords.shape[0]):
    if i in erp_elecs:
        axs['brain'].scatter(elec_coords[i, 0], elec_coords[i, 1], color='k', s=scatter_size+5)
        axs['brain'].annotate(f'e{i}', (elec_coords[i, 0] + 15, elec_coords[i, 1]))
    else:
        axs['brain'].scatter(elec_coords[i, 0], elec_coords[i, 1], color='k', s=0.5, alpha=0.5)
        
##### ----- Make ERP panel
erp_ylim = [-0.5, 1.5]
erp_yticks = np.arange(-0.5, 1.51, 0.5)

axs['erps'] = [fig.add_subplot(gs[0, 2:4]), fig.add_subplot(gs[0, 4:6]), fig.add_subplot(gs[1, 2:4]), fig.add_subplot(gs[1, 4:6])]

for ax in axs['erps']:
    ax.axvline(x=0.0, linestyle=':', color='k')
    ax.axhline(y=0.0, linestyle=':', color='k')
    ax.axes.set(xlabel='Time (s)', ylim=erp_ylim, yticks=erp_yticks)
    
x = np.linspace(alphabet1_2_ecog_window[0], alphabet1_2_ecog_window[1], num=erp_dict['overt'][erp_words[0]][erp_elecs[0]].shape[1])
counter = 0
for cur_elec, elec in enumerate(erp_elecs):
    axs['erps'][cur_elec*2].axes.set(ylabel=f'e{elec} HGA\n(z-score)')
    for cur_word, word in enumerate(erp_words):
        axs['erps'][cur_word].axes.set(title=f'"{word}"')
        
        for para in paradigms:
            ax = axs['erps'][counter]
            sns.despine(ax=ax, offset=dict(bottom=5, left=5))
            ax.axes.set(xlim=alphabet1_2_ecog_window)
            hga_trials = erp_dict[para][word][elec]
            y = hga_trials.mean(0)
            err = stats.sem(hga_trials, axis=0)
            ax.plot(x, y, label=paradigm_names[para].capitalize(), color=colors[para])
            ax.fill_between(x, y - err, y + err, alpha=0.5, color=colors[para])
            
        counter += 1

axs['erps'][1].legend(bbox_to_anchor=(1.05, 1.2), loc='upper right', frameon=False)


##### ----- Overt vs mimed accuracies
axs['overt_mimed_acc'] = fig.add_subplot(gs[2:, 1:5])
axs['overt_mimed_acc'] = sns.boxplot(data=overt_mimed_acc_df, y='accuracy', x='training scheme', hue='test data',
                                     palette=[colors['overt'], colors['mimed']], ax=axs['overt_mimed_acc'],
                                     hue_order=['overt', 'mimed'], order=training_schemes.keys(),showfliers=False)
axs['overt_mimed_acc'] = sns.stripplot(data=overt_mimed_acc_df, y='accuracy', x='training scheme', hue='test data',
                                     palette=[colors['overt'], colors['mimed']], ax=axs['overt_mimed_acc'],
                                     hue_order=['overt', 'mimed'], order=training_schemes.keys(), 
                                       edgecolor='black', linewidth=linewidth - 0.5, size=scatter_size, dodge=True)
current_handles, current_labels = axs['overt_mimed_acc'].get_legend_handles_labels()
axs['overt_mimed_acc'].legend(frameon=True, title=r'$\bf{Test\,data}$', loc='upper left',
                              handles=current_handles, labels=[l.capitalize() for l in paradigm_names.values()],
                             bbox_to_anchor=(1, 1.05))
axs['overt_mimed_acc'].axes.set(xlabel='Training scheme', ylabel='Accuracy', ylim=(0.2, 0.7),
                                xticklabels=list(training_schemes.values()))
axs['overt_mimed_acc'].axhline(y=1/26, linestyle=':', color='k')
add_stat_annotation(axs['overt_mimed_acc'], data=overt_mimed_acc_df, y='accuracy', x='training scheme', hue='test data',
                    hue_order=['overt', 'mimed'], order=training_schemes.keys(),
                    box_pairs=acc_box_pairs, perform_stat_test=False, pvalues=acc_pvals,
                    text_format='star', loc='outside', pvalue_thresholds=pvalue_thresholds, 
                    fontsize='medium', line_offset=0.01, linewidth=annot_linewidth)
fig.text(1.02, 1.05, 'P < 0.01 for all\ncomparisons\nnot marked ns', ha='left', fontsize=fontsize, transform=axs['overt_mimed_acc'].transAxes)


##### ----- Figure panel labels
fig.text(0.00, 0.98, r'$\bf{a}$', ha='left', fontsize=panel_label_fontsize, weight='bold')
fig.text(0.29, 0.98, r'$\bf{b}$', ha='left', fontsize=panel_label_fontsize, weight='bold')
fig.text(0.65, 0.98, r'$\bf{c}$', ha='left', fontsize=panel_label_fontsize, weight='bold')
fig.text(0.1, 0.48, r'$\bf{d}$', ha='left', fontsize=panel_label_fontsize, weight='bold');

### Save figure

In [None]:
figure_dpi = 300

for ext in ['png', 'pdf']:
    fig.savefig(os.path.join(fig_dir, f'figure5_overt_vs_mimed.{ext}'), 
                transparent=True, bbox_inches='tight', dpi=figure_dpi)
    fig.savefig(os.path.join(fig_dir, f'figure5_overt_vs_mimed_white.{ext}'), 
                transparent=False, bbox_inches='tight', dpi=figure_dpi)