In [None]:
import numpy as np
import pandas as pd
from scipy import stats
import pingouin as pg
import re

import pylab as plt
import seaborn as sns

from statsmodels.stats.multitest import multipletests

In [None]:
color_map = dict(zip(['motor','cognition','psychiatric','autonomic','daily','medication','physical activity','sleep','vital'],sns.color_palette('deep')))
color_map2 = dict(zip(['clinical','digital'],sns.color_palette('gray')))

In [None]:
cl_names = ['Semantic Fluency', 'MOCA', 'Benton',
       'Letter Number Sequencing', 'HVLT Recall', 'HVLT Recognition', 'HVLT Retention',
       'Symbol Digit', 'STAI trait', 'STAI state', 'GDS', 'QUIP',
       'ESS', 'RBDSQ', 'Systolic BP Drop', 'SCOPA autonome',
       'Schwab England ADL', 'UPDRS I','UPDRS II','UPDRS III OFF','UPDRS III ON','UPDRS IV', 'LEDD']
neuropsychiatric = ['stai_trait','stai_state','gds','quip']
cognition = ['semantic_fluency','moca','benton','lns','hvlt_recall','hvlt_recognition','hvlt_retention','symbol_digit']
autonome = ['epworth','rbd','systolic_bp_drop','scopa_aut']
daily = ['se_adl','updrs_i']
motor = ['updrs_ii','updrs_iii_OFF']
medication = ['updrs_iii_ON','updrs_iv','LEDD',]
features = np.hstack([cognition,neuropsychiatric,autonome,daily,motor,medication])
covs = ['visit_age','date_y']

grouped_mean = pd.read_csv('/scratch/c.c21013066/data/ppmi/accelerometer/weekly_mean.csv',parse_dates=['date_y'])
predictors = grouped_mean.filter(regex='(walking|step|efficiency|total_sleep|pulse|deep|light|rem|nrem|rmssd|wake)').columns
sleep_col = grouped_mean.filter(regex='(efficiency|total_sleep|deep|light|rem|nrem|wake)').columns
phys = grouped_mean.filter(regex='(walking|step)').columns
vital = grouped_mean.filter(regex='(pulse|rmssd)').columns
predictors_filt = [a for a in predictors if not re.search('_ms', a)]
sleep_col = [a for a in sleep_col if not re.search('_ms', a)]
phys = [a for a in phys if not re.search('_ms', a)]
vital = [a for a in vital if not re.search('_ms', a)]

In [None]:
kind = 'week'
corr = pd.read_csv(f'/scratch/c.c21013066/data/ppmi/analyses/studywatch/ratechange_{kind}_mean_corr.csv',index_col=[0,1],header=[0])

In [None]:
method='FDR'
fig = plt.figure(figsize=(15,10))
plot_context()

labels = pd.Series(np.hstack([np.repeat('cognition',len(cognition)),
                              np.repeat('psychiatric',len(neuropsychiatric)),np.repeat('autonomic',len(autonome)),
                             np.repeat('daily',len(daily)),np.repeat('motor',len(motor)),np.repeat('medication',len(medication))]),index=cl_names)
#color_map = dict(zip(labels.unique(),sns.color_palette('deep')))

labels_dig = pd.Series(np.hstack([np.repeat('physical activity',len(phys)),np.repeat('sleep',len(sleep_col)),
                             np.repeat('vital',len(vital))]),index=np.hstack([phys,sleep_col,vital]))
labels_all = pd.concat([labels,labels_dig])
labels2 = pd.Series(np.hstack([np.repeat('clinical',len(cl_names)),np.repeat('digital',labels_dig.shape[0])]),index=np.hstack([cl_names,phys,sleep_col,vital]))

corr = corr.dropna(subset=['p-value'])
if method == 'HOLM':
    # HOLM
    reject, p_corrected, _, _ = multipletests(corr['p-value'].values.flatten(), alpha=0.05, method='holm')
    corr['p-correctedHOLM'] = p_corrected
    corr['sign'] = corr['p-correctedHOLM']<0.05
elif method == 'FDR':
    # FDR
    rejected, p_corrected, _, _ = multipletests(corr['p-value'].values.flatten(), alpha=0.05, method='fdr_bh')
    corr['p-correctedFDR'] = p_corrected
    corr['sign'] = corr['p-correctedFDR']<0.05
elif method=='Bonferroni':
# BONFERRONI
    #corr['sign'] = corr['p-value']< (0.05/(len(labels_dig)+len(labels)))
    corr['sign'] = corr['p-value']< (0.05/(len(np.hstack([cl_names,phys,sleep_col,vital]))*(len(np.hstack([cl_names,phys,sleep_col,vital]))-1)/2))
    
corr['sign'] = corr['sign'].replace([True,False],['*',''])
corr.loc[corr['p-value'].isna(),'sign'] = np.nan

whole_corr = pd.DataFrame(index=pd.MultiIndex.from_product([np.hstack([cl_names,phys,sleep_col,vital]),
                                                      np.hstack([cl_names,phys,sleep_col,vital])],names=['p1','p2'])
                    ,columns=['pearson r','p-value','N','sign'])
whole_corr.loc[corr.index,:] = corr[['pearson r','p-value','N','sign']].values
rs = whole_corr[['pearson r']].unstack().droplevel(level=0,axis=1).loc[np.hstack([cl_names,phys,sleep_col,vital]),
                                                                               np.hstack([cl_names,phys,sleep_col,vital])].astype(float)


signs = whole_corr[['sign']].unstack().droplevel(level=0,axis=1).loc[np.hstack([cl_names,phys,sleep_col,vital]),
                                                                               np.hstack([cl_names,phys,sleep_col,vital])]




g = sns.clustermap(rs,
               cmap='coolwarm',center=0,xticklabels=True,yticklabels=True,
               annot=signs,
              fmt='',
              col_cluster=False,col_colors=[labels_all.map(color_map),labels2.map(color_map2)],figsize=(15,15),row_cluster=False,
                   row_colors=[labels_all.map(color_map),labels2.map(color_map2)],
                  cbar_kws={'label':'pearson r'});
g.ax_heatmap.set_xlabel('')
g.ax_heatmap.set_ylabel('')

# Draw the legend bar for the classes                 
for label in np.unique(np.hstack([labels,labels_dig])):
    g.ax_col_dendrogram.bar(0, 0, color=color_map[label],
                            label=label, linewidth=0)
for label in np.unique(np.hstack([labels2])):
    g.ax_col_dendrogram.bar(0, 0, color=color_map2[label],
                            label=label, linewidth=0)
g.ax_col_dendrogram.legend(bbox_to_anchor=(0.4,0.59), ncol=4);

ax = g.ax_heatmap  # this is the important part
ax.plot([0, len(np.hstack([labels,labels_dig]))], [len(labels), len(labels)], 'k-', lw = 2)
ax.plot([len(labels), len(labels)], [0,len(np.hstack([labels,labels_dig]))], 'k-', lw = 2)

plt.savefig(f'/scratch/c.c21013066/images/ppmi/studywatch/digital_clinical_corr_{kind}_pd_change_LEDD_removevisitday_mean.png',bbox_inches='tight',dpi=300)
plt.savefig(f'/scratch/c.c21013066/images/ppmi/studywatch/digital_clinical_corr_{kind}_pd_change_LEDD_removevisitday_mean.pdf',bbox_inches='tight',dpi=300)