In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Imports
import numpy as np
from distractors_figures.plotting_utils import plot_fncs, color_pals

import pickle
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.font_manager
import matplotlib as mpl

In [None]:
data_path = '../data'

In [None]:
with open(f'{data_path}/function_resp_df.pkl','rb') as f:
    df = pickle.load(f)
stat_thresh = np.min(df.read_z[df.read_p < 0.05])

In [None]:
with open(f'{data_path}/gate_dfs_ensem.pkl','rb') as f:
    d = pickle.load(f)
df_gate = d['b3']
df_gate_b1 = d['b1']

In [None]:
with open(f'{data_path}/lf_results.pkl','rb') as f:
    summary_psd = pickle.load(f)

In [None]:
# B3 data used for examples
with open(f'{data_path}/sample_fnc_erps.pkl','rb') as f:
    d = pickle.load(f)
    
speech = d['speech']
read = d['read']
listen = d['listen']

In [None]:
with open(f'{data_path}/gate_sals.pkl','rb') as f:
    d_sals = pickle.load(f)

In [None]:
with open(f'{data_path}/gate_cms.pkl','rb') as f:
    d_cms = pickle.load(f)

In [None]:
# set plot params 
import warnings
warnings.filterwarnings('ignore')
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
# mpl.rcParams['font.sans-serif'] = ['Arial']
plt.rcParams.update({'font.size': 6})
plt.rcParams['svg.fonttype'] = 'none'

In [None]:

# variables to make the spacing of the plot 
pad = 150
width =int( (1000 - pad*2)/3 -35) 
pad_brains = 25
width_brains = int( (1000 - pad_brains*3)/4)

region_pal = color_pals.regions[:4]
region_order = ['temporal','precentral','postcentral','frontal']
distractor_pal = color_pals.distractors
distractor_order = ['Speech','Read','Listen']
region_pal_dict = color_pals.region_dict
lw = 0.5

brain_width = 125
brain_pad = 5
funciton_pad = 15
rows = {
    'b3brain_start': 0,
    'b3brain_stop': 350,
    'top_start' : 0,
    'top_stop'  : 300,
    'topartic_start' : 50,#70,
    'topartic_stop'  : 300,#320,
    'b1brain_start' : 200,
    'b1brain_stop' : 325,
    'middle_start': 350,
    'middle_stop' : 350+width,
    'bottom_start': 700,
    'bottom_stop' : 900,
    'last_start': 975,
    'last_stop' :1125,
    'lastb3regstart_start': 1000,
    'lastb3regstart_stop' : 1125,
    'lastlast_start': 1200,
    'lastlast_stop' : 1350,
    'cbar_start': 1160,
    'cbar_stop' : 1175,
    'total'      : 1350
}
cols = {
    'brain1_start':0,
    'brain1_stop': brain_width,
    'brain2_start':brain_width + brain_pad,
    'brain2_stop': 2*brain_width + brain_pad,
    'brain3_start':2*brain_width + 2*brain_pad,
    'brain3_stop': 3*brain_width + 2*brain_pad,
    'brain4_start': 3*brain_width + 2*brain_pad + funciton_pad,
    'brain4_stop': 4*brain_width + 2*brain_pad + funciton_pad,
    'brain5_start': 4*brain_width + 3*brain_pad + funciton_pad,
    'brain5_stop': 5*brain_width + 3*brain_pad + funciton_pad,
    'brain6_start': 5*brain_width + 4*brain_pad + funciton_pad,
    'brain6_stop': 6*brain_width + 4*brain_pad + funciton_pad,
    'artic_start': 6*brain_width + 4*brain_pad + 3*funciton_pad,
    'artic_stop': 1000-70,
    'left_start' : 0,
    'left_stop'  : width,
    'middle_start' : width+pad,
    'middle_stop'  : 2*width+pad,
    'right_start' : 2*width + 2*pad,
    'right_stop'  : 3*width + 2*pad,  
    'erp_start' : 0,
    'erp_stop' : 200,
    'b1lfs_start' : 300,
    'b1lfs_stop' : 640,
    'b3lfs_start' : 660,
    'b3lfs_stop' : 1000,
    'left_bottom_start': 0,
    'left_bottom_stop' : 600-30,
    'right_bottom1_start': 625-30,
    'right_bottom1_stop' : 775-30,
    'right_bottom2_start': 850,
    'right_bottom2_stop' : 1000,
    'total'      : 1000
}
start_offset = 120
stop_offset = 40
cols['b1brain1_start'] = cols['brain1_start'] + start_offset
cols['b1brain1_stop'] = cols['brain1_stop'] + stop_offset

cols['b1brain2_start'] = cols['brain2_start'] + start_offset
cols['b1brain2_stop'] = cols['brain2_stop'] + stop_offset

cols['b1brain3_start'] = cols['brain3_start'] + start_offset
cols['b1brain3_stop'] = cols['brain3_stop'] + stop_offset

all_panel_params = {
    'speech' : {'row_and_col_spec' : ('b3brain', 'brain4')},
    'read' : {'row_and_col_spec' : ('b3brain', 'brain5')},
    'listen' : {'row_and_col_spec' : ('b3brain', 'brain6')},
    'speechb1' : {'row_and_col_spec' : ('b3brain', 'brain1')},
    'readb1' : {'row_and_col_spec' : ('b3brain', 'brain2')},
    'listenb1' : {'row_and_col_spec' : ('b3brain', 'brain3')},
    
    
    'artic' : {'row_and_col_spec' : ('topartic', 'artic')},
    'scatter1' : {'row_and_col_spec' : ('middle', 'left')},
    'scatter2' : {'row_and_col_spec' : ('middle', 'middle')},
    'scatter3' : {'row_and_col_spec' : ('middle', 'right')},
    'erp_ex' : {'row_and_col_spec' : ('bottom', 'erp')},
    'b1lfs' : {'row_and_col_spec' : ('bottom', 'b1lfs')},
    'b3lfs' : {'row_and_col_spec' : ('bottom', 'b3lfs')},
    'regaccs' : {'row_and_col_spec' : ('lastb3regstart', 'left_bottom')},
    'hgasal' : {'row_and_col_spec' : ('last', 'right_bottom1')},
    'lfssal' : {'row_and_col_spec' : ('last', 'right_bottom2')},
    'regaccs_b1' : {'row_and_col_spec' : ('lastlast', 'left_bottom')},
    'b1sal' : {'row_and_col_spec' : ('lastlast', 'right_bottom1')},
    'b1cm' : {'row_and_col_spec' : ('lastlast', 'right_bottom2')},
    'cbar' : {'row_and_col_spec' : ('cbar', 'right_bottom2')}
}


# scale the plot dimensions
dims = np.array([cols['total'],rows['total']])/np.max([rows['total'],cols['total']])
scale = 17

# Creates the plot
fig, axs = plot_fncs.setup_figure(
    all_panel_params=all_panel_params, row_specs=rows, col_specs=cols,
    figsize=(7.20+1.1,7.6+2) #scale*dims #15,12
)

### PANEL: B3 brain map for speech
b3_xlims = [205,515] #310
b3_ylims = [100,525] #425

b1_xlims = [195,505] # 310
b1_ylims = [125,550] # 425
ax = axs['speech']
size_max = 20
tmp = df[df.patient == 'B3'].speech_z.values.copy()
tmp[df[df.patient == 'B3'].speech_p.values > 0.05] = np.min(tmp)
plot_fncs.plot_vals_on_brain(tmp,'#4F4E4C',fig,ax,data_path,size_min=0,size_max=size_max,region_color=region_pal_dict['precentral'])  
ax.set(title='Attempted speech')
ax.set(xlim=b3_xlims,ylim=b3_ylims)

### PANEL: B3 brain map for reading
ax = axs['read']
tmp = df[df.patient == 'B3'].read_z.values.copy()
tmp[df[df.patient == 'B3'].read_p.values > 0.05] = np.min(tmp)
plot_fncs.plot_vals_on_brain(tmp,'#4F4E4C',fig,ax,data_path,size_min=0,size_max=size_max,region_color=region_pal_dict['precentral']) 
ax.set_title('Read')
ax.annotate('Bravo-3', 
            xy=(0.5, -0.1), 
            xycoords='axes fraction', 
            ha='center', 
            va='center',fontsize=6)


ax.set(xlim=b3_xlims,ylim=b3_ylims)

### PANEL: B3 brain map for listening
ax = axs['listen']
tmp = df[df.patient == 'B3'].listen_z.values.copy()
tmp[df[df.patient == 'B3'].listen_p.values > 0.05] = np.min(tmp)
plot_fncs.plot_vals_on_brain(tmp,'#4F4E4C',fig,ax,data_path,size_min=0,size_max=size_max,region_color=region_pal_dict['precentral'])  
ax.set(title='Listen')
ax.set(xlim=b3_xlims,ylim=b3_ylims)

### PANEL: B1 brain map for speech
ax = axs['speechb1']
tmp = df[df.patient == 'B1'].speech_z.values.copy()
tmp[df[df.patient == 'B1'].speech_p.values > 0.05] = np.min(tmp)
plot_fncs.plot_vals_on_brain(tmp,'#4F4E4C',fig,ax,data_path,size_min=0,size_max=20,subject='bravo1',region_color=region_pal_dict['precentral'])    
ax.set(xlim=b1_xlims,ylim=b1_ylims)
ax.set_title('Attempted speech')
ax.annotate('Task activation \n(Statistical Z-stat)', xy=(-0.2, 0.5), rotation=90, 
            va='center', ha='center',xycoords='axes fraction',fontsize=6)

### PANEL: B1 brain map for read
ax = axs['readb1']
tmp = df[df.patient == 'B1'].read_z.values.copy()
tmp[df[df.patient == 'B1'].read_p.values > 0.05] = np.min(tmp)
plot_fncs.plot_vals_on_brain(tmp,'#4F4E4C',fig,ax,data_path,size_min=0,size_max=20,subject='bravo1',region_color=region_pal_dict['precentral']) 
ax.set(xlim=b1_xlims,ylim=b1_ylims)
ax.set_title('Read')
ax.annotate('Bravo-1', 
            xy=(0.5, -0.1), 
            xycoords='axes fraction', 
            ha='center', 
            va='center',fontsize=6)

### PANEL: B1 brain map for listen
ax = axs['listenb1']
tmp = df[df.patient == 'B1'].listen_z.values.copy()
tmp[df[df.patient == 'B1'].listen_p.values > 0.05] = np.min(tmp)
plot_fncs.plot_vals_on_brain(tmp,'#4F4E4C',fig,ax,data_path,size_min=0,size_max=20,subject='bravo1',region_color=region_pal_dict['precentral'])  
ax.set(xlim=b1_xlims,ylim=b1_ylims)
ax.set_title('Listen')


### PANEL: B3 grid map for artic vs tri-function
ax = axs['artic']

artic_pal = sns.color_palette("Set2").as_hex()

# load the grid layout
with open(f'{data_path}/b3_layout.pkl','rb') as f:
    elec_layout = pickle.load(f)
two_d_coords = np.zeros((253,3))
ind = 0
for c in range(elec_layout.shape[1]):
    for i,r in enumerate(range(elec_layout.shape[0])):
        two_d_coords[ind,:] = [c,-i,elec_layout[r,c]]
        ind += 1
two_d_coords = two_d_coords[two_d_coords[:, 2].argsort(),:]

panel_f = pd.read_excel(f'{data_path}/fig5_sourcedata.xlsx',engine='openpyxl',sheet_name='Panel_f').iloc[:,1:]

# define the electrodes corresponding to each cluster 
hand_clust = plot_fncs.adjust_elec_inds(panel_f['top_hand'])
lips_clust = plot_fncs.adjust_elec_inds(panel_f['top_lips'])
coronal_clust = plot_fncs.adjust_elec_inds(panel_f['top_coronal'])
trifnc = (df[df.patient == 'B3'].speech_p.values < 0.05) & (df[df.patient == 'B3'].read_p.values < 0.05) & (df[df.patient == 'B3'].listen_p.values < 0.05)
trifnc = np.where(trifnc)[0]
vowel_clust = trifnc


# core scatter plot of the electrodes with color by whether they are in top for an artic cluster
scat_size = 10
ax.scatter(two_d_coords[:,0],two_d_coords[:,1],color='k',s=2,alpha=0.2)
ax.scatter(two_d_coords[hand_clust,0],two_d_coords[hand_clust,1],color='k',s=scat_size,marker='x',alpha=0.4)
ax.scatter(two_d_coords[lips_clust,0],two_d_coords[lips_clust,1],color='k',s=scat_size,marker='*',alpha=0.5)
ax.scatter(two_d_coords[coronal_clust,0],two_d_coords[coronal_clust,1],color='k',s=scat_size,marker='s',alpha=0.6)
ax.scatter(two_d_coords[vowel_clust,0],two_d_coords[vowel_clust,1],color='k',s=scat_size) #color='#4F4E4C'
ax.set_ylim([-22.5,0.5])
ax.set_xlim([-0.5,10.5])
ax.set(yticks=[],xticks=[],title='Bravo-3')
ax.spines[['top','right']].set_visible(True)
x_pad = 1.05
y_start = 0.95
y_width = 0.1
ax.annotate('Shared (●)',(x_pad,y_start),xycoords='axes fraction',color='k')
ax.annotate('Hand (x)',(x_pad,y_start-y_width),xycoords='axes fraction',color='k',alpha=0.4)
ax.annotate('Lips (*)',(x_pad,y_start-y_width*2),xycoords='axes fraction',color='k',alpha=0.5)
ax.annotate('Tongue (■)',(x_pad,y_start-y_width*3),xycoords='axes fraction',color='k',alpha=0.6)


# delete the dummy (Unconnected) channels before scatterplot 
df_scat = df[df.anatomy != 'dummy'].copy()

### PANEL: Scatter speech vs read
ax = axs['scatter1']
ax_lims = [0,25]
ax_lims_rl = [0,25]
ax_ticks = [0,5,10,15,20,25]
ax_ticks_rl = [0,10,20,25]
scat_size = 10
tmp = df_scat[(df_scat.speech_p < 0.05) | (df_scat.read_p < 0.05)].copy()
tmp['active_both'] = 0
tmp['active_both'][(tmp.speech_p < 0.05) & (tmp.read_p < 0.05)] = 1
tmp['active_both'] = tmp['active_both'].astype(bool)
sns.scatterplot(ax=ax,data=tmp[tmp['active_both']],
                x='speech_z',y='read_z',hue='anatomy',legend=False,style='patient',
                alpha=0.8,hue_order=region_order,clip_on=False,palette=region_pal,s=scat_size)
sns.scatterplot(ax=ax,data=tmp[np.logical_not(tmp['active_both'])],
                x='speech_z',y='read_z',hue='anatomy',legend=False,style='patient',
                alpha=0.2,hue_order=region_order,clip_on=False,palette=region_pal,s=scat_size)
ax.set(xlabel='Speech (Z-stat)',ylabel='Read (Z-stat)',xlim=ax_lims,ylim=ax_lims_rl,xticks=ax_ticks,yticks=ax_ticks)
ax.axhline(stat_thresh,color='k',linestyle='--',zorder=0)
ax.axvline(stat_thresh,color='k',linestyle='--',zorder=0)
ax.plot([0,ax_lims[1]],[0,ax_lims[1]],color='k',linestyle='--',zorder=0)
sns.despine(ax=ax,offset=5)

### PANEL: Scatter speech vs listen
ax = axs['scatter3']
tmp = df_scat[(df_scat.speech_p < 0.05) | (df_scat.listen_p < 0.05)].copy()
tmp['active_both'] = 0
tmp['active_both'][(tmp.speech_p < 0.05) & (tmp.listen_p < 0.05)] = 1
tmp['active_both'] = tmp['active_both'].astype(bool)
g = sns.scatterplot(ax=ax,data=tmp[tmp['active_both']],
                x='speech_z',y='listen_z',hue='anatomy',legend=False,style='patient',
                    alpha=0.8,hue_order=region_order,clip_on=False,palette=region_pal,s=scat_size)
sns.scatterplot(ax=ax,data=tmp[np.logical_not(tmp['active_both'])],
                x='speech_z',y='listen_z',hue='anatomy',legend=False,style='patient',
                alpha=0.4,hue_order=region_order,clip_on=False,palette=region_pal,s=scat_size)
ax.set(xlabel='Speech (Z-stat)',ylabel='Listen (Z-stat)',
       xlim=ax_lims,ylim=ax_lims_rl,xticks=ax_ticks,yticks=ax_ticks)
ax.axhline(stat_thresh,color='k',linestyle='--',zorder=0)
ax.axvline(stat_thresh,color='k',linestyle='--',zorder=0)
ax.plot([0,ax_lims[1]],[0,ax_lims[1]],color='k',linestyle='--',zorder=0)

sns.despine(ax=ax,offset=5)

### PANEL: Scatter listen vs read
ax = axs['scatter2']
tmp = df_scat[(df_scat.read_p < 0.05) | (df_scat.listen_p < 0.05)].copy()
tmp['active_both'] = 0
tmp['active_both'][(tmp.read_p < 0.05) & (tmp.listen_p < 0.05)] = 1
tmp['active_both'] = tmp['active_both'].astype(bool)
sns.scatterplot(ax=ax,data=tmp[tmp['active_both']],
                x='read_z',y='listen_z',hue='anatomy',legend=True,style='patient',
                alpha=0.8,hue_order=region_order,clip_on=False,palette=region_pal,s=scat_size)
g= sns.scatterplot(ax=ax,data=tmp[np.logical_not(tmp['active_both'])],
                x='read_z',y='listen_z',hue='anatomy',legend=False,style='patient',
                   alpha=0.4,hue_order=region_order,clip_on=False,palette=region_pal,s=scat_size)

ax.set(xlabel='Read (Z-stat)',ylabel='Listen (Z-stat)',xlim=ax_lims,ylim=ax_lims,xticks=ax_ticks,yticks=ax_ticks)
ax.axhline(stat_thresh,color='k',linestyle='--',zorder=0)
ax.axvline(stat_thresh,color='k',linestyle='--',zorder=0)
ax.plot([0,ax_lims[1]],[0,ax_lims[1]],color='k',linestyle='--',zorder=0)
sns.despine(ax=ax,offset=5)

handles, labels = ax.get_legend_handles_labels()
axs['scatter2'].legend(handles[1:5], labels[1:5], frameon=False, ncol=1, title='Region',
                      loc='lower left', bbox_to_anchor=(2.4, 0.5), columnspacing=0.2, handletextpad=0.1)
axs['scatter3'].legend(handles[-2:], labels[-2:], frameon=False, ncol=1, title='Participant',
                      loc='lower left', bbox_to_anchor=(1, 0.2), columnspacing=0.2, handletextpad=0.1)

### PANEL: example erp
ax = axs['erp_ex']
from scipy.stats import sem

alt_pal = ['#3C3633', '#747264','#E0CCBE']#sns.color_palette('Set2').as_hex()
t_ar = np.linspace(-1,2,600)
speech_sig = np.mean(speech,axis=0)
read_sig = np.mean(read,axis=0)
listen_sig = np.mean(listen,axis=0)
ax.plot(t_ar,speech_sig[200:800],color=distractor_pal[0],label='Speech')
ax.fill_between(t_ar,speech_sig[200:800] + sem(speech[:,200:800],axis=0),
                speech_sig[200:800] - sem(speech[:,200:800],axis=0),color=distractor_pal[0],alpha=0.2)

ax.plot(t_ar,read_sig[200:800],color=distractor_pal[1],label='Read')
ax.fill_between(t_ar,read_sig[200:800] + sem(read[:,200:800],axis=0),
                read_sig[200:800] - sem(read[:,200:800],axis=0),color=distractor_pal[1],alpha=0.2)

ax.plot(t_ar,listen_sig[200:800],color=distractor_pal[2],label='Listen')
ax.fill_between(t_ar,listen_sig[200:800] + sem(listen[:,200:800],axis=0),
                listen_sig[200:800] - sem(listen[:,200:800],axis=0),color=distractor_pal[2],alpha=0.2)

ax.axvline(0,linestyle='--',color='k')
ax.set(ylabel='HGA (Z)',xlabel='Time (s)',ylim=[-0.5,2],yticks=[-0.5,0,1,2],xlim=[-1,2])
ax.legend(loc=(0.70,0.70),frameon=False)
sns.despine(ax=ax,offset=5)

### PANEL: LFS/HGA power for B1
ax = axs['b1lfs']
summary_psd['comparison'][summary_psd.comparison == 'speech'] = 'Speech'
g = sns.violinplot(data=summary_psd[summary_psd.patient == 'B1']
            ,x='band',hue='comparison',y='power',ax=ax,palette=distractor_pal,hue_order=distractor_order,zorder=100)
ax.set(ylim=[-5,5],yticks=[-5,0,5],ylabel='Mean Power (0-2s, Z)',xlabel='Bravo-1')
g.legend(ncol=1,title='',frameon=False,bbox_to_anchor=(0.25,0.3))#(0.77,1.05))
plt.setp(g, clip_on=False)


from statannot import add_stat_annotation
from scipy.stats import mannwhitneyu, ranksums
from statsmodels.stats.multitest import multipletests
x_order = ['theta','beta','HGA']
hue_order=distractor_order#['Ideal', 'Premium', 'Good', 'Very Good', 'Fair']
box_pairs= []
p_vals = []
data_og = summary_psd[summary_psd.patient == 'B1'].copy()
for x in x_order:
    data = data_og[data_og.band == x]
    p_vals.append(ranksums(data[data.comparison == 'Speech'].power.values
                           ,data[data.comparison == 'Read'].power.values)[1])
    p_vals.append(ranksums(data[data.comparison == 'Speech'].power.values
                           ,data[data.comparison == 'Listen'].power.values)[1])
    p_vals.append(ranksums(data[data.comparison == 'Read'].power.values
                           ,data[data.comparison == 'Listen'].power.values)[1])
    
    box_pairs.append(  ((x, distractor_order[0]), (x, distractor_order[1])) )
    box_pairs.append(  ((x, distractor_order[0]), (x, distractor_order[2])) )
    box_pairs.append(  ((x, distractor_order[1]), (x, distractor_order[2])) )
p_vals = multipletests(p_vals,method='holm')[1]
add_stat_annotation(ax, data=summary_psd[summary_psd.patient == 'B1']
            ,x='band',hue='comparison',y='power', box_pairs=box_pairs,
                    test=None, loc='outside', verbose=2,pvalues=p_vals,perform_stat_test=False,
                   line_offset=0.01,line_offset_to_box=0)

sns.despine(ax=ax,offset={'left': 5, 'bottom': 5})
ax.spines['bottom'].set_visible(False)
ax.xaxis.set_tick_params(length=0)

### PANEL: LFS/HGA power for B3
ax = axs['b3lfs']
g = sns.violinplot(data=summary_psd[summary_psd.patient == 'B3']
            ,x='band',hue='comparison',y='power',ax=ax,palette=distractor_pal,hue_order=distractor_order,zorder=100)
#ax.axhline(0,linestyle='--',color='k')
ax.set(ylabel='',yticks=[],xlabel='Bravo-3')
ax.set(ylim=[-5,5])
ax.spines['left'].set_visible(False)
ax.get_legend().remove()
sns.despine(ax=ax,offset={'bottom': 5},left=True)
plt.setp(g, clip_on=False)


box_pairs= []
p_vals = []
data_og = summary_psd[summary_psd.patient == 'B3'].copy()
for x in x_order:
    data = data_og[data_og.band == x]
    p_vals.append(ranksums(data[data.comparison == 'Speech'].power.values
                           ,data[data.comparison == 'Read'].power.values)[1])
    p_vals.append(ranksums(data[data.comparison == 'Speech'].power.values
                           ,data[data.comparison == 'Listen'].power.values)[1])
    
    p_vals.append(ranksums(data[data.comparison == 'Read'].power.values
                           ,data[data.comparison == 'Listen'].power.values)[1])
    
    box_pairs.append(  ((x, distractor_order[0]), (x, distractor_order[1])) )
    box_pairs.append(  ((x, distractor_order[0]), (x, distractor_order[2])) )
    box_pairs.append(  ((x, distractor_order[1]), (x, distractor_order[2])) )

p_vals = multipletests(p_vals,method='holm')[1]
add_stat_annotation(ax, data=summary_psd[summary_psd.patient == 'B3']
            ,x='band',hue='comparison',y='power', box_pairs=box_pairs,
                    test=None, loc='outside', verbose=2,
                    pvalues=p_vals,perform_stat_test=False,line_offset=0.01,line_offset_to_box=0)


ax.spines['bottom'].set_visible(False)
ax.xaxis.set_tick_params(length=0)


### PANEL: Salience for B3 gate
ax = axs['hgasal']
b3_sals_adj = d_sals['b3']
plot_fncs.plot_vals_on_brain(b3_sals_adj,'#4F4E4C',fig,ax,data_path,size_max=30.,size_min=1.,region_color=region_pal_dict['precentral'])    
ax.set(xlim=b3_xlims,ylim=b3_ylims)
ax.annotate('Electrode contributions\nBravo-3', 
            xy=(0.5, -0.1), 
            xycoords='axes fraction', 
            ha='center', 
            va='center',fontsize=6)

ticks = ['Speech','Read','Listen']

### PANEL: CM for B3 gate
ax = axs['lfssal']
C = d_cms['b3']
ax.imshow(C,cmap='binary',vmin=0,vmax=1)
ax.spines[['top','right']].set_visible(True)
ax.set(xticks=[0,1,2],yticks=[0,1,2],xticklabels=ticks,yticklabels=ticks,
     ylabel='Ground truth')
ax.set_ylabel(ylabel='Ground truth',labelpad=0)

### PANEL: Region accs for B3 gate
ax = axs['regaccs']
df_gate['region'][df_gate.region == 'temporal_lobe'] = 'temporal'
reg_order = list(df_gate.groupby(by='region').mean().sort_values(by='accuracy',ascending=True).index.values)
b3_gate_pal = [region_pal_dict[anat] for anat in reg_order]

g = sns.boxplot(data=df_gate,x='feat_stream',hue='region',y='accuracy', order = ['hga','hga_lfs'],
             hue_order=reg_order,palette=b3_gate_pal,ax=ax,linewidth=lw)
ax.set(ylim=[80,100],ylabel='Bravo-3\n accuracy (%)',xticks=range(2),
       xticklabels=['HGA', 'HGA + LFS'],yticks=[80,90,100],xlim=[-0.5,1.5],xlabel='')
g.legend(loc='lower right',ncol=2,title='',frameon=False)
sns.despine(ax=ax,offset=5)



overall_stat_df = {'P-value':[],'Box_pair':[],'Participant': []}

x_order = ['hga', 'hga_lfs']
hue_order=reg_order
box_pairs= []
p_vals = []
data_og = df_gate.copy()
for x in x_order:
    data = data_og[data_og.feat_stream == x]
    for i in range(len(reg_order)-1):
        
        p_vals.append(ranksums(data[data.region == reg_order[i]].accuracy.values
                               ,data[data.region == reg_order[i+1]].accuracy.values)[1])

        box_pairs.append(  ((x, reg_order[i]), (x, reg_order[i+1])) )

p_vals = multipletests(p_vals,method='holm')[1]
overall_stat_df['P-value'].extend(p_vals)
overall_stat_df['Box_pair'].extend(box_pairs)
overall_stat_df['Participant'].extend(np.repeat('Bravo3',len(p_vals)))

final_pvals,final_bps = [],[]
for i,p in enumerate(p_vals):
    if(p < 0.05):
        final_pvals.append(p)
        final_bps.append(box_pairs[i])


box_pairs = list(box_pairs)

add_stat_annotation(ax, data=df_gate,x='feat_stream',hue='region',y='accuracy'
                    , box_pairs=final_bps,order = ['hga','hga_lfs'], hue_order=reg_order,
                    test=None, loc='outside', verbose=2,pvalues=final_pvals,perform_stat_test=False)

ax.set(ylim=[80,100],yticks=[80,85,90,95,100],clip_on=False)
ax.spines['bottom'].set_visible(False)
ax.xaxis.set_tick_params(length=0)




### PANEL: Region accs for B1 gate
ax = axs['regaccs_b1']
b1_gate_pal = region_pal.copy()
b1_gate_pal = list(np.array(b1_gate_pal)[[1,2,3]])
b1_gate_pal.append('#4F4E4C')

reg_order = list(df_gate_b1.groupby(by='region').mean().sort_values(by='accuracy',ascending=True).index.values)
b1_gate_pal = [region_pal_dict[anat] for anat in reg_order]


g = sns.boxplot(data=df_gate_b1,x='feat_stream',hue='region',y='accuracy', order = ['hga','hga_lfs'],
             hue_order=reg_order,ax=ax,palette=b1_gate_pal,zorder=500,linewidth=lw)
ax.set(ylim=[50,100],ylabel='Bravo-1\n accuracy (%)',xticks=range(2),
       xticklabels=['HGA', 'HGA + LFS'],xlim=[-0.5,1.5],xlabel='')
g.legend(loc='lower right',ncol=2,title='',frameon=False)
sns.despine(ax=ax,offset=5)

x_order = ['hga', 'hga_lfs']
hue_order=reg_order
box_pairs= []
p_vals = []
data_og = df_gate_b1.copy()
for x in x_order:
    data = data_og[data_og.feat_stream == x]
    for i in range(len(reg_order)-1):
        
        p_vals.append(ranksums(data[data.region == reg_order[i]].accuracy.values
                               ,data[data.region == reg_order[i+1]].accuracy.values)[1])

        box_pairs.append(  ((x, reg_order[i]), (x, reg_order[i+1])) )

        
p_vals = multipletests(p_vals,method='holm')[1]
overall_stat_df['P-value'].extend(p_vals)
overall_stat_df['Box_pair'].extend(box_pairs)
overall_stat_df['Participant'].extend(np.repeat('Bravo1',len(p_vals)))
overall_stat_df = pd.DataFrame(overall_stat_df)
final_pvals,final_bps = [],[]
for i,p in enumerate(p_vals):
    if(p < 0.05):
        final_pvals.append(p)
        final_bps.append(box_pairs[i])

add_stat_annotation(ax, data=df_gate_b1,x='feat_stream',hue='region',y='accuracy'
                    , box_pairs=final_bps,order = ['hga','hga_lfs'], hue_order=reg_order,
                    test=None, loc='inside', verbose=2,pvalues=final_pvals,
                    perform_stat_test=False)
ax.set(ylim=[50,100],clip_on=False)
ax.spines['bottom'].set_visible(False)
ax.xaxis.set_tick_params(length=0)


### PANEL: Salience for B1 gate
ax = axs['b1sal']
plot_fncs.plot_vals_on_brain(d_sals['b1'],'#4F4E4C',fig,ax,data_path,size_max=30.,size_min=1.,subject='bravo1',
                            region_color=region_pal_dict['precentral'])    

ax.annotate('Electrode contributions\nBravo-1', 
            xy=(0.5, -0.1), 
            xycoords='axes fraction', 
            ha='center', 
            va='center',fontsize=6)

ax.set(xlim=b1_xlims,ylim=b1_ylims)

### PANEL: CM for B1 gate
ax = axs['b1cm']
C = d_cms['b1']
im = ax.imshow(C,cmap='binary',vmin=0,vmax=1)
ax.spines[['top','right']].set_visible(True)
ax.set(xticks=[0,1,2],yticks=[0,1,2],xticklabels=ticks,yticklabels=ticks,
       xlabel='Predicted',ylabel='Ground truth')


ax.set_ylabel(ylabel='Ground truth',labelpad=0)

cb = plt.colorbar(mappable=im,cax=axs['cbar'],orientation='horizontal',
             anchor=(1,1),ticks=[0,1])
cb.set_label(label='Confusion',labelpad=-8)

####### Mark the panels 
left_pad = -10
t_pad_bot = 60
axs['speechb1'].annotate('A',(left_pad, 100),xycoords='axes points',weight='bold', ha='right',fontsize=9)
axs['artic'].annotate('B',(left_pad/2, 100),xycoords='axes points',weight='bold', ha='right',fontsize=9)
axs['scatter1'].annotate('C',(left_pad-20, 85),xycoords='axes points',weight='bold', ha='right',fontsize=9)
axs['scatter2'].annotate('D',(left_pad-20, 85),xycoords='axes points',weight='bold', ha='right',fontsize=9)
axs['scatter3'].annotate('E',(left_pad-20, 85),xycoords='axes points',weight='bold', ha='right',fontsize=9)
axs['erp_ex'].annotate('F',(left_pad-20, 85),xycoords='axes points',weight='bold', ha='right',fontsize=9)
axs['b1lfs'].annotate('G',(left_pad-20, 85),xycoords='axes points',weight='bold', ha='right',fontsize=9)
axs['regaccs'].annotate('H',(left_pad-20, t_pad_bot),xycoords='axes points',weight='bold', ha='right',fontsize=9)
axs['hgasal'].annotate('I',(left_pad/2, t_pad_bot),xycoords='axes points',weight='bold', ha='right',fontsize=9)
axs['lfssal'].annotate('J',(left_pad-20, t_pad_bot),xycoords='axes points',weight='bold', ha='right',fontsize=9)
# axs['regaccs_b1'].annotate('K',(-40, t_pad_bot),xycoords='axes points',weight='bold', ha='right',fontsize=9)
# axs['b1sal'].annotate('L',(left_pad/2, t_pad_bot),xycoords='axes points',weight='bold', ha='right',fontsize=9)
# axs['b1cm'].annotate('M',(left_pad-20, t_pad_bot),xycoords='axes points',weight='bold', ha='right',fontsize=9)

plt.show();