In [1]:
import os
import re
import warnings
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import statsmodels.api as sm
from catboost import Pool
from pandas import CategoricalDtype
from pandas.core.dtypes.common import is_bool_dtype, is_integer_dtype
warnings.filterwarnings("ignore")

#### Part I: Create the Cross-Site Feature Importance Plot

In [None]:
def collect_shap_cross_site(result_path, aki_subgrp, pred_pt, fs_type, site_labels):
    shap_importances = pd.DataFrame()
    shap_raw = {}
    pred_tasks = ['rvsl', 'stgup']
    for pred_task in pred_tasks:
        result_dict = pd.read_pickle(
            result_path + 'cb_' + aki_subgrp + '_' + pred_task + '_' + str(pred_pt + 1) + 'd_' + fs_type + '.pkl')

        # get outcome label
        if pred_task == 'rvsl':
            outcome_label = 'AKI_RVRT'
        elif pred_task == 'stgup':
            outcome_label = 'AKI_STGUP'
        else:
            raise ValueError(f"Unknown pred_task: {pred_task}")

        for site in site_labels:
            model = result_dict[site]['best_model']
            shapX = result_dict[site]['data_test'].drop(['ID_POD', 'ID_PAT_ENC', 'PATID', 'ENCOUNTERID', outcome_label],
                                                        axis=1)
            shapy = result_dict[site]['data_test'][outcome_label]
            cat_features = list(shapX.select_dtypes(include=['bool']).columns)

            # Create Pool object for SHAP value calculation
            pshap = Pool(data=shapX, label=shapy, cat_features=cat_features)

            # Get SHAP values, last column is bias (base value)
            shap_df = model.get_feature_importance(data=pshap, type='ShapValues', prettified=True)

            # Remove bias term (last column) and assign feature names
            shap_no_bias = shap_df.iloc[:, :-1]
            shap_no_bias.columns = shapX.columns
            # Save the raw shap values 
            shap_raw[site] = shap_no_bias

            # Calculate mean absolute SHAP values and sort them
            shap_importance = shap_no_bias.abs().mean().sort_values(ascending=False).reset_index()
            shap_importance.columns = ['feature', 'MAS']
            shap_importance['rank'] = shap_importance['MAS'].rank(method='min', ascending=False) - 1
            shap_importance = shap_importance[shap_importance['rank'] < 100]
            shap_importance['rank'] = (100 - shap_importance['rank']) / 100
            shap_importance['task'] = pred_task
            shap_importance['site'] = site
            shap_importances = pd.concat([shap_importances, shap_importance], axis=0)
    return shap_importances

In [12]:
def get_shap_plot_data(shap_importances, pred_task, concept):
    df_importances = shap_importances[shap_importances['task'] == pred_task][['feature', 'rank', 'site']].groupby(
        ['feature', 'site']).median().reset_index()
    df_importances['feature_short'] = df_importances['feature']  #.str.split('(').str[0]
    df_importances_stat = df_importances[['feature_short', 'rank']].groupby('feature_short').quantile(
        [0.25, 0.5, 0.75]).reset_index().pivot(index='level_1', columns=['feature_short'],
                                               values='rank').T.reset_index()
    df_importances_stat['IQR'] = df_importances_stat[0.75] - df_importances_stat[0.25]
    df_importances_stat.index = df_importances_stat['feature_short'].copy()
    df_importances_stat = df_importances_stat[[0.5, 'IQR']]
    df_importances_stat.columns = ['Median', 'IQR']

    df_importances_count = df_importances[['feature_short', 'site']].groupby('feature_short').count() / 4
    df_importances_count.columns = ['Count']

    df_importances_stat = df_importances_stat.merge(df_importances_count, left_index=True, right_index=True)
    df_importances_top5 = df_importances_stat[['Median', 'Count']].sort_values('Median', ascending=False).groupby(
        'Count').rank(method='first', ascending=False)
    df_importances_top5.columns = ['Label_rank']

    df_importances_top5 = df_importances_top5[df_importances_top5['Label_rank'] <= 10]
    df_importances_top5['Label_rank'] = 1 - df_importances_top5['Label_rank']
    df_importances_stat = df_importances_stat.merge(df_importances_top5[['Label_rank']], left_index=True,
                                                    right_index=True, how='left').fillna(-100)

    # get readable feature names
    df_importances_stat = df_importances_stat.reset_index()
    df_importances_plot = replace_feature_name(df_importances_stat, 'feature_short', concept, var_nm_dict)
    df_importances_plot['cusheight'] = 1

    return df_importances_plot

def reshape_concept_name(row):
    if pd.isna(row['concept_name']):
        return f"Unmatched code ({row['concept_code']})"
    else:
        concept_name_clean = re.sub(r'[\(\[].*?[\)\]]', '', str(row['concept_name']))
        parts = re.split(r' |,|\||;', concept_name_clean)
        name = ' '.join(parts[:4])
        new_name = name + ' (' + row['concept_code'] + ')'
        return new_name

def filter_by_domain(row):
    # if pd.isna(row['concept_name']):
    #     return True
    if row['domain_code'] in domain_mapping:
        allowed_domains = domain_mapping[row['domain_code']]
        if row['domain_id'] in allowed_domains:
            return True
    return False


var_nm_dict = {'SCR_MEAN': 'SCr (mrv)',
               'SCR_FD': 'SCr (slope)',
               'SCR_RANGE': 'SCr (intra-day range)',
               'SYSTOLIC_MEAN': 'Systolic pressure (mrv)',
               'SYSTOLIC_FD': 'Systolic pressure (slope)',
               'SYSTOLIC_RANGE': 'Systolic pressure (intra-day range)',
               'DIASTOLIC_MEAN': 'Diastolic pressure (mrv)',
               'DIASTOLIC_FD': 'Diastolic pressure (slope)',
               'DIASTOLIC_RANGE': 'Diastolic pressure (intra-day range)',
               'AKI_INIT_STG': 'AKI stage at onset',
               'SCR_BASELINE': 'Baseline SCr',
               'SCR_ONSET': 'Onset SCr',
               'SCR_REFERENCE': 'SCr (2 day pre-onset min)',
               'HT': 'Height',
               'WT': 'Weight',
               'BMI': 'BMI',
               'AGE': 'Age',
               'MALE': 'Male',
               'RACE_BLACK': 'Race: black',
               'RACE_WHITE': 'Race: white',
               'PREADM_CKD_FLAG': 'Pre-adimission CKD',
               'PREADM_CKD_STAGE': 'Pre-adimission CKD stage',
               'ONSET_SINCE_ADMIT': 'Days since admission',
               'POD_1': 'Post-onset day 1',
               'POD_2': 'Post-onset day 2',
               'POD_3': 'Post-onset day 3',
               'POD_4': 'Post-onset day 4',
               'POD_5': 'Post-onset day 5',
               'POD_6': 'Post-onset day 6',
               'POD_7': 'Post-onset day 7'
               }

domain_mapping = {'DX': ['Condition Status', 'Condition', 'Condition/Meas', 'Condition/Device'],
                  'LAB': ['Condition/Meas', 'Measurement', 'Observation'],
                  'LABCAT': ['Condition/Meas', 'Measurement', 'Observation'],
                  'RX': ['Drug'],
                  'PX': ['Procedure', 'Visit', 'Device']}


def replace_feature_name(df, var_col, concept, var_nm_dict):
    df['concept_code'] = df[var_col].apply(lambda x: x.split('_')[-1]).astype('str')
    df['domain_code'] = df[var_col].apply(lambda x: x.split('_')[0]).astype('str')

    df1 = df[df['feature_short'].isin(var_nm_dict.keys())]
    df1['feature_name'] = df1['feature_short'].apply(lambda x: var_nm_dict[x])

    df2 = df[~df['feature_short'].isin(var_nm_dict.keys())]
    df_merged = df2.merge(concept[['concept_code', 'concept_name', 'domain_id']],
                          on='concept_code',
                          how='left')

    df_merged['keep'] = df_merged.apply(filter_by_domain, axis=1)
    df_filtered = df_merged[df_merged['keep']]

    df_filtered = df_filtered.drop(columns=['keep'])
    missing_concepts = df2[~df2['concept_code'].isin(df_filtered['concept_code'])]
    fallback_matches = missing_concepts.merge(concept[['concept_code', 'concept_name', 'domain_id']],
                                              on='concept_code',
                                              how='left').drop_duplicates(subset=['concept_code'])

    df2_final = pd.concat([df_filtered, fallback_matches]).drop_duplicates(subset=['feature_short'])
    df2_final['feature_name'] = df2_final.apply(reshape_concept_name, axis=1)
    df2_final = df2_final.drop(['concept_name', 'domain_id'], axis=1)
    df_replaced = pd.concat([df1, df2_final], axis=0)
    return df_replaced

In [None]:
# Function to plot the SHAP rankings for AKI 1 patients
def plot_shap_aki1(shap_importances, task_name, task_label, subplot_position):
    df_importances_plot = get_shap_plot_data(shap_importances, task_name, concept)

    ft_code_name_mapping = {
        'RX_ATC_N02BF': 'Gabapentinoids (N02BF)',
        'RX_RXN_1545706': 'Bupivacaine hydrochloride injectable (1545706)',
        'RX_RXN_543249': 'Furosemide injectable (543249)',
        'RX_ND_781318895': 'Vancomycin injectable (781318895)',
        'RX_ND_338004904': 'Sodium chloride injectable (338004904)',
        'PX_CH_97116': 'Therapeutic procedure (97116)',
        'LABCAT_10381-2(LOW)': 'Target cells present in blood: Low (10381-2)',
        'LABCAT_802-9(LOW)': 'Spherocytes presence in blood: Low (802-9)',
        'LABCAT_779-9(LOW)': 'Poikilocytosis presence in blood: Low (779-9)',
        'LABCAT_LG40867-0(NEGATIVE)': 'Leukocytes presence in urine: Negative (LG40867-0)',
        'LABCAT_LG40953-8(MODERATE)': 'Leukocyte esterase presence in urine: Moderate (LG40953-8)',
    }

    for ft_code, feature_name in ft_code_name_mapping.items():
        if ft_code in df_importances_plot['feature_short'].values:
            df_importances_plot.loc[df_importances_plot['feature_short'] == ft_code, 'feature_name'] = feature_name

    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_736-9', 'cusheight'] = 1.7
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG13614-9', 'cusheight'] = 2.5
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG12083-8', 'cusheight'] = 2.5
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_48642-3', 'cusheight'] = 1.4
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG5465-2', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_RXN_543249', 'cusheight'] = 2.8
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ND_338004904', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG12080-4', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG1777-4', 'cusheight'] = 1.4
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_48643-1', 'cusheight'] = 2.1
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_61151-7', 'cusheight'] = 1
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_3173-2', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG49829-1', 'cusheight'] = 1.4
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ATC_B02BA', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ATC_A04AA', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'SYSTOLIC_FD', 'cusheight'] = 1.7
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'DIASTOLIC_FD', 'cusheight'] = 1.4
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'SCR_FD', 'cusheight'] = 1.2
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'SYSTOLIC_MEAN', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG5665-7', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'PX_CH_93306', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'PX_CH_96361', 'cusheight'] = 1.4
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'BMI', 'cusheight'] = 2.1
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'MALE', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'DX_401', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'PX_CH_97116', 'cusheight'] = 2.1
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG7247-2', 'cusheight'] = 1.5
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG51070-7', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'PX_CH_Q9950', 'cusheight'] = 1.5
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG6426-3', 'cusheight'] = 1.5
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ND_781318895', 'cusheight'] = 2.3
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'PX_CH_99152', 'cusheight'] = 1.8
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LABCAT_779-9(LOW)', 'cusheight'] = 1.8
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG32857-1', 'cusheight'] = 1.8
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG2807-8', 'cusheight'] = 1.8
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ATC_J01XA', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ATC_J01DH', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ATC_J01CG', 'cusheight'] = 1.2
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ATC_C03CA', 'cusheight'] = 2.2
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ATC_B02AA', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LABCAT_LG40867-0(NEGATIVE)', 'cusheight'] = 2.2
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG6039-4', 'cusheight'] = 1.7

    df = df_importances_plot[df_importances_plot['Count'] <= 1]

    plt.subplot(2, 1, subplot_position)
    plt.scatter(df['Median'], df['Count'], c=df['IQR'], cmap='coolwarm', s=200)

    cbar = plt.colorbar(label='IQR Value')
    cbar.ax.tick_params(labelsize=12)
    cbar.set_label('IQR Value', fontsize=20)

    for idx, row in df.iterrows():
        if ((row['Label_rank'] >= (-6 - 4 * (row['Count'] == 1))) & (
                'Unmatched' not in row['feature_name'])):  #& ('POD_1' not in row['feature_short'])
            plt.annotate(
                row['feature_name'],
                xy=(row['Median'], row['Count']),
                xytext=(row['Median'] + 0.05 * row['Label_rank'] - 0.1,
                        row['Count'] + 0.04 * [-row['cusheight'] if row['Label_rank'] % 2 == 1 else row['cusheight']][
                            0]),
                arrowprops=dict(arrowstyle='-', lw=1),
                fontsize=12
            )
        if (row['feature_short'] == 'LAB_LG6199-6') & (task_name == 'rvsl'):
            plt.annotate(
                row['feature_name'],
                xy=(row['Median'], row['Count']),
                xytext=(row['Median'] + 0.05 * (-5) - 0.1,
                        row['Count'] + 0.04 * [-2.1][0]),
                arrowprops=dict(arrowstyle='-', lw=1),
                fontsize=12
            )

    # Add a dashed red rectangle with transparency
    rect = plt.Rectangle((0.7, 0.73), 0.310, 0.295, linewidth=2, edgecolor='red', facecolor='none', linestyle='--',
                         alpha=0.45)
    plt.gca().add_patch(rect)

    # Set y-ticks to be at intervals of 0.1
    plt.yticks(np.arange(0.25, 1.01, 0.25))
    plt.ylim(0.15, 1.1)
    plt.xlim(0, 1.1)
    plt.grid(True)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.xlabel('Importance Ranking (Median of Relative Ranking)', fontsize=20)
    plt.ylabel('Commonality Across Sites', fontsize=20)
    plt.title(f'({chr(65 + subplot_position - 1)}) {task_label}', fontsize=20)


In [None]:
# Function to plot the SHAP rankings for AKI 2&3 patients
def plot_shap_aki23(shap_importances, task_name, task_label, subplot_position):
    df_importances_plot = get_shap_plot_data(shap_importances, task_name, concept)

    ft_code_name_mapping = {
        'RX_ATC_N02BF': 'Gabapentinoids (N02BF)',
        'RX_RXN_1545706': 'Bupivacaine hydrochloride injectable (1545706)',
        'RX_RXN_543249': 'Furosemide injectable (543249)',
        'RX_ND_781318895': 'Vancomycin injectable (781318895)',
        'RX_ND_338004904': 'Sodium chloride injectable (338004904)',
        'PX_CH_97116':  'Therapeutic procedure (97116)',
        'LABCAT_10381-2(LOW)': 'Target cells present in blood: Low (10381-2)',
        'LABCAT_802-9(LOW)': 'Spherocytes presence in blood: Low (802-9)',
        'LABCAT_779-9(LOW)': 'Poikilocytosis presence in blood: Low (779-9)',
        'LABCAT_LG40867-0(NEGATIVE)': 'Leukocytes presence in urine: Negative (LG40867-0)',
        'LABCAT_LG40953-8(MODERATE)': 'Leukocyte esterase presence in urine: Moderate (LG40953-8)',
    }

    for ft_code, feature_name in ft_code_name_mapping.items():
        if ft_code in df_importances_plot['feature_short'].values:
            df_importances_plot.loc[df_importances_plot['feature_short'] == ft_code, 'feature_name'] = feature_name

    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG10990-6', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_4544-3', 'cusheight'] = 1.5
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'POD_1', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG1314-6', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_RXN_1545706', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG13614-9', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'A12CA', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG11363-5', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG12083-8', 'cusheight'] = 2.1
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_48642-3', 'cusheight'] = 1.2
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG33051-0', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG5465-2', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_RXN_543249', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ND_338004904', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ND_781318895', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG12080-4', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG1777-4', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_48643-1', 'cusheight'] = 2.1
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_61151-7', 'cusheight'] = 1
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_3173-2', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG49829-1', 'cusheight'] = 1.4
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ATC_B02BA', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ATC_A04AA', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'SYSTOLIC_FD', 'cusheight'] = 1.2
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'DIASTOLIC_FD', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'SCR_FD', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'SYSTOLIC_MEAN', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'DIASTOLIC_MEAN', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG5665-7', 'cusheight'] = 1.4
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'PX_CH_93306', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'PX_CH_96361', 'cusheight'] = 1.4
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'BMI', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ATC_J01EE', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG32850-6', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG32886-0', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'DX_401', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'PX_CH_97116', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'PX_CH_70498', 'cusheight'] = 2.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ATC_N02AA', 'cusheight'] = 2.5
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG35626-7', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LAB_LG51070-7', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'LABCAT_LG40867-0(NEGATIVE)', 'cusheight'] = 1.9
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ATC_J01XA', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ATC_J01CG', 'cusheight'] = 1.6
    df_importances_plot.loc[df_importances_plot['feature_short'] == 'RX_ATC_C03CA', 'cusheight'] = 1.6
    df = df_importances_plot[df_importances_plot['Count'] <= 1]

    plt.subplot(2, 1, subplot_position)
    plt.scatter(df['Median'], df['Count'], c=df['IQR'], cmap='coolwarm', s=200)
    cbar = plt.colorbar(label='IQR Value')
    cbar.ax.tick_params(labelsize=12)
    cbar.set_label('IQR Value', fontsize=20)

    for idx, row in df.iterrows():
        if (row['Label_rank'] >= (-6 - 4*(row['Count'] == 1 ) ) ) & ('Unmatched' not in row['feature_name']) : #& ('POD_1' not in row['feature_short'])
            plt.annotate(
                row['feature_name'],
                xy=(row['Median'], row['Count']),
                xytext=(row['Median'] + 0.05 * row['Label_rank'] - 0.1,
                        row['Count'] + 0.04 * [-row['cusheight'] if row['Label_rank'] % 2 == 1 else row['cusheight']][0]),
                arrowprops=dict(arrowstyle='-', lw=1),
                fontsize=12
            )

    rect = plt.Rectangle((0.7, 0.73), 0.310, 0.295, linewidth=2, edgecolor='red', facecolor='none', linestyle='--', alpha=0.45)
    plt.gca().add_patch(rect)
    
    plt.yticks(np.arange(0.25, 1.01, 0.25))
    plt.ylim(0.15, 1.1)
    plt.xlim(0, 1.1)
    plt.grid(True)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.xlabel('Importance Ranking (Median of Relative Ranking)', fontsize=20)
    plt.ylabel('Commonality Across Sites', fontsize=20)
    plt.title(f'({chr(65 + subplot_position - 1)}) {task_label}', fontsize=20)

In [None]:
def plot_shap_all(aki_subgrp, site_labels):
    # Collect SHAP values and rankings
    shap_importances = collect_shap_cross_site(result_path, aki_subgrp, 0, 'no_fs', site_labels)

    plt.figure(figsize=(23, 18))
    if aki_subgrp == 'aki1':
        plot_shap_aki1(shap_importances, 'rvsl', 'AKI Reversal', 1)
        plot_shap_aki1(shap_importances, 'stgup', 'AKI Progression', 2)
    else:
        plot_shap_aki23(shap_importances, 'rvsl', 'AKI Reversal', 1)
        plot_shap_aki23(shap_importances, 'stgup', 'AKI Progression', 2)
    
    plt.tight_layout()
    figure_filename = os.path.join(result_path, 'figure', 'plot_csshap_'+ aki_subgrp +'.png')
    plt.savefig(figure_filename, bbox_inches='tight', dpi=150)
    plt.show()

In [None]:
# Plot Shap rankings across sites
base_path   = './'
result_path =  os.path.join(base_path, 'result') + '/'
aux_path    =  os.path.join(base_path, 'aux_files') + '/'
site_labels = ['Site1', 'Site2', 'Site3', 'Site4']

concept = pd.read_csv(aux_path + 'CONCEPT.csv', sep='\t')

plot_shap_all('aki1', site_labels)
plot_shap_all('aki1', site_labels)

#### Part III: Bootstrap SHAP values and Marginal Effect Plots

In [25]:
def get_boot_shap_data(aki_subgrp, pred_task, fts_shap_boot_lst, fs_type, site_labels, pred_pt, n_boot):
    result_dict = pd.read_pickle(
        result_path + 'cb_' + aki_subgrp + '_' + pred_task + '_' + str(pred_pt + 1) + 'd_' + fs_type + '.pkl')

    if pred_task == 'rvsl':
        outcome_label = 'AKI_RVRT'
    elif pred_task == 'stgup':
        outcome_label = 'AKI_STGUP'
    else:
        raise ValueError(f"Unknown pred_task: {pred_task}")

    all_sites_shap_data = {}

    for site in site_labels:
        # Set a site-specific random seed
        np.random.seed(42)
        model = result_dict[site]['best_model']
        shapX = result_dict[site]['data_test'].drop(['ID_POD', 'ID_PAT_ENC', 'PATID', 'ENCOUNTERID', outcome_label],
                                                    axis=1)
        shapy = result_dict[site]['data_test'][outcome_label]
        positive_class_idx = np.where(shapy == True)[0]
        negative_class_idx = np.where(shapy == False)[0]
        n_positive = len(positive_class_idx)
        n_negative = len(negative_class_idx)

        pred_brkdn_b = []
        x_val_b = []

        ## Bootstrap SHAP calcualations
        for b in range(n_boot):
            sampled_pos_idx = np.random.choice(positive_class_idx, n_positive, replace=True)
            sampled_neg_idx = np.random.choice(negative_class_idx, n_negative, replace=True)
            idxset = np.concatenate([sampled_pos_idx, sampled_neg_idx])

            shapX_boot = shapX.iloc[idxset, :]
            shapy_boot = shapy.iloc[idxset]
            cat_ft_boot = list(shapX_boot.select_dtypes(include=['bool']).columns)

            pshap_boot = Pool(data=shapX_boot, label=shapy_boot, cat_features=cat_ft_boot)
            shap_boot = model.get_feature_importance(data=pshap_boot, type='ShapValues', prettified=True)
            shap_boot_nb = shap_boot.iloc[:, :-1]
            shap_boot_nb.columns = shapX_boot.columns

            shap_sel_df = pd.DataFrame()
            x_key_b = pd.DataFrame()

            for var in fts_shap_boot_lst:
                if var in shap_boot_nb.columns:
                    shap_sel_df[var] = shap_boot_nb[var]
                    x_key_b[var] = shapX_boot[var]
                else:
                    shap_sel_df[var] = np.nan
                    x_key_b[var] = np.nan

            shap_sel_df['boot'] = b
            shap_sel_df['idx'] = idxset
            pred_brkdn_b.append(shap_sel_df)
            
            x_key_b['boot'] = b
            x_key_b['idx'] = idxset
            x_val_b.append(x_key_b)
        ##

        pred_brkdn_b_df = pd.concat(pred_brkdn_b).reset_index(drop=True)
        x_val_df = pd.concat(x_val_b).reset_index(drop=True)

        all_sites_shap_data[site] = {
            'pred_brkdn_b_df': pred_brkdn_b_df,
            'x_val_df': x_val_df,
            'site': site,
            'n_boot': n_boot,
            'pred_task': pred_task,
            'pred_pt': pred_pt,
            'fs_type': fs_type
        }

    pred_brkdn = []
    for site, data in all_sites_shap_data.items():
        pred_brkdn_b_df = data['pred_brkdn_b_df']
        x_val_df = data['x_val_df']
        for var in fts_shap_boot_lst:
            col_dtype = x_val_df[var].dtype

            if (
                    is_bool_dtype(col_dtype)
                    or isinstance(col_dtype, CategoricalDtype)
                    or is_integer_dtype(col_dtype) 
            ):
                df_shap_grp = pd.DataFrame({
                    'val': x_val_df[var].astype('int'),
                    'effect': pred_brkdn_b_df[var],
                    'boot': pred_brkdn_b_df['boot'],
                    'site': site  
                })
            elif ('SYSTOLIC_' in var) | ('DIASTOLIC_' in var):
                df_shap_grp = pd.DataFrame({
                    'val': x_val_df[var].round(0),
                    'effect': pred_brkdn_b_df[var],
                    'boot': pred_brkdn_b_df['boot'],
                    'site': site 
                })
            else:
                df_shap_grp = pd.DataFrame({
                    'val': x_val_df[var].round(2),
                    'effect': pred_brkdn_b_df[var],
                    'boot': pred_brkdn_b_df['boot'],
                    'site': site 
                })

            summary_df = df_shap_grp.groupby(['site', 'boot', 'val']).agg({'effect': 'mean'}).reset_index()
            summary_df['var'] = var
            pred_brkdn.append(summary_df)

    pred_brkdn_df = pd.concat(pred_brkdn, ignore_index=True)
    return pred_brkdn_df

In [5]:
def plot_shap_with_ci_by_site(data, feature, feature_label, colors_dict, title_weight, ax):
    x_mins = []
    x_maxs = []
    for site in site_labels:
        site_data = data[data['site'] == site]
        site_data = site_data[~((site_data['var'] == 'ONSET_SINCE_ADMIT') & (site_data['val'] > 14))]
        
        grouped = site_data.groupby('val').agg(
            median_effect=('effect', 'median'),
            lower=('effect', lambda x: np.quantile(x, 0.025)),
            upper=('effect', lambda x: np.quantile(x, 0.975)),
            count=('effect', 'count')
        )

        sample_size = 200
        if len(grouped) > sample_size:
            grouped = grouped.sample(n=sample_size, random_state=42).sort_index()

        # Scatter plot with error bars for 2.5th and 97.5th percentiles
        ax.errorbar(grouped.index, grouped['median_effect'],
                    yerr=[grouped['median_effect'] - grouped['lower'], grouped['upper'] - grouped['median_effect']],
                    fmt='o', color=colors_dict[site], alpha=0.25, 
                    label=f'{site} Log Odds')

        lowess = sm.nonparametric.lowess
        lowess_fit = lowess(grouped['median_effect'], grouped.index, frac=0.2)
        lowess_x = lowess_fit[:, 0]
        lowess_y = lowess_fit[:, 1]

        residuals = grouped['median_effect'] - np.interp(grouped.index, lowess_x, lowess_y)
        sigma = np.std(residuals)
        ci_upper = lowess_y + 1.96 * sigma
        ci_lower = lowess_y - 1.96 * sigma

        # Plot LOWESS curve and confidence band
        ax.plot(lowess_x, lowess_y, color=colors_dict[site], alpha=0.8, label=f'{site} LOWESS Smoother')
        ax.fill_between(lowess_x, ci_lower, ci_upper, color=colors_dict[site], alpha=0.15,
                        label=f'{site} 95% LOWESS C.I.')

        x_mins.append(site_data[site_data['var'] == feature]['val'].quantile(0.05))
        x_maxs.append(site_data[site_data['var'] == feature]['val'].quantile(0.95))

    if feature not in ['MALE', 'RACE_BLACK', 'PREADM_CKD_FLAG', 'DX_09_401']:  # Binary variables
        x_min = np.max(x_mins)
        x_max = np.min(x_maxs)
        ax.set_xlim(x_min, x_max)
    else:
        ax.set_xticks([0, 1])

    if feature == 'LAB_LG7967-5':
        ax.set_ylim(-0.3, 0.2)
        
    if feature == 'DIASTOLIC_MEAN':
        ax.set_ylim(-0.3, 0.2)
        
    # Plot settings
    ax.axhline(0, color='black', linewidth=0.8, linestyle='--')
    ax.set_xlabel(feature_label, weight=title_weight)
    ax.set_ylabel('') 
    ax.grid(True, linestyle='--', alpha=0.5)

In [6]:
# Function for plotting marginal effects of AKI 1 patients
def plot_marginal_effect_aki1(pred_brkdn_df, fts_all, fts_dict, fts_common, site_labels, result_path):
    colors = sns.color_palette("tab10", 4)
    color_dict = dict(zip(site_labels, colors))

    pred_brkdn_df_rvsl = pred_brkdn_df['rvsl']
    pred_brkdn_df_stgup = pred_brkdn_df['stgup']

    fig, axes = plt.subplots(9, 4, figsize=(20, 30),
                             gridspec_kw={'height_ratios': [1, 1, 1, 1, 0.15, 1, 1, 1, 1]})

    fig.text(0.5, 0.96, '(A) AKI Reversal', ha='center', fontsize=18)
    fig.text(0.5, 0.53, '(B) AKI Progression', ha='center', fontsize=18)

    axes_rvsl = axes[:4, :].flatten()  
    for i, feature in enumerate(fts_all['rvsl']):
        feature_data = pred_brkdn_df_rvsl[pred_brkdn_df_rvsl['var'] == feature]
        if feature == 'LAB_LG6199-6':
            feature_data['val'] = feature_data['val'] / 10
        title_weight = 'bold' if feature in fts_common else 'normal'
        plot_shap_with_ci_by_site(feature_data, feature, fts_dict[feature], color_dict, title_weight, axes_rvsl[i])
    for ax in axes_rvsl[len(fts_all['rvsl']):]: 
        ax.remove()

    for ax in axes[4, :]: 
        ax.axis('off')

    axes_stgup = axes[5:, :].flatten() 
    for i, feature in enumerate(fts_all['stgup']):
        feature_data = pred_brkdn_df_stgup[pred_brkdn_df_stgup['var'] == feature]

        if feature == 'LAB_LG6199-6':
            feature_data['val'] = feature_data['val'] / 10

        title_weight = 'bold' if feature in fts_common else 'normal'
        plot_shap_with_ci_by_site(feature_data, feature, fts_dict[feature], color_dict, title_weight, axes_stgup[i])

    for ax in axes_stgup[len(fts_all['stgup']):]:  
        ax.remove()

    fig.text(0.04, 0.5, 'Median Log Odds', ha='center', va='center', rotation='vertical', fontsize=18)

    handles, labels = [], []
    for ax in axes_stgup:
        for handle, label in zip(*ax.get_legend_handles_labels()):
            if label not in labels:
                handles.append(handle)
                labels.append(label)
    fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, 0.05), bbox_transform=fig.transFigure, ncol=4,
               fontsize=12)

    plt.tight_layout(rect=(0.05, 0.1, 1, 0.95))
    figure_filename = os.path.join(result_path, 'figure', 'plot_me_aki1.png')
    plt.savefig(figure_filename, bbox_inches='tight', dpi=150)
    plt.show()

In [None]:
# Selected features for AKI 1 patinets
stgup_fts_dict = {
    'SCR_BASELINE': 'Baseline SCr (mg/dL)',
    'SCR_FD': 'Most recent SCr (slope, mg/(dL*day))',
    'SCR_MEAN': 'Most recent SCr (level, mg/dL)',
    'SYSTOLIC_FD': 'Systolic pressure (slope, mmHg/day)',
    'SYSTOLIC_MEAN': 'Systolic pressure (level, mmHg)',
    'LAB_LG5465-2': 'Albumin (g/dL)',
    'LAB_LG6199-6': 'Bilirubin (mg/dL)',
    'BMI': 'BMI',
    'DIASTOLIC_FD': 'Diastolic pressure (slope, mmHg/day)',
    'LAB_LG13614-9': 'Anion Gap (mEq/L)',
    'LAB_LG6033-7': 'AST (IU/L)',
    'LAB_LG4454-7': 'CO2 (mEq/L)',

    'LAB_LG7967-5': 'Glucose (mg/dL)',
    'LAB_736-9': 'Lymphocytes/(100*Leukocytes)',
    'LAB_5902-2': 'PT (s)'
}  # stgup
stgup_fts = stgup_fts_dict.keys()

rvsl_fts_dict = {
    'SCR_BASELINE': 'Baseline SCr (mg/dL)',
    'SCR_FD': 'Most recent SCr (slope, mg/(dL*day))',
    'SCR_MEAN': 'Most recent SCr (level, mg/dL)',
    'SYSTOLIC_FD': 'Systolic pressure (slope, mmHg/day)',
    'SYSTOLIC_MEAN': 'Systolic pressure (level, mmHg)',
    'LAB_LG5465-2': 'Albumin (g/dL)',
    'LAB_LG6199-6': 'Bilirubin (mg/dL)',
    'DIASTOLIC_MEAN': 'Diastolic pressure (level, mmHg)',
    'ONSET_SINCE_ADMIT': 'Days since admission',
    'DX_09_401': 'Essential Hypertension',
    'LAB_LG5665-7': 'ALP (IU/L)',
    'LAB_LG12080-4': 'BNP (pg/mL)',
    'LAB_LG7247-2': 'Calcium (mg/dL)',
    'LAB_LG6039-4': 'Lactate (mmol/L)',
    'LAB_LG10990-6': 'Potassium (mEq/L)',
}  # rvsl
rvsl_fts = rvsl_fts_dict.keys()

fts_all = {'rvsl': rvsl_fts,
           'stgup': stgup_fts} # combined

fts_common = ['SCR_BASELINE', 'SCR_FD', 'SCR_MEAN', 'SYSTOLIC_FD', 
              'SYSTOLIC_MEAN', 'LAB_LG5465-2', 'LAB_LG6199-6']  # common across tasks

fts_dict = {
    'SCR_BASELINE': 'Baseline SCr (mg/dL)',
    'SCR_FD': 'Most recent SCr (slope, mg/(dL*day))',
    'SCR_MEAN': 'Most recent SCr (level, mg/dL)',
    'SYSTOLIC_FD': 'Systolic pressure (slope, mmHg/day)',
    'SYSTOLIC_MEAN': 'Systolic pressure (level, mmHg)',
    'DIASTOLIC_FD': 'Diastolic pressure (slope, mmHg/day)',
    'DIASTOLIC_MEAN': 'Diastolic pressure (level, mmHg)',
    'BMI': 'BMI',
    'ONSET_SINCE_ADMIT': 'Days since admission',
    'DX_09_401': 'Essential Hypertension',
    'LAB_LG13614-9': 'Anion Gap (mEq/L)',
    'LAB_LG5465-2': 'Albumin (g/dL)',
    'LAB_LG5665-7': 'ALP (IU/L)',
    'LAB_LG6033-7': 'AST (IU/L)',
    'LAB_LG6199-6': 'Bilirubin (mg/dL)',
    'LAB_LG12080-4': 'BNP (pg/mL)',
    'LAB_LG7247-2': 'Calcium (mg/dL)',
    'LAB_LG4454-7': 'CO2 (mEq/L)',
    'LAB_LG7967-5': 'Glucose (mg/dL)',
    'LAB_LG6039-4': 'Lactate (mmol/L)',
    'LAB_736-9': 'Lymphocytes/(100*Leukocytes)',
    'LAB_LG10990-6': 'Potassium (mEq/L)',
    'LAB_5902-2': 'PT (s)'
}  # all
fts_shap_boot_lst = fts_dict.keys()

In [None]:
# Collect bootstrapped SHAP values for AKI 1 patients
n_boot = 99
pred_tasks = ['rvsl', 'stgup']
marginal_effect_aki1 = {}
for pred_task in pred_tasks:
    marginal_effect_aki1[pred_task] = get_boot_shap_data('aki1', pred_task,
                                                         fts_shap_boot_lst, 'no_fs',
                                                         site_labels, 0, n_boot)

plot_marginal_effect_aki1(marginal_effect_aki1, fts_all, fts_dict, fts_common, site_labels, result_path)

In [None]:
# Function for plotting marginal effects of AKI 2&3 patients
def plot_marginal_effect_aki23(pred_brkdn_df, fts_all, fts_dict, fts_common, site_labels, result_path):
    colors = sns.color_palette("tab10", 4)
    color_dict = dict(zip(site_labels, colors))

    pred_brkdn_df_rvsl = pred_brkdn_df['rvsl']
    pred_brkdn_df_stgup = pred_brkdn_df['stgup']

    fig, axes = plt.subplots(8, 4, figsize=(20, 30), gridspec_kw={'height_ratios': [1, 1, 1, 1, 0.15, 1, 1, 1]})

    fig.text(0.5, 0.96, '(A) AKI Reversal', ha='center', fontsize=18)
    fig.text(0.5, 0.46, '(B) AKI Progression', ha='center', fontsize=18)

    axes_rvsl = axes[:4, :].flatten()
    for i, feature in enumerate(fts_all['rvsl']):
        feature_data = pred_brkdn_df_rvsl[pred_brkdn_df_rvsl['var'] == feature]
        if feature == 'LAB_LG6199-6':
            feature_data['val'] = feature_data['val'] / 10

        title_weight = 'bold' if feature in fts_common else 'normal'
        plot_shap_with_ci_by_site(feature_data, feature, fts_dict[feature], color_dict, title_weight, axes_rvsl[i])
    for ax in axes_rvsl[len(fts_all['rvsl']):]:
        ax.remove()

    for ax in axes[4, :]:
        ax.axis('off')

    axes_stgup = axes[5:, :].flatten()
    for i, feature in enumerate(fts_all['stgup']):
        feature_data = pred_brkdn_df_stgup[pred_brkdn_df_stgup['var'] == feature]

        if feature == 'LAB_LG6199-6':
            feature_data['val'] = feature_data['val'] / 10

        title_weight = 'bold' if feature in fts_common else 'normal'
        plot_shap_with_ci_by_site(feature_data, feature, fts_dict[feature], color_dict, title_weight, axes_stgup[i])

    for ax in axes_stgup[len(fts_all['stgup']):]:  # Remove any unused subplots in stgup section
        ax.remove()

    # Add y-axis label
    fig.text(0.04, 0.5, 'Median Log Odds', ha='center', va='center', rotation='vertical', fontsize=18)

    # Combined legend outside the main plot
    handles, labels = [], []
    for ax in axes_stgup:
        for handle, label in zip(*ax.get_legend_handles_labels()):
            if label not in labels:
                handles.append(handle)
                labels.append(label)
    fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, 0.05),  bbox_transform=fig.transFigure, ncol=4, fontsize=12)

    plt.tight_layout(rect= (0.05, 0.1, 1, 0.95))
    figure_filename = os.path.join(result_path, 'figure', 'plot_me_aki23.png')
    plt.savefig(figure_filename, bbox_inches='tight', dpi=150)
    plt.show()

In [None]:
# Selected features for AKI 2&3 patinets
stgup_fts_dict = {
    'SCR_BASELINE' : 'Baseline SCr (mg/dL)',
    'SCR_FD' : 'Most recent SCr (slope, mg/(dL*day))',
    'SCR_MEAN' : 'Most recent SCr (level, mg/dL)',
    'SYSTOLIC_FD': 'Systolic pressure (slope, mmHg/day)',
    'SYSTOLIC_MEAN': 'Systolic pressure (level)',
    'LAB_LG5465-2': 'Albumin (g/dL)',
    'LAB_LG6033-7': 'AST (IU/L)',
    'LAB_LG7247-2': 'Calcium (mg/dL)',
    'LAB_LG4454-7':  'CO2 (mEq/L)',
    'LAB_LG10990-6': 'Potassium (mEq/L)'
}  # stgup
stgup_fts = stgup_fts_dict.keys()

rvsl_fts_dict = {
    'SCR_BASELINE' : 'Baseline SCr (mg/dL)',
    'SCR_FD' : 'Most recent SCr (slope, mg/(dL*day))',
    'SCR_MEAN' : 'Most recent SCr (level, mg/dL)',
    'SYSTOLIC_FD': 'Systolic pressure (slope, mmHg/day)',
    'SYSTOLIC_MEAN': 'Systolic pressure (level, mmHg)',
    'LAB_LG5465-2': 'Albumin (g/dL)',
    'DIASTOLIC_FD': 'Diastolic pressure (slope, mmHg/day)',
    'DIASTOLIC_MEAN': 'Diastolic pressure (level, mmHg)',
    'AGE': 'Age',
    'ONSET_SINCE_ADMIT': 'Days since admission',
    'LAB_LG5665-7': 'ALP (IU/L)',
    'LAB_LG6199-6': 'Bilirubin (mg/dL)',
    'LAB_LG12080-4': 'BNP (pg/mL)',
    'LAB_4544-3': 'Hematocrit (%)',
    'LAB_LG32850-6': 'RBC (10^6 cells/µL)',
}  # rvsl
rvsl_fts = rvsl_fts_dict.keys()

fts_all = {'rvsl': rvsl_fts,
           'stgup': stgup_fts} # combined 

fts_common = ['SCR_BASELINE', 'SCR_FD', 'SCR_MEAN', 'SYSTOLIC_FD', 'SYSTOLIC_MEAN', 'LAB_LG5465-2'] # common across tasks

fts_dict = {
    'SCR_BASELINE' : 'Baseline SCr (mg/dL)',
    'SCR_FD' : 'Most recent SCr (slope, mg/(dL*day))',
    'SCR_MEAN' : 'Most recent SCr (level, mg/dL)',
    'SYSTOLIC_FD': 'Systolic pressure (slope, mmHg/day)',
    'SYSTOLIC_MEAN': 'Systolic pressure (level, mmHg)',
    'DIASTOLIC_FD': 'Diastolic pressure (slope, mmHg/day)',
    'DIASTOLIC_MEAN': 'Diastolic pressure (level, mmHg)',
    'AGE': 'Age',
    'ONSET_SINCE_ADMIT': 'Days since admission',
    'LAB_LG5465-2': 'Albumin (g/dL)',
    'LAB_LG5665-7': 'ALP (IU/L)',
    'LAB_LG6033-7': 'AST (IU/L)',
    'LAB_LG6199-6': 'Bilirubin (mg/dL)',
    'LAB_LG12080-4': 'BNP (pg/mL)',
    'LAB_LG7247-2':  'Calcium (mg/dL)',
    'LAB_LG4454-7':  'CO2 (mEq/L)',
    'LAB_4544-3': 'Hematocrit (%)',
    'LAB_LG10990-6': 'Potassium (mEq/L)',
    'LAB_LG32850-6': 'RBC (10^6 cells/µL)'
}  # all
fts_shap_boot_lst = fts_dict.keys()

In [None]:
# Collect bootstrapped SHAP values for AKI 2&3 patients
n_boot = 99
pred_tasks = ['rvsl', 'stgup']
marginal_effect_aki23 = {}
for pred_task in pred_tasks:
    marginal_effect_aki23[pred_task] = get_boot_shap_data('aki23', pred_task,
                                                         fts_shap_boot_lst, 'no_fs',
                                                         site_labels, 0, n_boot)
plot_marginal_effect_aki23(marginal_effect_aki23, fts_all, fts_dict, fts_common, site_labels, result_path)