In [None]:
from sklearn.model_selection import KFold

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import LinearSegmentedColormap
import os
import pickle
import scipy
import sklearn
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

from sklearn.manifold import TSNE
from sklearn.metrics import RocCurveDisplay, auc, roc_curve

from custom_utils import validation_setup

---
### read data

In [None]:
data_path = '../data/mimic'
output_path = '../output/mimic'

In [None]:
with open(data_path+'/df_feature.pickle', 'rb') as f:
    df_feature = pickle.load(f).fillna(0)

with open(data_path+'/df_label.pickle', 'rb') as f:
    df_label = pickle.load(f).fillna(0)

In [None]:
concept_info = pd.read_csv('./d_labitems.csv', names=['itemid','label','fluid','category'], index_col='itemid')
concept_info['used_name'] = concept_info['label'] + ' / ' + concept_info['fluid']

In [None]:
df_feature = df_feature.rename(columns={'M'+key: concept_info['used_name'].to_dict()[key] for key in concept_info['used_name'].to_dict()}).rename(
    columns={'BMI (kg/m2)':'BMI / Physical', 'Height (Inches)':'Height / Physical', 'Weight (Lbs)':'Weight / Physical', 'age':'age / Demography', 'Blood_Pressure_High': 'SBP / Physical', 'Blood_Pressure_Low': 'DBP / Physical'}
)
df_label = df_label.rename(columns={key: key[2:] for key in df_label.columns})

---
### Table 1. feature characteristic

In [None]:
df_feature.shape, df_label.shape

In [None]:
np.round(df_label.sum(axis=0))

In [None]:
np.round(df_label.sum(axis=0) / df_label.shape[0] * 100, 2)

#### remove HUA label

In [None]:
df_label = df_label.drop('HUA', axis=1)  # not predicting disease HUA for most of the providers didn't have corresponding diagnoses records

---
### setups

In [None]:
dataset_name = 'MIMIC-IV'

provider_ids = df_feature.index.get_level_values(0).unique()

n_folds, n_providers_per_fold = 11, 5
n_providers = n_folds * n_providers_per_fold 
n_inputs = df_feature.shape[1]
n_labels = df_label.shape[1]

col_labels = df_label.columns

batch_size = 100
n_iters = 40
epsilon = 1e-7

In [None]:
train_list, valid_list = validation_setup(n_folds, n_providers_per_fold)

---
### figure setups

In [None]:
LABEL_SIZE = 30
TICK_SIZE = 25
TICK_SIZE_SMALL = 20
united_dpi = 400

In [None]:
correlation_metrics = [r'Pearson $r$', r'Spearman $\rho$', r'Kendall $\tau$']
evaluation_metrics = [r'Pearson $r$', r'Spearman $\rho$', r'Kendall $\tau$', r'$AUC_{ROC}$']
metrics_names = ['pearsonr', 'spearmanr', 'kendalltau',  'roc_auc_score']

In [None]:
markers = 'ov^s'
layer_shape = [200, 50, 30]
common_reg_list = [2e-5, 3e-5, 4e-5, 5e-5, 7e-5, 8e-5, 1e-4, 2e-4, 3e-4, 4e-4, 5e-4, 7e-4, 1e-3]
common_wpa_list = np.arange(0,11) / 2
used_reg = 8e-5
used_wpa = 2

---
### reading model predictions

In [None]:
ind_pred_total = []
df_ind_pred_total = []

for i_provider in provider_ids:
    
    ind_pred_total_i = []
    
    for j_valid in provider_ids:    
        
        with open(output_path+'/ind/pred_train_on_%s_valid_on_%s.pickle' % (i_provider, j_valid), 'rb') as f:
            ind_pred_total_i.append(pickle.load(f))
            
    ind_pred_total_i = np.vstack(ind_pred_total_i)
    df_ind_pred_total_i = pd.DataFrame(
        index=pd.MultiIndex.from_product([[i_provider], range(df_feature.shape[0])], names=['provider', 'pid']),
        columns=col_labels, data=ind_pred_total_i
    )
    
    ind_pred_total.append(ind_pred_total_i)
    df_ind_pred_total.append(df_ind_pred_total_i)
    
ind_pred_total = np.stack(ind_pred_total, axis=1)
df_ind_pred_total = pd.concat(df_ind_pred_total).reset_index()

In [None]:
with open(output_path+'/phyC-all-w=%.1f/train_pred_g.pickle' % used_wpa, 'rb') as f:
    phyC_g_total = pickle.load(f)

In [None]:
with open(output_path+'/phyC-all-w=%.1f/train_pred.pickle' % used_wpa, 'rb') as f:
    phyC_pred_total = pickle.load(f)

In [None]:
with open(output_path+'/phyC-all-w=%.1f/train_kernel.pickle' % used_wpa, 'rb') as f:
    phyC_kernel_total = pickle.load(f)

In [None]:
with open(output_path+'/phyC-all-w=%.1f/train_bias.pickle' % used_wpa, 'rb') as f:
    phyC_bias_total = pickle.load(f)

---
### Figure 1b - physician-paired inconsistency measures

In [None]:
# consistency_between_physicians = np.zeros([n_providers, n_providers, n_labels, 3])

# for i_provider in range(n_providers):
#     for j_provider in range(n_providers):
#         for k_label in range(n_labels):
#             consistency_between_physicians[i_provider, j_provider, k_label, 0] = scipy.stats.pearsonr(ind_pred_total[:, i_provider,k_label], ind_pred_total[:, j_provider,k_label]).statistic
#             consistency_between_physicians[i_provider, j_provider, k_label, 1] = scipy.stats.spearmanr(ind_pred_total[:, i_provider,k_label], ind_pred_total[:, j_provider,k_label]).statistic
#             consistency_between_physicians[i_provider, j_provider, k_label, 2] = scipy.stats.kendalltau(ind_pred_total[:, i_provider,k_label], ind_pred_total[:, j_provider,k_label]).statistic
            
#         print(i_provider, j_provider, end='\r')
            
# with open(output_path+'/MLP/pair consistency.pickle', 'wb') as f:
    
#     pickle.dump(consistency_between_physicians, f)

In [None]:
with open(output_path+'/MLP/pair consistency.pickle', 'rb') as f:
    
    consistency_between_physicians = pickle.load(f)
    
fig, axes = plt.subplots(nrows=3, ncols=9, figsize=(48, 12), dpi=united_dpi)

plt.suptitle(dataset_name, x=0.45, fontsize=LABEL_SIZE+5)
for i_metric in range(3):
    for k_label in range(n_labels):
        ax = axes.flat[i_metric*n_labels+k_label]
        if i_metric == 0:
            ax.set_title(col_labels[k_label], fontsize=LABEL_SIZE)
        if k_label == 0:
            ax.set_ylabel(correlation_metrics[i_metric], fontsize=LABEL_SIZE)
        im = ax.imshow(consistency_between_physicians[:,:,k_label,i_metric], vmin=-1, vmax=1, cmap='RdBu')
        ax.set_xticks([0,n_providers-1],[1,n_providers],fontsize=TICK_SIZE_SMALL)
        ax.set_yticks([0,n_providers-1],[1,n_providers],fontsize=TICK_SIZE_SMALL)
        # ax.invert_xaxis()
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), pad=0.01, ticks=np.arange(-1+1e-5,1.1,0.2-1e-6), format='%4.1f')
cbar.ax.tick_params(labelsize=TICK_SIZE)

plt.show()

In [None]:
((consistency_between_physicians.sum(axis=0)-1) / (consistency_between_physicians.shape[0]-1)).mean(axis=0).mean(axis=0)

### Figure 1d - group agreement (ICC) scores

In [None]:
# import pandas as pd
# import pingouin as pg

# # Compute ICC
# icc_df_list = {}
# for col_name in col_labels:
#     print(col_name, end='\r')
#     icc_df_list[col_name] = pg.intraclass_corr(data=df_ind_pred_total, targets='pid', raters='provider', ratings=col_name)
    
# with open(output_path+'/MLP/group consistency.pickle', 'wb') as f:
    
#     pickle.dump(icc_df_list, f)

In [None]:
with open(output_path+'/MLP/group consistency.pickle', 'rb') as f:
    
    icc_df_list = pickle.load(f)
    
plt.figure(figsize=(13,12), dpi=united_dpi)

plt.fill_between([-1,n_labels], [1.0, 1.0], [0.75, 0.75], alpha=0.2, label='excellent', color='tab:cyan', lw=0)
plt.fill_between([-1,n_labels], [0.75, 0.75], [0.6, 0.6], alpha=0.2, label='good', color='tab:green', lw=0)
plt.fill_between([-1,n_labels], [0.6, 0.6], [0.4, 0.4], alpha=0.2, label='fair', color='tab:orange', lw=0)  
plt.fill_between([-1,n_labels], [0.4, 0.4], [0.0, 0.0], alpha=0.2, label='poor', color='tab:red', lw=0)

icc_values = [icc_df_list[s].iloc[2]['ICC'] for s in col_labels]

markerline, stemlines, baseline = plt.stem(range(n_labels), icc_values, bottom=-1)

for k_label in range(n_labels):
    plt.text(k_label, icc_values[k_label]+0.02, '%.2f' % (icc_values[k_label]), fontsize=TICK_SIZE, horizontalalignment='center')
    
plt.setp(stemlines, 'linewidth', 5)
plt.setp(markerline, markersize = 10)

plt.ylim([0,1])
plt.xlim([-1,n_labels])
plt.title(dataset_name, fontsize=LABEL_SIZE)
plt.xticks(range(n_labels), [s for s in col_labels], fontsize=TICK_SIZE)
plt.yticks(fontsize=TICK_SIZE)
plt.xlabel('Disease Name', fontsize=LABEL_SIZE)
plt.ylabel('ICC(3,1)', fontsize=LABEL_SIZE)
plt.legend(fontsize=TICK_SIZE)

---
### Figure 3
#### evaluation_scores: change on $w_{PA}$, reg_scores: change on regularization

In [None]:
def metrics_analysis(x):
    
    x = x.dropna()
    x = x.groupby(['valid', 'label', 'metric']).mean()
    _n_folds, _n_algs = x.shape
    df = pd.DataFrame()
    df['mean'] = x.mean(axis=0)
    df['rank'] = df['mean'].rank(ascending=False)
    df['std'] = x.std(axis=0)
    df['statistic'], df['pvalue'], df['significance'] = 0,0.5,0
    
    for i in range(1, df.shape[0]):
        df['statistic'].iloc[i], df['pvalue'].iloc[i] = scipy.stats.wilcoxon(x.iloc[:, 0], x.iloc[:, i])
    
    df['significance'] = (df['pvalue'] < 0.05).astype(int) + (df['pvalue'] < 0.01).astype(int) + (df['pvalue'] < 0.001).astype(int)
    df['significance'] = ['*' * int(k1) + '•' * int(k2) 
                          for k1, k2 
                          in zip((df['rank'] < df['rank'].iloc[0]) * df['significance'],
                                 (df['rank'] > df['rank'].iloc[0]) * df['significance'])]

    df['display'] = df.apply(lambda y: ('%.3f'%y['mean'])+'±'+('%.3f'%y['std']), axis=1)
    
    return df.T

In [None]:
alg_name_list = ['phyC-w=%.1f' % i for i in common_wpa_list]

n_algs = len(alg_name_list)

evaluation_scores = pd.DataFrame(columns=common_wpa_list, index=pd.MultiIndex.from_product([range(n_folds), col_labels, metrics_names, [0, 1]], names=['valid', 'label', 'metric', 'submetric']))

for j_fold in range(n_folds):
    
    valid_labels = df_label.loc[provider_ids[valid_list[j_fold]]].values
    
    for i_alg, alg_name in enumerate(alg_name_list):
        
        with open(output_path + '/%s/pred_valid_%s_split_0.pickle' % (alg_name, j_fold), 'rb') as f:
            pred_0 = pickle.load(f)
        with open(output_path + '/%s/pred_valid_%s_split_1.pickle' % (alg_name, j_fold), 'rb') as f:
            pred_1 = pickle.load(f)
        
        for i_label in range(n_labels):
               
            _cor = scipy.stats.pearsonr(pred_0[:, i_label], pred_1[:, i_label]).statistic
            _rho = scipy.stats.spearmanr(pred_0[:, i_label], pred_1[:, i_label]).statistic
            _tau = scipy.stats.kendalltau(pred_0[:, i_label], pred_1[:, i_label]).statistic
            
            if valid_labels[:, i_label].sum() > 0:
                _roc_auc_0 = sklearn.metrics.roc_auc_score(valid_labels[:, i_label], pred_0[:, i_label])
                _roc_auc_1 = sklearn.metrics.roc_auc_score(valid_labels[:, i_label], pred_1[:, i_label])
            else:
                _roc_auc_0 = np.nan
                _roc_auc_1 = np.nan
                
            evaluation_scores.loc[(j_fold, col_labels[i_label], 'pearsonr', 0), common_wpa_list[i_alg]] = _cor
            evaluation_scores.loc[(j_fold, col_labels[i_label], 'spearmanr', 0), common_wpa_list[i_alg]] = _rho
            evaluation_scores.loc[(j_fold, col_labels[i_label], 'kendalltau', 0), common_wpa_list[i_alg]] = _tau
            evaluation_scores.loc[(j_fold, col_labels[i_label], 'roc_auc_score', 0), common_wpa_list[i_alg]] = _roc_auc_0
            evaluation_scores.loc[(j_fold, col_labels[i_label], 'roc_auc_score', 1), common_wpa_list[i_alg]] = _roc_auc_1
            
            print(j_fold, i_label, end='\r')

In [None]:
evaluation_scores

In [None]:
scores_statistic = evaluation_scores.groupby(['label', 'metric']).apply(metrics_analysis)
scores_mean_list = scores_statistic.loc[:,:,'mean'].groupby('metric').mean()
scores_std_list = scores_statistic.loc[:,:,'std'].groupby('metric').mean()

In [None]:
scores_mean_list

In [None]:
alg_name = 'MLP_grid_search'

n_algs = len(common_reg_list)

reg_scores = pd.DataFrame(columns=common_reg_list, index=pd.MultiIndex.from_product([range(n_folds), col_labels, metrics_names, [0, 1]], names=['valid', 'label', 'metric', 'submetric']))

for j_fold in range(n_folds):
    
    valid_labels = df_label.loc[provider_ids[valid_list[j_fold]]].values
    
    for common_reg in common_reg_list:

        with open(output_path + '/%s/preds/pred_%s.pickle' % (alg_name, (common_reg, str(layer_shape), j_fold)), 'rb') as f:

            pred_0, pred_1 = pickle.load(f)
            pred_0, pred_1 = pred_0.values, pred_1.values
            
        for i_label in range(n_labels):
               
            _cor = scipy.stats.pearsonr(pred_0[:, i_label], pred_1[:, i_label]).statistic
            _rho = scipy.stats.spearmanr(pred_0[:, i_label], pred_1[:, i_label]).statistic
            _tau = scipy.stats.kendalltau(pred_0[:, i_label], pred_1[:, i_label]).statistic
            
            if valid_labels[:, i_label].sum() > 0:
                _roc_auc_0 = sklearn.metrics.roc_auc_score(valid_labels[:, i_label], pred_0[:, i_label])
                _roc_auc_1 = sklearn.metrics.roc_auc_score(valid_labels[:, i_label], pred_1[:, i_label])
            else:
                _roc_auc_0 = np.nan
                _roc_auc_1 = np.nan
                
            reg_scores.loc[(j_fold, col_labels[i_label], 'pearsonr', 0), common_reg] = _cor
            reg_scores.loc[(j_fold, col_labels[i_label], 'spearmanr', 0), common_reg] = _rho
            reg_scores.loc[(j_fold, col_labels[i_label], 'kendalltau', 0), common_reg] = _tau
            reg_scores.loc[(j_fold, col_labels[i_label], 'roc_auc_score', 0), common_reg] = _roc_auc_0
            reg_scores.loc[(j_fold, col_labels[i_label], 'roc_auc_score', 1), common_reg] = _roc_auc_1
            
            print(j_fold, i_label, end='\r')

In [None]:
reg_scores_statistic = reg_scores.groupby(['label', 'metric']).apply(metrics_analysis)
reg_scores_mean_list = reg_scores_statistic.loc[:,:,'mean'].groupby('metric').mean()
reg_scores_std_list = reg_scores_statistic.loc[:,:,'std'].groupby('metric').mean()

In [None]:
reg_scores_mean_list

### Figure 3a Comparison outputs

In [None]:
algs = [0, used_wpa]
legends = ['Baseline', 'Ours($w_\mathrm{PA}=%.1f$)' % used_wpa]

bar_width, box_width, scatter_width = 0.2, 0.1, 0.2
ind_width = bar_width + scatter_width
group_width = (len(algs) + 0.5) * ind_width

x_bar_anchor = np.arange(n_labels) * group_width
x_scatter_anchor, _ = np.meshgrid(x_bar_anchor, np.arange(n_folds))

y_sig_height = 0.035
y_sig_shift = -  y_sig_height / 3

plt.figure(figsize=(11,20), dpi=united_dpi)

for i_metric, metric in enumerate(metrics_names):
    
    plt.subplot(4,1, i_metric+1)
    
    high_caps_y = []
    
    for i_alg, alg_value in enumerate(algs):
          
        this_metric = evaluation_scores[alg_value].groupby(['valid','label','metric']).mean().loc[:, col_labels, metric].values.reshape([n_labels, n_folds]).T 
        this_color = 'tab:orange' if i_alg == 0 else 'tab:blue'
        this_bar_position = x_bar_anchor + i_alg * ind_width
        this_scatter_position = x_scatter_anchor - scatter_width if i_alg == 0 else x_scatter_anchor + ind_width * i_alg + scatter_width
        
        plt.bar(this_bar_position, 
                scores_statistic.loc[col_labels, metric, 'mean'][alg_value], 
                width=bar_width, label=legends[i_alg],alpha=0.5,
                color=this_color)
        bplt = plt.boxplot(this_metric, positions=this_bar_position, widths=box_width, showfliers=True,medianprops={'color':'black'}, sym='+')
        plt.scatter(this_scatter_position, this_metric, s=10, alpha=0.3, color=this_color)
        
        high_caps_y.append(np.array([cap._y[0] for cap in bplt['caps']]).reshape([n_labels, 2]).max(axis=1))
        
    plt.plot([-group_width/2,group_width*n_labels],[1,1], '--', c='black', alpha=0.4)
    plt.ylim([0.4,1.07])
    plt.xlim([-group_width/2,group_width*n_labels])
    plt.title(evaluation_metrics[i_metric], fontsize=LABEL_SIZE)
    plt.xticks(x_bar_anchor + 0.5 * group_width - 0.3, [_name for _name in col_labels], fontsize=TICK_SIZE)
    plt.yticks([0.4,0.6,0.8,1.0], [0.4,0.6,0.8,1.0], fontsize=TICK_SIZE)
    plt.legend(fontsize=TICK_SIZE_SMALL-3, loc='lower center', ncol=2)
    
    high_caps_y = np.vstack(high_caps_y).max(axis=0)

    for i_sig in np.arange(n_labels):
        
        n_sig = 0
        
        for i_alg, alg_value in enumerate(algs):   
            
            if i_alg == 0:
                continue
             
            _text = scores_statistic.loc[col_labels[i_sig], metric, 'significance'][alg_value]
            
            if len(_text) != 0:

                _rect_left, _rect_right = i_sig * group_width, i_sig * group_width + i_alg * ind_width
                _rect_low, _rect_high = high_caps_y[i_sig]+y_sig_height*(n_sig+1.1) + y_sig_height/4, high_caps_y[i_sig]+y_sig_height*(n_sig+1.1) + 2*y_sig_height/4
            
                plt.text(_rect_left/2 + _rect_right/2,
                         _rect_high,
                         _text,
                         horizontalalignment='center', fontsize=TICK_SIZE_SMALL-5)
                
                plt.plot([_rect_left, _rect_left, _rect_right, _rect_right],
                         [_rect_low, _rect_high, _rect_high, _rect_low],
                         color='black', lw=1.1)
                
                n_sig += 1
                
plt.suptitle(dataset_name, x=0.2, fontsize=LABEL_SIZE+2)
plt.tight_layout()
plt.show()

### Figure3c trend of metrics with reg

In [None]:
plt.figure(figsize=(11,11), dpi=united_dpi)

phyc_mean = scores_mean_list[used_wpa]
phyc_std = scores_std_list[used_wpa]

for i_metric, metric in enumerate(metrics_names):
    
    plt.plot([common_reg_list[0], common_reg_list[-1]], [phyc_mean.loc[metric], phyc_mean.loc[metric]], '-', color=cm.Set1(i_metric), lw=3)

for i_metric, metric in enumerate(metrics_names):
    
    plt.plot(common_reg_list, reg_scores_mean_list.loc[metric], markers[i_metric]+'--', markersize=10, label=evaluation_metrics[i_metric],  color=cm.Set1(i_metric), lw=3)
    plt.fill_between(common_reg_list, 
                     reg_scores_mean_list.loc[metric] + reg_scores_std_list.loc[metric], 
                     reg_scores_mean_list.loc[metric] - reg_scores_std_list.loc[metric],
                     color=cm.Set1(i_metric), 
                     alpha=0.1, lw=0)

plt.plot(used_reg, phyc_mean.max(axis=0) + 0.02, 'v', markersize=12, color='grey')

plt.ylim([0.5,0.95])
plt.xlim([common_reg_list[0], common_reg_list[-1]])
plt.xscale('log')
plt.xticks(common_reg_list,['  %.0e' % x for x in common_reg_list],fontsize=TICK_SIZE, rotation='90')
plt.yticks(np.arange(0.5,0.91, 0.1), fontsize=TICK_SIZE)
plt.xlabel('L1L2', fontsize=LABEL_SIZE)
plt.ylabel('score', fontsize=LABEL_SIZE)
plt.title(dataset_name, fontsize=LABEL_SIZE)
plt.legend(fontsize=TICK_SIZE, loc='lower right')
plt.show()

### Figure3e trend of metrics with $w_{PA}$

In [None]:
plt.figure(figsize=(11,11), dpi=united_dpi)

for i_metric, metric in enumerate(metrics_names):
    
    plt.plot([0, common_wpa_list[-1]], [scores_mean_list.loc[metric][0], scores_mean_list.loc[metric][0]], '--', color=cm.Set1(i_metric), lw=3)

for i_metric, metric in enumerate(metrics_names):
    
    plt.plot(common_wpa_list, scores_mean_list.loc[metric], markers[i_metric]+'-', markersize=10, label=evaluation_metrics[i_metric],  color=cm.Set1(i_metric), lw=3)
    plt.fill_between(range(scores_mean_list.shape[1]), 
                     scores_mean_list.loc[metric] + scores_std_list.loc[metric], 
                     scores_mean_list.loc[metric] - scores_std_list.loc[metric],
                     color=cm.Set1(i_metric), 
                     alpha=0.1, lw=0)

plt.plot(used_wpa, scores_mean_list.max(axis=0)[used_wpa]+ 0.02, 'v', markersize=12, color='grey')

plt.ylim([0.5,0.95])
plt.xlim([0,common_wpa_list[-1]])
plt.xticks(common_wpa_list,common_wpa_list,fontsize=TICK_SIZE)
plt.yticks(np.arange(0.5,0.91, 0.1), fontsize=TICK_SIZE)
plt.xlabel('$w_\mathrm{PA}$', fontsize=LABEL_SIZE)
plt.ylabel('score', fontsize=LABEL_SIZE)
plt.title(dataset_name, fontsize=LABEL_SIZE)
plt.legend(fontsize=TICK_SIZE, loc='lower right')
plt.show()

In [None]:
scores_statistic[[0, used_wpa]].swaplevel(0,1).loc[metrics_names, col_labels, ['display', 'pvalue', 'significance']].to_csv(output_path+'/eval.csv', sep=',', index=True, encoding='utf-8')

---
### Figure 4a 4b scatter map / Bland–Altman plots

In [None]:
_cmap = LinearSegmentedColormap.from_list('1', cm.tab10.colors[:2])

def cor_scatter_plt(alg_name, j_fold, k_label):
    
    valid_labels = df_label.loc[provider_ids[valid_list[j_fold]]].values[:, k_label]
    
    with open(output_path+'/%s/pred_valid_%d_split_0.pickle'% (alg_name, j_fold), 'rb') as f:  
            _f0 = pickle.load(f)[:, k_label]
    with open(output_path+'/%s/pred_valid_%d_split_1.pickle'% (alg_name, j_fold), 'rb') as f:  
            _f1 = pickle.load(f)[:, k_label]
    plt.scatter(_f0, _f1, c=valid_labels, alpha=0.3, s=10, cmap=_cmap)
    plt.axis('square')
    
def bland_altman_plot(alg_name, j_fold, k_label):
    
    valid_labels = df_label.loc[provider_ids[valid_list[j_fold]]].values[:, k_label]
    
    with open(output_path+'/%s/pred_valid_%d_split_0.pickle'% (alg_name, j_fold), 'rb') as f:  
            data1 = pickle.load(f)[:,k_label]
    with open(output_path+'/%s/pred_valid_%d_split_1.pickle'% (alg_name, j_fold), 'rb') as f:  
            data2 = pickle.load(f)[:,k_label]
            
    mean      = np.mean([data1, data2], axis=0)
    diff      = data1 - data2                   # Difference between data1 and data2
    md        = np.mean(diff)                   # Mean of the difference
    sd        = np.std(diff, axis=0)            # Standard deviation of the difference

    plt.scatter(mean, diff, c=valid_labels, alpha=0.3, s=10, cmap=_cmap)
    plt.text(0.67,md + 1.96*sd + 0.02,'$μ+1.96σ$', fontsize=TICK_SIZE_SMALL-2)
    plt.text(0.67,md - 1.96*sd - 0.08,'$μ-1.96σ$', fontsize=TICK_SIZE_SMALL-2)
    plt.text(0.02,-0.58,'$μ=%.3f$\n$σ=%.3f$'%(md,sd), fontsize=TICK_SIZE_SMALL)
    plt.axhline(md,           color='gray', linestyle='--')
    plt.axhline(md + 1.96*sd, color='gray', linestyle='--')
    plt.axhline(md - 1.96*sd, color='gray', linestyle='--')
    
plt.figure(figsize=(20,15))

j_fold = 1

for i_place, label_name in enumerate(['HTN', 'HLP', 'CHD']): 

    k_label = col_labels.tolist().index(label_name)
    
    # MLP
    ax = plt.subplot(3,4,1+i_place*4)
    cor_scatter_plt('phyC-w=%.1f' % 0, j_fold, k_label)
    plt.xlim([0,1])
    plt.ylim([0,1])
    plt.xticks((np.arange(0,11,2)/10)[1:], fontsize=TICK_SIZE_SMALL)
    plt.yticks(np.arange(0,11,2)/10, fontsize=TICK_SIZE_SMALL)
    plt.title('Baseline: %s'%label_name, fontsize=TICK_SIZE)
    x_left, x_right, y_low, y_high = 0,1, 0,1
    ax.set_aspect(abs((x_right-x_left)/(y_low-y_high)))
    
    # OURS
    ax = plt.subplot(3,4,2+i_place*4)
    cor_scatter_plt('phyC-w=%.1f' % used_wpa, j_fold, k_label)
    plt.xlim([0,1])
    plt.ylim([0,1])
    plt.xticks((np.arange(0,11,2)/10)[1:], fontsize=TICK_SIZE_SMALL)
    plt.yticks(np.arange(0,11,2)/10, fontsize=TICK_SIZE_SMALL)
    plt.title('Ours: %s'%label_name, fontsize=TICK_SIZE)
    x_left, x_right, y_low, y_high = 0,1, 0,1
    ax.set_aspect(abs((x_right-x_left)/(y_low-y_high)))    
    
    # MLP
    ax = plt.subplot(3,4,3+i_place*4)
    bland_altman_plot('phyC-w=%.1f' % 0, j_fold, k_label)
    plt.xticks((np.arange(-10,11,2)/10)[1:], fontsize=TICK_SIZE_SMALL)
    plt.yticks(np.arange(-10,11,2)/10, fontsize=TICK_SIZE_SMALL)
    plt.title('Baseline: %s'%label_name, fontsize=TICK_SIZE)
    plt.xlim([0,1])
    plt.ylim([-0.6,0.6])
    x_left, x_right, y_low, y_high = 0,1, -0.6, 0.6
    ax.set_aspect(abs((x_right-x_left)/(y_low-y_high)))    
    
    # OURS
    ax = plt.subplot(3,4,4+i_place*4)
    bland_altman_plot('phyC-w=%.1f' % used_wpa, j_fold, k_label)
    plt.xticks((np.arange(-10,11,2)/10)[1:], fontsize=TICK_SIZE_SMALL)
    plt.yticks(np.arange(-10,11,2)/10, fontsize=TICK_SIZE_SMALL)
    plt.title('Ours: %s'%label_name, fontsize=TICK_SIZE)
    plt.xlim([0,1])
    plt.ylim([-0.6,0.6])
    x_left, x_right, y_low, y_high = 0,1, -0.6, 0.6
    ax.set_aspect(abs((x_right-x_left)/(y_low-y_high)))
    
plt.suptitle(dataset_name,fontsize=LABEL_SIZE)
plt.tight_layout()
plt.show()

---
### Figure 5 Distribution-Based Interpretation

In [None]:
# N=10000
# rndperm = np.random.permutation(phyC_pred_total.shape[0])
# phyC_pred_total_subset = phyC_pred_total[rndperm[:N]]
# phyC_g_total_subset = phyC_g_total[rndperm[:N]]
# tsne = TSNE(n_components=2, verbose=1, perplexity=400, n_iter=1000, init='pca',learning_rate='auto')

# tsne_results = tsne.fit_transform(phyC_pred_total_subset)

# with open(output_path+'/MLP/tsne_result.pickle', 'wb') as f:
#     pickle.dump((N, rndperm, phyC_pred_total_subset, phyC_g_total_subset, tsne_results), f)

#### Figure 5a Propensity distribution maps

In [None]:
with open(output_path+'/MLP/tsne_result.pickle', 'rb') as f:
    
    N, rndperm, phyC_pred_total_subset, phyC_g_total_subset, tsne_results = pickle.load(f)

nrows, ncols = 28, 18
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(36, 55), layout='constrained', dpi=united_dpi)

for i_provider in range(n_providers):
    for k_label in range(n_labels):
        ax = axes[i_provider % nrows, k_label + (i_provider >= nrows)*n_labels]
        if i_provider % nrows == 0: 
            ax.set_title(col_labels[k_label], fontsize=TICK_SIZE*2)
        if k_label == 0:
            ax.set_ylabel(i_provider+1, fontsize=TICK_SIZE*2)
        im = ax.scatter(tsne_results[:,0], tsne_results[:,1], alpha=1, s=0.2, c=phyC_g_total_subset[:,i_provider,k_label],cmap='viridis', vmin=0.8, vmax=1)
        ax.set_xticks([])
        ax.set_yticks([])
        
for ax in axes[-1,n_labels:]:
    ax.set_visible(False)
    
cbar = fig.colorbar(im, ax=axes.ravel().tolist()[-9:], pad=-0.8, ticks=np.arange(0.8,1.05,0.1), format='%4.1f', fraction=0.6, aspect=30, shrink=0.9, alpha=1, orientation='horizontal')
cbar.ax.tick_params(labelsize=TICK_SIZE*2)

fig.suptitle('t-SNE map for physician\'s propensities', fontsize=LABEL_SIZE*3)
plt.show()

#### Figure 5b Disease predicted probability distribution maps

In [None]:
fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(9, 11), layout='constrained', dpi=united_dpi)

for i_label in range(n_labels):
    ax = axes.flat[i_label]
    im = ax.scatter(tsne_results[:,0], tsne_results[:,1], c=phyC_pred_total_subset[:,i_label], cmap='viridis_r', vmin=0, vmax=1, s=1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(col_labels[i_label], fontsize=LABEL_SIZE)
    
cbar = fig.colorbar(im, ax=axes, pad=0.02, format='%4.1f', orientation='horizontal')
fig.suptitle('t-SNE map for disease predicted probability', fontsize=LABEL_SIZE+2)
cbar.ax.tick_params(labelsize=LABEL_SIZE)

plt.show()

---
### Figure 6 Correlation-Based Interpretation

In [None]:
phyC_feature_total_subset = df_feature.iloc[rndperm[:N]].values

fig, axes = plt.subplots(nrows=9, ncols=1, figsize=(12, 11), layout='constrained', dpi=united_dpi)
_provider = 0
for _label in range(n_labels):
    ax = axes.flat[_label]
    
    ax.plot([-1, n_inputs], [0, 0], '--', c='grey', linewidth=1)
    
    for i_func, coef_func in enumerate([scipy.stats.pearsonr, scipy.stats.spearmanr, scipy.stats.kendalltau]):
        
        coef_label = np.nan_to_num([coef_func(phyC_g_total_subset[:, _provider, _label], phyC_feature_total_subset[:, k_input]).statistic for k_input in range(n_inputs)]) 
        ax.plot(np.arange(n_inputs), coef_label, '-%s' % markers[i_func], color=cm.Set1(i_func), label=correlation_metrics[i_func], linewidth=0, alpha=0.5, ms=4)
    
    
    ax.set_ylim([-0.6,0.6])
    ax.set_xlim([-15, n_inputs])
    ax.text(-13,0.3, '$p_{%d}$-%s'%(_provider+1, col_labels[_label]), fontsize=TICK_SIZE_SMALL/2)
    ax.tick_params(labelsize=TICK_SIZE_SMALL/3)
    if(_label + 1 == n_labels):
        ax.set_xticks(range(n_inputs), df_feature.columns, fontsize=TICK_SIZE_SMALL/3, rotation=90)
    else:
        ax.set_xticks([],[], fontsize=TICK_SIZE_SMALL/3, rotation=90)
        
    ax.legend(loc='lower left', fontsize=TICK_SIZE/3)

In [None]:
phyC_feature_total_subset = df_feature.iloc[rndperm[:N]].values

_provider = 0

fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(10, 9), layout='tight', dpi=united_dpi)

top_k = 10

for _label in range(n_labels):

    ax = axes.flat[_label]

    ax.plot([0, 0], [-1, n_inputs], color='grey')
    ax.plot([-0.3, 0.3], [top_k-0.5, top_k-0.5], '--', color='grey')

    coef_label = np.nan_to_num([sum(
        [coef_func(phyC_g_total_subset[:, _provider, _label], phyC_feature_total_subset[:, k_input]).statistic for coef_func in [scipy.stats.pearsonr, scipy.stats.spearmanr, scipy.stats.kendalltau]]
    ) for k_input in range(n_inputs)], 0) / 3

    top_pos_idx = np.hstack([coef_label.argsort()[:top_k], coef_label.argsort()[-top_k:]])

    ax.barh(range(top_k*2), coef_label[top_pos_idx], height=0.7)

    ax.set_xlim([-0.6,0.6])
    ax.set_ylim([-1, top_k*2])
    ax.set_title('$p_{%d}$-%s'%(_provider+1, col_labels[_label]), fontsize=TICK_SIZE_SMALL/2)
    ax.tick_params(labelsize=TICK_SIZE_SMALL/2)
    ax.set_yticks(range(top_k*2), df_feature.columns[top_pos_idx], fontsize=TICK_SIZE_SMALL/3)

plt.show()

---
### Figure 7 Cluster-Based Interpretation

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram

from sklearn.cluster import AgglomerativeClustering


def plot_dendrogram(model, ax, **kwargs):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack(
        [model.children_, model.distances_, counts]
    ).astype(float)

    # Plot the corresponding dendrogram
    return dendrogram(linkage_matrix, ax=ax, **kwargs)

#### Figure 7 Hierarchical clustering dendrogram/transfer matrix/trainsfer vectors

In [None]:
range_v = 0.12
ddgs = []

fig, axes = plt.subplots(2,2,figsize=(30, 12), gridspec_kw={'width_ratios': [6.5,20], 'height_ratios': [3,8]}, layout='constrained', dpi=united_dpi)
    
k_label = 6
col_name = col_labels[k_label]
i_provider = 0

X = phyC_kernel_total[:,:,k_label].T
model = AgglomerativeClustering(distance_threshold=0, n_clusters=None)
model = model.fit(X)

ax = axes[0, 0]
ax.set_visible(False)
    
# dendrogram
ax = axes[0, 1]
# setting distance_threshold=0 ensures we compute the full tree.
ax.set_title("Hierarchical clustering dendrogram on %s" % col_name, fontsize=LABEL_SIZE)
# plot the top three levels of the dendrogram
ddg = plot_dendrogram(model, ax=ax, truncate_mode="level", p=10)
ddgs.append(ddg)
ax.set_xticks([],[])
ax.yaxis.set_tick_params(labelsize=TICK_SIZE)
ax.xaxis.set_tick_params(labelsize=TICK_SIZE)

# transfer vectors
ax = axes[1, 1]
    
im = ax.imshow(phyC_kernel_total[:,[int(s) for s in ddg['ivl']], k_label], vmin=-range_v, vmax=range_v, aspect=2, cmap='RdBu')
ax.set_title("Transfer vectors", fontsize=LABEL_SIZE)
ax.set_xticks(np.arange(n_providers), [int(i)+1 for i in ddg['ivl']], fontsize=TICK_SIZE, rotation=90) 
ax.set_yticks([],[]) 
ax.set_xlabel('Physician id ', fontsize=LABEL_SIZE, loc='right')
cbar = fig.colorbar(im, ax=ax, format='%4.2f', pad=0.03, shrink=0.7)
cbar.ax.tick_params(labelsize=TICK_SIZE)

# transfer matrix
ax = axes[1, 0]
im = ax.imshow(phyC_kernel_total[:,0,:], vmin=-range_v, vmax=range_v, cmap='RdBu')
ax.set_title("Transfer matrix of $p_%d$" % (i_provider+1), fontsize=LABEL_SIZE)
ax.set_xticks(range(n_labels),col_labels, fontsize=TICK_SIZE, rotation=90)
ax.set_yticks(range(n_labels),col_labels, fontsize=TICK_SIZE) 
    
plt.show()