In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# speech content decode
import pickle
import os
import numpy as np
import pandas as pd 
from distractors_figures.plotting_utils import plot_fncs, color_pals

data_path = '../data'
with open(f'{data_path}/region_acc_dfs.pkl','rb') as f:
    df = pickle.load(f)

In [None]:
with open(f'{data_path}/speech_listen_sals.pkl','rb') as f:
    d_sals = pickle.load(f)

In [None]:
with open(f'{data_path}/cross_train_test.pkl','rb') as f:
    cross_train_test = pickle.load(f)

In [None]:
with open(f'{data_path}/function_resp_df.pkl','rb') as f:
    df_activations = pickle.load(f)

In [None]:
# set plot params 
import warnings
warnings.filterwarnings('ignore')

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.font_manager
import matplotlib as mpl
from statannot import add_stat_annotation
lw = 0.5 

region_pal = color_pals.regions
region_pal_dict = color_pals.region_dict
b3_gate_pal = region_pal.copy()
b3_gate_pal = b3_gate_pal[:3]
b3_gate_pal.append('#4F4E4C')
b3_gate_order=['temporal','precentral','postcentral','all']

b1_gate_pal = region_pal.copy()
b1_gate_pal = list(np.array(b1_gate_pal)[[1,2,3]])
b1_gate_pal.append('#4F4E4C')
b1_gate_order =['precentral','postcentral','frontal','all']


distractor_pal = list(np.array(color_pals.distractors)[[0,-1]])
distractor_order = ['Attempted speech','Listen']


mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['font.sans-serif'] = ['Helvetica']
plt.rcParams.update({'font.size': 8})#, 'font.sans-serif': 'Arial'})
# Specifies plot parameters
rows = {
    'top_start' : 0,
    'top_stop'  : 200,
    'second_start': 350,
    'second_stop' : 550,
    'secondscat_start':370,
    'secondscat_stop':550,
    'total'      : 550,
}

cols = {
    'paradigmb1_start' : 0,
    'paradigmb1_stop' : 390,
    'paradigmb3_start' : 410,
    'paradigmb3_stop' : 800,
    'traintest_start' : 0,
    'traintest_stop' : 180,#150,
    'salsspeech_start': 200,
    'salsspeech_stop': 370,
    'salslisten_start': 380,
    'salslisten_stop': 550,
    'salsscat_start': 630,
    'salsscat_stop' : 800,
    'total'      : 800
}
all_panel_params = {
    'paradigmb1' : {'row_and_col_spec' : ('top', 'paradigmb1')},
    'paradigmb3' : {'row_and_col_spec' : ('top', 'paradigmb3')},
    'traintest' : {'row_and_col_spec' : ('second', 'traintest')},
    'salslisten' : {'row_and_col_spec' : ('second', 'salslisten')},
    'salsspeech' : {'row_and_col_spec' : ('second', 'salsspeech')},
    'salsscat' : {'row_and_col_spec' : ('secondscat', 'salsscat')},
}

dims = np.array([cols['total'],rows['total']])/np.max([rows['total'],cols['total']])
scale = 8.17

# Creates the plot
fig, axs = plot_fncs.setup_figure(
    all_panel_params=all_panel_params, row_specs=rows, col_specs=cols,
    figsize=scale*dims #15,12
)

### PANEL: Content acc B1
ax = axs['paradigmb1']

reg_order = list(df[df.Patient=='B1'].groupby(by='Region').mean().sort_values(by='Accuracy',ascending=True).index.values)
b1_gate_pal = [region_pal_dict[anat] for anat in reg_order]

base_pal = sns.color_palette('deep').as_hex()
palette = ['#4F4E4C',base_pal[0],base_pal[1],base_pal[2]]

g = sns.boxplot(data=df[df.Patient=='B1'],x='Paradigm',y='Accuracy',hue='Region',ax=ax,palette=b1_gate_pal,
           hue_order=reg_order,linewidth=lw)
g.legend(frameon=False)
ax.axhline(10,linestyle='--',color='k',zorder=-1)
ax.set(ylim=[0,100],ylabel='Classification Accuracy (%)',xlabel='Bravo-1')
ax.spines[['top','right','bottom']].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.xaxis.set_tick_params(length=0)

from statannot import add_stat_annotation
from scipy.stats import wilcoxon
from statsmodels.stats.multitest import multipletests

x_order = ['Attempted speech','Listen','Read']
hue_order =reg_order
data_og = df[df.Patient == 'B1']
p_vals, box_pairs, maxs = [],[],[]
for x in x_order:
    data = data_og[data_og.Paradigm == x]
    for hue in hue_order:        
        p_vals.append(wilcoxon(data[data.Region == hue].Accuracy.values - 10)[1])
        box_pairs.append(  ((x, hue), (x, hue)) )
        maxs.append(np.max(data[data.Region == hue].Accuracy.values))
p_vals = multipletests(p_vals,method='holm')[1]

x_annot = -0.3
spacing = 0.2
buffer = 2
for i,p in enumerate(p_vals):

    if(p>0.05):
        pass
    else:
        ax.text(x_annot, maxs[i]+buffer, "*", ha='center', va='bottom', color='k',fontsize=12)    
    if(i==3 or i==7):
        x_annot += 2*spacing
    else:
        x_annot += spacing




### PANEL: Content acc B3
ax = axs['paradigmb3']
df['Region'][df.Region == 'temporal_lobe'] = 'temporal'

reg_order = list(df[df.Patient=='B3'].groupby(by='Region').mean().sort_values(by='Accuracy',ascending=True).index.values)
b3_gate_pal = [region_pal_dict[anat] for anat in reg_order]


g = sns.boxplot(data=df[df.Patient=='B3'],x='Paradigm',y='Accuracy',hue='Region',ax=ax,palette=b3_gate_pal,
           hue_order=reg_order,linewidth=lw)
g.legend(frameon=False)
ax.axhline(10,linestyle='--',color='k')
ax.set(ylim=[0,100],yticks=[],ylabel='',xlabel='Bravo-3')#,ylabel='Classification Accuracy (%)')
ax.spines[['top','right','left']].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.xaxis.set_tick_params(length=0)




x_order = ['Attempted speech','Listen','Read']
hue_order =reg_order
data_og = df[df.Patient == 'B3']

p_vals, box_pairs, maxs = [],[],[]
for x in x_order:
    data = data_og[data_og.Paradigm == x]
    for hue in hue_order:        
        p_vals.append(wilcoxon(data[data.Region == hue].Accuracy.values - 10)[1])
        box_pairs.append(  ((x, hue), (x, hue)) )
        maxs.append(np.max(data[data.Region == hue].Accuracy.values))
p_vals = multipletests(p_vals,method='holm')[1]

x_annot = -0.3
spacing = 0.2
buffer = 2
for i,p in enumerate(p_vals):

    if(p>0.05):
        pass
    else:
        ax.text(x_annot, maxs[i]+buffer, "*", ha='center', va='bottom', color='k',fontsize=12)    
    if(i==3 or i==7):
        x_annot += 2*spacing
    else:
        x_annot += spacing



### PANEL: Train/test on speech/listening
ax = axs['traintest']

all_accs = []
all_train = []
all_test = []
all_accs.extend(df['Accuracy'][(df.Region == 'temporal') & (df.Paradigm == 'Attempted speech')])
all_train.extend(np.repeat('Attempted speech', 10))
all_test.extend(np.repeat('Attempted speech', 10))

all_accs.extend(df['Accuracy'][(df.Region == 'temporal') & (df.Paradigm == 'Listen')])
all_train.extend(np.repeat('Listen', 10))
all_test.extend(np.repeat('Listen', 10))

#### PLACEHOLDER from WANDB replace with excat results later 
all_accs.extend(np.array(cross_train_test['train_speech_test_listen'])*100)
all_train.extend(np.repeat('Attempted speech', 10))
all_test.extend(np.repeat('Listen', 10))
all_accs.extend(np.array(cross_train_test['train_listen_test_speech'])*100)
all_train.extend(np.repeat('Listen', 10))
all_test.extend(np.repeat('Attempted speech', 10))

train_test_df = pd.DataFrame({'Accuracy': all_accs, 'Train': all_train, 'Test': all_test})

g = sns.boxplot(data=train_test_df,x='Train',y='Accuracy',hue='Test',ax=ax,
            palette=distractor_pal,hue_order=distractor_order,linewidth=lw)
g.legend(frameon=False,title='Test')

ax.axhline(10,linestyle='--',color='k')
ax.set(ylim=[0,100],ylabel='Temporal-lobe accuracy (%)')
ax.xaxis.set_tick_params(length=0)
sns.despine(ax=ax,offset=5)
ax.spines['bottom'].set_visible(False)


x_order = ['Attempted speech','Listen']
hue_order =distractor_order
data_og = train_test_df.copy()

p_vals, box_pairs, maxs = [],[],[]
for x in x_order:
    data = data_og[data_og.Train == x]
    for hue in hue_order:        
        p_vals.append(wilcoxon(data[data.Test == hue].Accuracy.values - 10)[1])
        box_pairs.append(  ((x, hue), (x, hue)) )
        maxs.append(np.max(data[data.Test == hue].Accuracy.values))
p_vals = multipletests(p_vals,method='holm')[1]

x_annot = -0.2
spacing = 0.4
buffer = 2
for i,p in enumerate(p_vals):

    if(p>0.05):
        pass
    else:
        ax.text(x_annot, maxs[i]+buffer, "*", ha='center', va='bottom', color='k',fontsize=12)    
    if(i==1):
        x_annot += 1.5*spacing
    else:
        x_annot += spacing




axs['salslisten'].set_visible(False)
axs['salsspeech'].set_visible(False)

### PANEL: Scatter saliences 
ax = axs['salsscat']
from sklearn.preprocessing import MinMaxScaler

sals_final_listen = d_sals['listening'].copy()
sals_final_speech = d_sals['speech'].copy()

scaler = MinMaxScaler()
sals_final_listen = scaler.fit_transform((sals_final_listen[sals_final_listen>0]).reshape(-1,1))
sals_final_speech = scaler.fit_transform((sals_final_speech[sals_final_speech>0]).reshape(-1,1))

from plotting_utils.utils import correlation_permutation
from scipy.stats import spearmanr
r,p = correlation_permutation(sals_final_listen,sals_final_speech,corr=spearmanr)


ax.scatter(sals_final_listen,sals_final_speech,color=palette[0],alpha=0.4,clip_on=False,s=10)
ax.set(xlabel='Listen',ylabel='Attempted speech',xlim=[0,1],ylim=[0,1],xticks=[0,0.5,1],yticks=[0,0.5,1])#,title='Electrode contribution')
ax.plot([0,1],[0,1],color='k',linestyle='--',zorder=-5)
sns.despine(ax=ax,offset=5)

ax.text(0.4, 0.9, f"r = {np.round(r,2)}, p = {np.round(p,2)}", ha='center', va='bottom', color='k')    



axs['paradigmb1'].annotate('A',(-40, 120),xycoords='axes points',weight='bold', ha='right',fontsize=9)
axs['traintest'].annotate('B',(-40, 130),xycoords='axes points',weight='bold', ha='right',fontsize=9)
axs['salsspeech'].annotate('C',(-20, 130),xycoords='axes points',weight='bold', ha='right',fontsize=9)
axs['salsscat'].annotate('D',(-40, 130),xycoords='axes points',weight='bold', ha='right',fontsize=9)
fig.set_size_inches(scale*dims)
print(fig.get_size_inches())


plt.show();