# Figure 3 - Characterization of high-gamma activity (HGA) and low-frequency signals (LFS) during silent-speech attempts

In [None]:
import copy
import itertools
import os
import pickle

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from statannot import add_stat_annotation

from silent_spelling.plotting import plot_images_and_elecs
from silent_spelling.utils import plotting_defaults, holm_bonferroni_correction, correlation_permutation, bootstrap_confidence_intervals

plotting_defaults(font='Arial', fontsize=16)

%load_ext autoreload
%autoreload 2

In [None]:
# Subject
subject = 'bravo1'
sr = 200
sig_thresh = 0.01
pvalue_thresholds = [[1e-4, "***"], [0.001, "**"], [0.01, "*"], [1, "ns"]]

fig_dir = 'saved_figures'
load_from_RT = False
save_to_excel = True

# Name of the folder that contains result .pkl's
result_folder_name = 'spelling_paper_signal_analyses'

# Define the result file nums
result_nums = {
    'raw_alone_salience': 80,
    'hga_alone_salience': 78,
    'rawhga_salience': 79,
    'hga_raw_acc': 77,
    'hga_raw_pr': 81,
    'temporal_smoothing': 82
}

stats_nums = {
    'salience_corr_permute': 86
}

# Proper utterance set names
stim_set_names = {
    'alphabet1_1': 'English alphabet',
    'alphabet1_2': 'NATO codewords'
}

# Proper feature names
all_feature_names = {
    'hga': 'HGA',
    'raw': 'LFS',
    'hga + raw': 'HGA+LFS',
    'hga_and_raw': 'HGA+LFS'
}
feature_names = {
    'hga': 'HGA',
    'raw': 'LFS',
    'hga + raw': 'HGA+LFS'
}
second_feature_names = {
    'HGA': 'HGA',
    'LFS': 'LFS',
    'HGA + LFS': 'HGA+LFS'
}

cwd = os.path.split(os.getcwd())[0]

# Load brain image and electrode coordinates
brain_img = plt.imread(f'recon/{subject}_brain_2D.png')
elec_coords = np.load(f'recon/{subject}_elecmat_2D.npy')
elec_layout = np.load(f'recon/{subject}_elec_layout.npy')

paradigms = ['overt', 'mimed']

with open(f'{cwd}/letter_codeword_label_mapping.pkl', 'rb') as f:
    label_mapping = pickle.load(f)
    
letter_labels = list(label_mapping['alphabet1_1'].keys())

excel_filepath = os.path.join(cwd, 'source_data', 'source_data.xlsx')

## Load data

In [None]:
if load_from_RT:
    
    # Custom software for file handling on Chang Lab systems
    from RT.util import fileHandler, RTConfig
    
    # Saliences
    saliences = {}
    for sal_key in ['raw_alone_salience', 'hga_alone_salience', 'rawhga_salience']:
        result_path = fileHandler.getSubResultFilePath(
            sub_dir_key='analysis',
            result_label=result_folder_name,
            sub_result_num=result_nums[sal_key]
        )
        temp_sal = np.load(result_path)

        # Find the cutoff for padding
        cutoff = np.where(temp_sal[0, :, 0] == 0)[0][0]

        # Take the L2 norm across time, then average by trial
        saliences[sal_key] = np.mean(np.linalg.norm(temp_sal[:, :cutoff, :], 2, axis=1), axis=0)

    ## HGA vs LFS accuracy
    result_path = fileHandler.getSubResultFilePath(
        sub_dir_key='analysis',
        result_label=result_folder_name,
        sub_result_num=result_nums['hga_raw_acc']
    )
    hga_raw_df = pd.read_hdf(result_path)
    
    ## Number of PCs for spatial and temporal variance
    result_path = fileHandler.getSubResultFilePath(
        sub_dir_key='analysis',
        result_label=result_folder_name,
        sub_result_num=result_nums['hga_raw_pr']
    )
    hga_raw_pr_df = pd.read_hdf(result_path)
    
    ## Temporal smoothing
    result_path = fileHandler.getSubResultFilePath(
        sub_dir_key='analysis',
        result_label=result_folder_name,
        sub_result_num=result_nums['temporal_smoothing']
    )
    temporal_smoothing = pd.read_hdf(result_path)
                
    if save_to_excel:
        
        if os.path.exists(excel_filepath):
            mode = 'a'
        else:
            mode = 'w'
        
        with pd.ExcelWriter(excel_filepath, mode=mode) as writer:
            
            hga_raw_df.to_excel(writer, sheet_name='Fig 3A', index=False)
            
            for key, val in saliences.items():
                pd.DataFrame(val).to_excel(writer, sheet_name=f'Fig 3B-E_{key}', index=False)
            
            hga_raw_pr_df.to_excel(writer, sheet_name='Fig 3FG', index=False)
            temporal_smoothing.to_excel(writer, sheet_name='Fig 3H', index=False)

else:
    
    saliences = {}
    for key in ['raw_alone_salience', 'hga_alone_salience', 'rawhga_salience']:
        saliences[key] = pd.read_excel(excel_filepath, sheet_name=f'Fig 3B-E_{key}', engine='openpyxl').values.squeeze()
                
    hga_raw_df = pd.read_excel(excel_filepath, sheet_name='Fig 3A', engine='openpyxl')
    hga_raw_pr_df = pd.read_excel(excel_filepath, sheet_name='Fig 3FG', engine='openpyxl')
    temporal_smoothing = pd.read_excel(excel_filepath, sheet_name='Fig 3H', engine='openpyxl')

## HGA vs raw vs HGA + raw

### Saliences

In [None]:
raw_alone = saliences['raw_alone_salience'] / np.sum(saliences['raw_alone_salience'])
hga_alone = saliences['hga_alone_salience'] / np.sum(saliences['hga_alone_salience'])
raw_combo = saliences['rawhga_salience'][:128] / np.sum(saliences['rawhga_salience'][:128])
hga_combo = saliences['rawhga_salience'][128:] / np.sum(saliences['rawhga_salience'][128:])

In [None]:
P = 2000

if load_from_RT:
    
    from RT.util import fileHandler, RTConfig
    
    if stats_nums['salience_corr_permute'] is not None:
        stats_results_path = fileHandler.getSubResultFilePath(
            sub_dir_key='analysis',
            result_label=result_folder_name,
            sub_result_num=stats_nums['salience_corr_permute'],
            extension='.h5'
        )
        sal_corr_permute = pd.read_hdf(stats_results_path, key='df')
            
    else:
        
        print('Computing correlations...')
        
        stats_results_path = fileHandler.getSubResultFilePath(
            sub_dir_key='analysis',
            result_label=result_folder_name,
            next_file_sub_label='salience_corr_permute'
        )

        sal_corr_permute = {'feature': [], 'spearman_corr': [], 'pvalue': []}
        
        sal_corr_permute['feature'].append('hga_vs_hgaraw')
        r, p = correlation_permutation(hga_alone, 
                                        hga_combo, 
                                        corr=stats.spearmanr, 
                                        n_permute=P)
        sal_corr_permute['spearman_corr'].append(r)
        sal_corr_permute['pvalue'].append(p)
        
        sal_corr_permute['feature'].append('raw_vs_hgaraw')
        r, p = correlation_permutation(raw_alone, 
                                    raw_combo, 
                                    corr=stats.spearmanr, 
                                    n_permute=P)
        sal_corr_permute['spearman_corr'].append(r)
        sal_corr_permute['pvalue'].append(p)
        
        sal_corr_permute['feature'].append('hga_vs_raw')
        r, p = correlation_permutation(hga_alone,
                                     raw_alone, 
                                     corr=stats.spearmanr, 
                                     n_permute=P)
        sal_corr_permute['spearman_corr'].append(r)
        sal_corr_permute['pvalue'].append(p)
        
        sal_corr_permute = pd.DataFrame(data=sal_corr_permute)
        sal_corr_permute.to_hdf(stats_results_path + '.h5', key='df', mode='w')  
    
    if save_to_excel:
        
        if os.path.exists(excel_filepath):
            mode = 'a'
        else:
            mode = 'w'
        
        with pd.ExcelWriter(excel_filepath, mode=mode) as writer:
            sal_corr_permute.to_excel(writer, sheet_name='Fig 3B-E_stats', index=False)
    
else:
    
    sal_corr_permute = pd.read_excel(excel_filepath, sheet_name='Fig 3B-E_stats', engine='openpyxl')

sal_corr_permute

### Decoding accuracy

In [None]:
hgaraw_boxplot_dict = {}
for feat in second_feature_names.keys():
    hgaraw_boxplot_dict[feat] = hga_raw_df.loc[hga_raw_df['Features'] == feat]['Accuracy'].values

hgaraw_stats_float, hgaraw_pvals_float = {}, {}
for c1, c2 in itertools.combinations(second_feature_names.keys(), 2):
    key = f'{c1}&{c2}'
    hgaraw_stats_float[key], hgaraw_pvals_float[key] = stats.ranksums(hgaraw_boxplot_dict[c1], hgaraw_boxplot_dict[c2])

hgaraw_pvals_float_corrected = holm_bonferroni_correction(hgaraw_pvals_float)

hgaraw_pvals, hgaraw_box_pairs = [], []
acc_stats_df = {key: [] for key in ['Statistical comparison \tnote{1}', '$\mid$ \textit{z}-value $\mid$', '\textit{P}-value \\ & & (corrected\tnote{2}  )']}
for key, val in hgaraw_pvals_float_corrected.items():
    c1, c2 = key.split('&')
    hgaraw_box_pairs.append((c1, c2))
    hgaraw_pvals.append(val)
    
    acc_stats_df['Statistical comparison \tnote{1}'].append(f'{c1} vs. {c2}')
    acc_stats_df['$\mid$ \textit{z}-value $\mid$'].append('{:0.2f}'.format(abs(hgaraw_stats_float[key])))
    acc_stats_df['\textit{P}-value \\ & & (corrected\tnote{2}  )'].append('\num{{{:0.3g}}}'.format(val))
    
acc_stats_df = pd.DataFrame(data=acc_stats_df)

In [None]:
hgaraw_stats_float

In [None]:
hgaraw_pvals_float_corrected

In [None]:
hga_raw_df.groupby(['Features']).median()

In [None]:
bootstrap_confidence_intervals(hga_raw_df.loc[hga_raw_df.Features == 'HGA + LFS'].Accuracy)

### No. PCs >80% variance explained

In [None]:
hga_raw_pr_dict = {'temporal': {}, 'spatial': {}}
    
for feat in feature_names.keys():
    hga_raw_pr_dict['temporal'][feat] = hga_raw_pr_df.loc[hga_raw_pr_df['Condition'] == feat]['80% var PC, Temporal'].values
    hga_raw_pr_dict['spatial'][feat] = hga_raw_pr_df.loc[hga_raw_pr_df['Condition'] == feat]['80% var PC, Spatial'].values
    
hga_raw_pr_pvals_float = {'temporal': {}, 'spatial': {}}
hga_raw_pr_stats_float = {'temporal': {}, 'spatial': {}}
for c1, c2 in itertools.combinations(feature_names.keys(), 2):
    key = f'{c1}&{c2}'
    hga_raw_pr_stats_float['temporal'][key], hga_raw_pr_pvals_float['temporal'][key] = stats.ranksums(hga_raw_pr_dict['temporal'][c1], hga_raw_pr_dict['temporal'][c2])
    hga_raw_pr_stats_float['spatial'][key], hga_raw_pr_pvals_float['spatial'][key] = stats.ranksums(hga_raw_pr_dict['spatial'][c1], hga_raw_pr_dict['spatial'][c2])

hgaraw_pr_box_pairs = {'temporal': [], 'spatial': []}
hgaraw_pr_pvals = {'temporal': [], 'spatial': []}

temp = {key: [] for key in ['Statistical comparison \tnote{1}', '$\mid$ \textit{z}-value $\mid$', '\textit{P}-value \\ & & (corrected\tnote{2}  )']}
pr_stats = {'temporal': copy.deepcopy(temp), 'spatial': copy.deepcopy(temp)}

for pr_type in hga_raw_pr_pvals_float.keys():
    
    corrected_pvals = holm_bonferroni_correction(hga_raw_pr_pvals_float[pr_type])
    
    for key, val in corrected_pvals.items():
        
        if val > sig_thresh:
            continue
        
        c1, c2 = key.split('&')
        hgaraw_pr_box_pairs[pr_type].append((c1, c2))
        hgaraw_pr_pvals[pr_type].append(val)
        
        pr_stats[pr_type]['Statistical comparison \tnote{1}'].append(f'{feature_names[c1]} vs. {feature_names[c2]}')
        pr_stats[pr_type]['$\mid$ \textit{z}-value $\mid$'].append('{:0.2f}'.format(abs(hga_raw_pr_stats_float[pr_type][key])))
        pr_stats[pr_type]['\textit{P}-value \\ & & (corrected\tnote{2}  )'].append('\num{{{:0.3g}}}'.format(val))
        
temp_stats_df = pd.DataFrame(data=pr_stats['temporal'])
spatial_stats_df = pd.DataFrame(data=pr_stats['spatial'])

In [None]:
spatial_stats_df

In [None]:
temp_stats_df

In [None]:
print(hga_raw_pr_df['80% var PC, Spatial'].max(), hga_raw_pr_df['80% var PC, Temporal'].max())
print(hga_raw_pr_df['80% var PC, Spatial'].min(), hga_raw_pr_df['80% var PC, Temporal'].min())

In [None]:
hga_raw_pr_stats_float

### Temporal smoothing

In [None]:
temporal_smoothing = temporal_smoothing.loc[temporal_smoothing['Temporal Smoothing'] != 1]

# Convert to seconds
temporal_smoothing['Temporal Smoothing (s)'] = temporal_smoothing['Temporal Smoothing'].values / sr

temporal_smoothing['temporal_smoothing'] = temporal_smoothing['Temporal Smoothing (s)'].values
temporal_smoothing['accuracy_fraction'] = temporal_smoothing['accuracy (% of original)'].values
temporal_smoothing['feature_type'] = temporal_smoothing['Data Type'].values

In [None]:
# Runs Wilcoxon signed-rank tests
data_types = set(temporal_smoothing['Data Type'])
smoothings = sorted(set(temporal_smoothing['Temporal Smoothing (s)']))
acc_fracs = {k: [] for k in data_types}

for cur_smoothing in smoothings:
    cur_df = temporal_smoothing[temporal_smoothing['Temporal Smoothing (s)'] == cur_smoothing]
    cur_data = {k : cur_df[cur_df['Data Type'] == k]['accuracy_fraction'].values for k in data_types}
    assert len(np.unique([len(v) for v in cur_data.values()])) == 1
    for k, v in cur_data.items():
        acc_fracs[k].extend(list(v))

stats_results = [
    {'key1' : k1, 'key2': k2, 'p': stats.wilcoxon(acc_fracs[k1], acc_fracs[k2])[1], 'stat': stats.wilcoxon(acc_fracs[k1], acc_fracs[k2])[0]}
    for k1, k2 in itertools.combinations(data_types, 2)
]
corrected_p_vals = holm_bonferroni_correction(
    {i: v['p'] for i, v in enumerate(stats_results)}
)
for k, cur_corrected_p in corrected_p_vals.items():
    stats_results[k]['p_mc_corrected'] = cur_corrected_p
results_df = pd.DataFrame(stats_results)

results_df

## Overall Figure

In [None]:
# Parameters: these are the same for all plots so we can just reuse them
cbar_params = {
    'plot_colorbar'        : False,
    'colorbar_title'       : 'Color title',
}
elec_size_color_params = {
    'color_params' : {'min': 0,  'max': 1.0,   'relative': True},
    'size_params'  : {'min': 10, 'max': 30, 'relative': True, 'scale': 30},
    'alpha_params' : {'min': 0.2, 'max': 1.0,  'relative': True, 'exponent': 0.5}
}
other_params = {
    'show_fig': False,
    'elec_size_color_params': elec_size_color_params
}
all_plot_params = {'elec_loc_file_path': 'bravo1_elecmat_2D.npy',
 'all_image_params': {'file_name': 'bravo1_brain_2D.png',
  'invert_y': True,
  'alpha': 0.25},
 'y_scale': -1.0,
 'add_height': True,
 'elec_plot_params': {'linewidths': 0.0, 'zorder': 100000.0},
 'elec_size_color_params': {'color_spec': 'black',
  'size_params': {'min': 100.0, 'max': 100.0, 'relative': False},
  'alpha_params': {'min': 0.8, 'max': 0.8, 'relative': False},
  'color_params': {}}}
all_plot_params.update(cbar_params)
all_plot_params.update(other_params)

In [None]:
# Get Dark2 colors
set2_colors = sns.color_palette("Set2")

colors = {
    'overt': set2_colors[0],
    'mimed': set2_colors[1],
    'alphabet1_1': set2_colors[4],
    'alphabet1_2': set2_colors[6],
    'hga': set2_colors[3],
    'raw': set2_colors[2],
    'hga + raw': set2_colors[7]
}

brain_closeup = {
    'xlim': [200, 500],
    'ylim': [150, 550]
}

set2_colors

In [None]:
linewidth = 1
annot_linewidth = 1
fontsize = 7
plotting_defaults(font='Arial', fontsize=fontsize, linewidth=linewidth)
panel_label_fontsize = 7
boxplot_kwargs = {'fliersize': 3}
scatter_size = 3
mm = 1 / 25.4
mm_figsize = [mm*180, mm*130]

fig = plt.figure(figsize=mm_figsize)

gs = mpl.gridspec.GridSpec(2, 9, figure=fig, wspace=1.5, hspace=0.6)
sal_gs = mpl.gridspec.GridSpecFromSubplotSpec(1, 4, subplot_spec=gs[0, 3:], wspace=0.05)
spatial_gs = mpl.gridspec.GridSpecFromSubplotSpec(1, 3, subplot_spec=gs[1, :3], wspace=0.4, hspace=0.6)
temporal_gs = mpl.gridspec.GridSpecFromSubplotSpec(1, 3, subplot_spec=gs[1, 3:6], wspace=0.4, hspace=0.6)
axs = {}


##### ----- HGA vs HGA + Raw: decoding accuracies
axs['hga_raw_accs_box'] = fig.add_subplot(gs[0, :3])
axs['hga_raw_accs_box'] = sns.boxplot(data=hga_raw_df, x='Features', y='Accuracy', ax=axs['hga_raw_accs_box'],
                                 order=second_feature_names.keys(), palette=[colors[key] for key in feature_names.keys()] ,
                                     showfliers=False, **boxplot_kwargs)
axs['hga_raw_accs_box'] = sns.stripplot(data=hga_raw_df, x='Features', y='Accuracy', ax=axs['hga_raw_accs_box'],
                                 order=second_feature_names.keys(), palette=[colors[key] for key in feature_names.keys()], 
                                       color='black', edgecolor='black', linewidth=linewidth-0.5, size=scatter_size)
axs['hga_raw_accs_box'].axes.set(xticklabels=second_feature_names.values(), 
                                  ylabel='NATO code-word accuracy',
                                  ylim=(0.25, 0.6));
axs['hga_raw_accs_box'].axes.set_xlabel('Feature type', labelpad=0)
add_stat_annotation(axs['hga_raw_accs_box'], data=hga_raw_df, x='Features', y='Accuracy', order=second_feature_names.keys(),
                   box_pairs=hgaraw_box_pairs, perform_stat_test=False, pvalues=hgaraw_pvals,
                   text_format='star', loc='outside', pvalue_thresholds=pvalue_thresholds, linewidth=annot_linewidth)


##### ----- HGA alone saliences
axs['hga_sal_alone'] = fig.add_subplot(sal_gs[0, 0])
all_plot_params['elec_weights'] = saliences['hga_alone_salience']
all_plot_params['fig'] = fig
all_plot_params['elec_size_color_params']['color_spec'] = colors['hga']
plot_images_and_elecs(**all_plot_params, ax=axs['hga_sal_alone'])
axs['hga_sal_alone'].axes.set(xlim=brain_closeup['xlim'], ylim=brain_closeup['ylim'], title='HGA\nalone')


##### ----- HGA from HGA + raw saliences
axs['hga_sal_from_rawhga'] = fig.add_subplot(sal_gs[0, 1])
all_plot_params['elec_weights'] = saliences['rawhga_salience'][128:]
all_plot_params['fig'] = fig
all_plot_params['elec_size_color_params']['color_spec'] = colors['hga']
plot_images_and_elecs(**all_plot_params, ax=axs['hga_sal_from_rawhga'])
axs['hga_sal_from_rawhga'].axes.set(xlim=brain_closeup['xlim'], ylim=brain_closeup['ylim'], title='HGA\nfrom HGA+LFS')


##### ----- Raw alone saliences
axs['raw_sal_alone'] = fig.add_subplot(sal_gs[0, 2])
all_plot_params['elec_weights'] = saliences['raw_alone_salience']
all_plot_params['fig'] = fig
all_plot_params['elec_size_color_params']['color_spec'] = colors['raw']
plot_images_and_elecs(**all_plot_params, ax=axs['raw_sal_alone'])
axs['raw_sal_alone'].axes.set(xlim=brain_closeup['xlim'], ylim=brain_closeup['ylim'], title='LFS\nalone')


##### ----- Raw from HGA + raw saliences
axs['raw_sal_from_rawhga'] = fig.add_subplot(sal_gs[0, 3])
ax = axs['raw_sal_from_rawhga']
all_plot_params['elec_weights'] = saliences['rawhga_salience'][:128]
all_plot_params['fig'] = fig
all_plot_params['elec_size_color_params']['color_spec'] = colors['raw']
plot_images_and_elecs(**all_plot_params, ax=axs['raw_sal_from_rawhga'])
axs['raw_sal_from_rawhga'].axes.set(xlim=brain_closeup['xlim'], ylim=brain_closeup['ylim'], title='LFS\nfrom HGA+LFS')


##### ----- HGA vs HGA + Raw: Spatial PCs
axs['alphnato_spatial_pr_stat'] = fig.add_subplot(spatial_gs[0, :3])
axs['alphnato_spatial_pr_stat'] = sns.stripplot(data=hga_raw_pr_df, x='Condition', y='80% var PC, Spatial', ax=axs['alphnato_spatial_pr_stat'],
                                 order=feature_names.keys(), palette=3*['white'])
axs['alphnato_spatial_pr_stat'].axes.set(ylim=(2, 14))
add_stat_annotation(axs['alphnato_spatial_pr_stat'], data=hga_raw_pr_df, x='Condition', y='80% var PC, Spatial', order=feature_names.keys(),
                   box_pairs=hgaraw_pr_box_pairs['spatial'], perform_stat_test=False, pvalues=hgaraw_pr_pvals['spatial'],
                   text_format='star', loc='outside', pvalue_thresholds=pvalue_thresholds, line_offset_to_box=0.1, linewidth=annot_linewidth)
axs['alphnato_spatial_pr_stat'].axis('off')

axs['alphnato_spatial_pr'] = [fig.add_subplot(spatial_gs[0, 0]), fig.add_subplot(spatial_gs[0, 1]), fig.add_subplot(spatial_gs[0, 2])]
for ax, key in zip(axs['alphnato_spatial_pr'], feature_names.keys()):
    pcs, pc_counts = np.unique(hga_raw_pr_df.loc[hga_raw_pr_df['Condition'] == key]['80% var PC, Spatial'].values, return_counts=True)
    ax.barh(pcs, pc_counts, height=0.5, clip_on=False, color=colors[key], edgecolor='grey')
    ax.axes.set(ylim=(4, 18), yticks=range(4, 19, 2), xlim=(0, 100), xticks=[0, 50, 100], xticklabels=[0, '', 100])

axs['alphnato_spatial_pr'][0].axes.set_ylabel('No. feature PCs, $\sigma^2$ > 80%', labelpad=-1)
axs['alphnato_spatial_pr'][1].axes.set(xlabel='Percent of bootstraps')
for ax in axs['alphnato_spatial_pr'][1:]:
    ax.axes.set(yticklabels=[])
    
handles = []
for key, val in feature_names.items():
    handles.append(mpl.patches.Patch(color=colors[key], label=val))
axs['alphnato_spatial_pr'][0].legend(handles=handles, handlelength=0.5, fontsize=fontsize-1, bbox_to_anchor=(-0.05, 0.98), loc='upper left', frameon=False)


##### ----- HGA vs HGA + Raw: Temporal PCs
axs['alphnato_temporal_pr_stat'] = fig.add_subplot(temporal_gs[0, :3])
axs['alphnato_temporal_pr_stat'] = sns.stripplot(data=hga_raw_pr_df, x='Condition', y='80% var PC, Temporal', ax=axs['alphnato_temporal_pr_stat'],
                                 order=feature_names.keys(), palette=3*['white'])
axs['alphnato_temporal_pr_stat'].axes.set(ylim=(12, 24))
add_stat_annotation(axs['alphnato_temporal_pr_stat'], data=hga_raw_pr_df, x='Condition', y='80% var PC, Temporal', order=feature_names.keys(),
                   box_pairs=hgaraw_pr_box_pairs['temporal'], perform_stat_test=False, pvalues=hgaraw_pr_pvals['temporal'],
                   text_format='star', loc='outside', pvalue_thresholds=pvalue_thresholds, line_offset_to_box=0.1, linewidth=annot_linewidth)
axs['alphnato_temporal_pr_stat'].axis('off')

axs['alphnato_temporal_pr'] = [fig.add_subplot(temporal_gs[0, 0]), fig.add_subplot(temporal_gs[0, 1]), fig.add_subplot(temporal_gs[0, 2])]
for ax, key in zip(axs['alphnato_temporal_pr'], feature_names.keys()):
    pcs, pc_counts = np.unique(hga_raw_pr_df.loc[hga_raw_pr_df['Condition'] == key]['80% var PC, Temporal'].values, return_counts=True)
    
    ax.barh(pcs, pc_counts, height=0.5, clip_on=False, color=colors[key], edgecolor='grey')
    ax.axes.set(ylim=(12, 24), xlim=(0, 100), xticks=[0, 50, 100], xticklabels=[0, '', 100])

axs['alphnato_temporal_pr'][0].axes.set_ylabel('No. temporal PCs, $\sigma^2$ > 80%', labelpad=-1)
axs['alphnato_temporal_pr'][1].axes.set(xlabel='Percent of bootstraps')
for ax in axs['alphnato_temporal_pr'][1:]:
    ax.axes.set(yticklabels=[])
    
handles = []
for key, val in feature_names.items():
    handles.append(mpl.patches.Patch(color=colors[key], label=val))
axs['alphnato_temporal_pr'][0].legend(handles=handles, handlelength=0.5, fontsize=fontsize-1, bbox_to_anchor=(-0.05, 0.98), loc='upper left', frameon=False)


##### ----- Temporal smoothing
axs['temporal_smoothing'] = fig.add_subplot(gs[1, 6:])
axs['temporal_smoothing'] = sns.lineplot(data=temporal_smoothing, x='Temporal Smoothing (s)', y='accuracy (% of original)', hue='Data Type',
             ax=axs['temporal_smoothing'], hue_order=['hga', 'raw', 'hga_and_raw'], ci=99, err_style='bars',
              palette=[colors['hga'], colors['raw'], colors['hga + raw']], estimator=np.median, clip_on=False)
axs['temporal_smoothing'] = sns.scatterplot(data=temporal_smoothing.groupby(['Temporal Smoothing (s)', 'Data Type']).median(), x='Temporal Smoothing (s)', 
                                            y='accuracy (% of original)', hue='Data Type', clip_on=False,
                 ax=axs['temporal_smoothing'], hue_order=['hga', 'raw', 'hga_and_raw'], s=25, legend=False,
                  palette=[colors['hga'], colors['raw'], colors['hga + raw']])

plt.setp(axs['temporal_smoothing'].lines, clip_on=False)
axs['temporal_smoothing'].axes.set(ylim=(0.2, 1.0),
                                   xlabel='Gaussian filter width (s)', xlim=(None, 1.0))
axs['temporal_smoothing'].axes.set_ylabel('Fraction of original accuracy', labelpad=0)
handles, labels = axs['temporal_smoothing'].get_legend_handles_labels()
axs['temporal_smoothing'].legend(handles=handles, labels=[all_feature_names[key] for key in labels], 
                                 frameon=False, bbox_to_anchor=(1, 0.98), loc='upper right', fontsize=fontsize-1);
for o in plt.findobj():
    o.set_clip_on(False)
    
fig.tight_layout();

##### ----- Figure panel labels
axs['hga_raw_accs_box'].annotate('a', (-0.2, 1.09), xycoords='axes fraction', ha='right', fontsize=panel_label_fontsize, weight='bold')
axs['hga_sal_alone'].annotate('b', (-0.03, 1.27), xycoords='axes fraction', ha='left', fontsize=panel_label_fontsize, weight='bold')
axs['hga_sal_from_rawhga'].annotate('c', (-0.03, 1.27), xycoords='axes fraction', ha='left', fontsize=panel_label_fontsize, weight='bold')
axs['raw_sal_alone'].annotate('d', (-0.03, 1.27), xycoords='axes fraction', ha='left', fontsize=panel_label_fontsize, weight='bold')
axs['raw_sal_from_rawhga'].annotate('e', (-0.03, 1.27), xycoords='axes fraction', ha='left', fontsize=panel_label_fontsize, weight='bold')
axs['alphnato_spatial_pr_stat'].annotate('f', (-0.1, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_label_fontsize, weight='bold')
axs['alphnato_temporal_pr_stat'].annotate('g', (-0.1, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_label_fontsize, weight='bold')
axs['temporal_smoothing'].annotate('h', (-0.1, 1.1), xycoords='axes fraction', ha='right', fontsize=panel_label_fontsize, weight='bold');

### Save figure

In [None]:
figure_dpi = 300

for ext in ['png', 'pdf']:
    fig.savefig(os.path.join(fig_dir, f'figure3_hga_vs_raw.{ext}'), 
                transparent=True, bbox_inches='tight', dpi=figure_dpi)
    fig.savefig(os.path.join(fig_dir, f'figure3_hga_vs_raw_white.{ext}'), 
                transparent=False, bbox_inches='tight', dpi=figure_dpi)