In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.isotonic import IsotonicRegression
from sklearn.metrics import (brier_score_loss, precision_recall_curve, roc_curve, roc_auc_score, auc)

#### Part I: Collect Performance Metrics

In [None]:
def get_performance_metrics(pred, real, keep_all_cutoffs=False):
    fpr_values, tpr_values, roc_thresholds = roc_curve(real, pred)
    auroc = roc_auc_score(real, pred)

    prec_values, sens_values, prc_thresholds = precision_recall_curve(real, pred)
    auprc = auc(sens_values, prec_values)

    perf_roc = pd.DataFrame({
        'roc_cutoff': roc_thresholds,
        'tpr': tpr_values,
        'fpr': fpr_values
    })

    perf_prc = pd.DataFrame({
        'prc_cutoff': np.append(prc_thresholds, 1.0),
        'prec': prec_values,
        'sens': sens_values
    })

    perf_summ = {
        'auroc': auroc,
        'auprc': auprc
    }

    perf_summ_df = pd.DataFrame(list(perf_summ.items()), columns=['overall_meas', 'meas_val'])

    if keep_all_cutoffs:
        return {'perf_summ':perf_summ_df, 'perf_roc': perf_roc, 'perf_prc': perf_prc}
    else:
        return {'perf_summ':perf_summ_df}

def get_performance_curve(data, site_label, pred_pt, pred_task, fs_type, grp):
    perf_at = get_performance_metrics(data['pred'], data['y'], keep_all_cutoffs=True)
    perf_roc_data = perf_at['perf_roc']
    perf_roc_data['site'] = site_label
    perf_roc_data['pred_pt'] = pred_pt
    perf_roc_data['pred_task'] = pred_task
    perf_roc_data['fs_type'] = fs_type
    perf_roc_data['grp'] = grp
    perf_roc_data['meas_value'] = perf_at['perf_summ'][perf_at['perf_summ']['overall_meas'] == 'auroc']['meas_val'].values[0]

    # Calculate performance at all cutoffs
    perf_prc_data = perf_at['perf_prc']
    perf_prc_data['site'] = site_label
    perf_prc_data['pred_pt'] = pred_pt
    perf_prc_data['pred_task'] = pred_task
    perf_prc_data['fs_type'] = fs_type
    perf_prc_data['grp'] = grp
    perf_prc_data['auroc'] = perf_at['perf_summ'][perf_at['perf_summ']['overall_meas'] == 'auprc']['meas_val'].values[0]
    return perf_roc_data, perf_prc_data

def get_calibration_curve(pred, real, n_bin=20, cal_set = True, iso_reg_model = None):
    calib = pd.DataFrame({'pred': pred, 'y': real})

    calib = calib.sort_values(by='pred').reset_index(drop=True)

    calib['pred_bin'] = pd.cut(calib['pred'],
                               bins=np.unique(np.quantile(calib['pred'], np.linspace(0, 1, n_bin + 1))),
                               include_lowest=True,
                               labels=False) + 1  # Add 1 to labels to start binning from 1

    calib_summary = calib.groupby('pred_bin').apply(lambda df: pd.Series({
        'expos': len(df),
        'bin_lower': df['pred'].min(),
        'bin_upper': df['pred'].max(),
        'bin_mid': df['pred'].median(),
        'y_agg': df['y'].sum(),
        'pred_p': df['pred'].mean()
    })).reset_index()

    brier_stat_uncal = brier_score_loss(calib['y'], calib['pred'])
    calib_summary['brier_uncal'] = brier_stat_uncal
    calib_summary['y_p'] = calib_summary['y_agg'] / calib_summary['expos']
    calib_summary['binCI_lower'] = np.maximum(0, calib_summary['pred_p'] - 1.96 * np.sqrt(calib_summary['y_p'] * (1 - calib_summary['y_p']) / calib_summary['expos']))
    calib_summary['binCI_upper'] = calib_summary['pred_p'] + 1.96 * np.sqrt(calib_summary['y_p'] * (1 - calib_summary['y_p']) / calib_summary['expos'])

    if cal_set:
        # Fit isotonic regression model
        iso_reg_cal = IsotonicRegression(out_of_bounds='clip')
        iso_reg_cal.fit(calib_summary['bin_mid'], calib_summary['y_p'])
        return iso_reg_cal
    else: 
        # Apply recalibration to original predicted probabilities
        calib['pred_recal'] = iso_reg_model.predict(calib['pred']) 
        brier_stat_recal = brier_score_loss(calib['y'], calib['pred_recal'])
        calib_summary['brier_recal'] = brier_stat_recal
        calib_summary['pred_recal'] = iso_reg_model.predict(calib_summary['bin_mid'])
        return calib_summary

In [None]:
def collect_performance_metrics(model_label, pred_pt, aki_subgrp, site_labels, subset = True, n_boots = 199):
    # Get bootstrapped AUROC and AUPRC
    roc_curve_tbl = pd.DataFrame()
    prc_curve_tbl = pd.DataFrame()
    perf_tbl = pd.DataFrame()
    
    pred_task_lst = ['rvsl', 'stgup']
    fs_type_opt = ['no_fs', 'rm_scr_bun']  
    perf_metrics = ['auroc', 'auprc']

    for site_label in site_labels:
        for pred_task in pred_task_lst:
            if pred_task == 'rvsl':
                outcome_label = 'AKI_RVRT'
            elif pred_task == 'stgup':
                outcome_label = 'AKI_STGUP'
            else:
                raise ValueError(f"Unknown pred_task: {pred_task}")
            
            for fs_type in fs_type_opt:
                result_dict = pd.read_pickle(model_path + model_label + '_' + aki_subgrp +'_' + pred_task + '_' + str(pred_pt + 1) + 'd_' + fs_type + '.pkl')
                if model_label in ['rlr', 'rlr_tmpr']:
                    data_label = 'data_test_raw'
                else: 
                    data_label = 'data_test'
                valid_orig = result_dict[site_label][data_label]
                valid_orig['pred'] = result_dict[site_label]['y_pred']
                valid_orig['y'] = valid_orig[outcome_label]
                valid_orig['age'] = pd.cut(valid_orig['AGE'], bins=[0, 45, 65, np.inf], right=False)

                valid_orig['sex'] = np.where(valid_orig['MALE'], 
                                                        'Male',
                                                        'Female')
                valid_orig['race'] = np.where(valid_orig['RACE_WHITE'], 
                                        'White',
                                        'Nonwhite')
                valid_orig['ckd_stg'] = np.where(valid_orig['PREADM_CKD_FLAG'], 
                                                        'CKD',
                                                        'Non_CKD')
                valid_orig['hispanic'] = np.where(valid_orig['HISPANIC'], 
                                        'HISPANIC',
                                        'Non_HISPANIC')
                valid_orig['pod'] = np.where(valid_orig['ONSET_SINCE_ADMIT'] == 0, 
                                        'ADM_DAY_ONSET',
                                        'POST_ADM_DAY_ONSET')                

                overall_roc, overall_prc = get_performance_curve(valid_orig, site_label, pred_pt, pred_task, fs_type, 'Overall')
                
                roc_curve_tbl = pd.concat([roc_curve_tbl, overall_roc], ignore_index=True)  
                prc_curve_tbl = pd.concat([prc_curve_tbl, overall_prc], ignore_index=True)
                
                ## Subgroup analysis
                if subset:
                    # Subgroup by age
                    subgrp_roc_age = pd.DataFrame()
                    subgrp_prc_age = pd.DataFrame()
                    for grp in valid_orig['age'].unique():
                        valid_grp = valid_orig[valid_orig['age'] == grp]
                        grp_roc, grp_prc = get_performance_curve(valid_grp, site_label, pred_pt, pred_task, fs_type, f"Subgrp_age:{grp}")
                        subgrp_roc_age = pd.concat([subgrp_roc_age, grp_roc], ignore_index=True)
                        subgrp_prc_age = pd.concat([subgrp_prc_age, grp_prc], ignore_index=True)

                    # Subgroup by sex
                    subgrp_roc_sex = pd.DataFrame()
                    subgrp_prc_sex = pd.DataFrame()
                    for grp in valid_orig['sex'].unique():
                        valid_grp = valid_orig[valid_orig['sex'] == grp]
                        grp_roc, grp_prc = get_performance_curve(valid_grp, site_label, pred_pt, pred_task, fs_type, f"Subgrp_sex:{grp}")
                        subgrp_roc_sex = pd.concat([subgrp_roc_sex, grp_roc], ignore_index=True)
                        subgrp_prc_sex = pd.concat([subgrp_prc_sex, grp_prc], ignore_index=True)

                    # Subgroup by day of event
                    subgrp_roc_pod = pd.DataFrame()
                    subgrp_prc_pod = pd.DataFrame()
                    for grp in valid_orig['pod'].unique():
                        valid_grp = valid_orig[valid_orig['pod'] == grp]
                        grp_roc, grp_prc = get_performance_curve(valid_grp, site_label, pred_pt, pred_task, fs_type, f"Subgrp_pod:{grp}")
                        subgrp_roc_pod = pd.concat([subgrp_roc_pod, grp_roc], ignore_index=True)
                        subgrp_prc_pod = pd.concat([subgrp_prc_pod, grp_prc], ignore_index=True)

                    # Combine all results
                    roc_curve_tbl = pd.concat([roc_curve_tbl, subgrp_roc_age, subgrp_roc_sex, subgrp_roc_pod], ignore_index=True)
                    prc_curve_tbl = pd.concat([prc_curve_tbl, subgrp_prc_age, subgrp_prc_sex, subgrp_prc_pod], ignore_index=True)
                ##

                # Bootstrap the performance metrics
                for b in range(n_boots):
                    rng_boot = b + 1234
                    valid = valid_orig.sample(frac=1, replace=True, random_state = rng_boot)
                    # Overall performance metrics
                    perf_overall = get_performance_metrics(valid['pred'], valid['y'], keep_all_cutoffs=False)['perf_summ']
                    perf_overall = perf_overall[perf_overall['overall_meas'].isin(perf_metrics)]
                    perf_overall['size'] = len(valid)
                    perf_overall['grp'] = "Overall"
                    perf_overall['pred_task'] = pred_task

                    ## Subgroup analysis
                    if subset:
                        # Subgroup by age group
                        subgrp_age = pd.DataFrame()

                        valid = valid_orig.groupby('age', group_keys=False).apply(
                            lambda x: x.sample(frac=1, replace=True, random_state = rng_boot))

                        for grp in valid['age'].unique():
                            valid_grp = valid[valid['age'] == grp]
                            grp_summ = get_performance_metrics(valid_grp['pred'], valid_grp['y'], keep_all_cutoffs=False)['perf_summ']
                            grp_summ = grp_summ[grp_summ['overall_meas'].isin(perf_metrics)]
                            grp_summ['size'] = len(valid_grp)
                            grp_summ['grp'] = f"Subgrp_age:{grp}"
                            grp_summ['pred_task'] = pred_task
                            subgrp_age = pd.concat([subgrp_age, grp_summ], ignore_index=True)
                            
                        # Subgroup by sex
                        subgrp_sex = pd.DataFrame()

                        valid = valid_orig.groupby('sex', group_keys=False).apply(
                            lambda x: x.sample(frac=1, replace=True, random_state = rng_boot))

                        for grp in valid['sex'].unique():
                            valid_grp = valid[valid['sex'] == grp]
                            grp_summ = get_performance_metrics(valid_grp['pred'], valid_grp['y'], keep_all_cutoffs=False)['perf_summ']
                            grp_summ = grp_summ[grp_summ['overall_meas'].isin(perf_metrics)]
                            grp_summ['size'] = len(valid_grp)
                            grp_summ['grp'] = f"Subgrp_sex:{grp}"
                            grp_summ['pred_task'] = pred_task
                            subgrp_sex = pd.concat([subgrp_sex, grp_summ], ignore_index=True)

                        # Subgroup by race
                        subgrp_race = pd.DataFrame()

                        valid = valid_orig.groupby('race', group_keys=False).apply(
                            lambda x: x.sample(frac=1, replace=True, random_state = rng_boot))

                        for grp in valid['race'].unique():
                            valid_grp = valid[valid['race'] == grp]
                            grp_summ = get_performance_metrics(valid_grp['pred'], valid_grp['y'], keep_all_cutoffs=False)['perf_summ']
                            grp_summ = grp_summ[grp_summ['overall_meas'].isin(perf_metrics)]
                            grp_summ['size'] = len(valid_grp)
                            grp_summ['grp'] = f"Subgrp_race:{grp}"
                            grp_summ['pred_task'] = pred_task
                            subgrp_race = pd.concat([subgrp_race, grp_summ], ignore_index=True)
    
                        # Subgroup by day of event
                        subgrp_pod = pd.DataFrame()

                        valid = valid_orig.groupby('pod', group_keys=False).apply(
                            lambda x: x.sample(frac=1, replace=True, random_state = rng_boot))

                        for grp in valid['pod'].unique():
                            valid_grp = valid[valid['pod'] == grp]
                            grp_summ = get_performance_metrics(valid_grp['pred'], valid_grp['y'], keep_all_cutoffs=False)['perf_summ']
                            grp_summ = grp_summ[grp_summ['overall_meas'].isin(perf_metrics)]
                            grp_summ['size'] = len(valid_grp)
                            grp_summ['grp'] = f"Subgrp_pod:{grp}"
                            grp_summ['pred_task'] = pred_task
                            subgrp_pod = pd.concat([subgrp_pod, grp_summ], ignore_index=True)

                        # Combine all results
                        perf_overall = pd.concat([perf_overall, subgrp_age, subgrp_sex, subgrp_race,  subgrp_pod], ignore_index=True)
                        ##
                    perf_overall['pred_pt']  =  pred_pt
                    perf_overall['fs_type']  =  fs_type
                    perf_overall['site']     =  site_label
                    perf_overall['meas_val'] =  perf_overall['meas_val'].round(4)

                    perf_tbl = pd.concat([perf_tbl, perf_overall.assign(boots=b)], ignore_index=True)
                    
    return {'roc_curve' : roc_curve_tbl,
            'prc_curve' : prc_curve_tbl,
            'perf_tbl': perf_tbl}
                

In [None]:
def collect_calibration_results(model_label, pred_pt, aki_subgrp, site_labels, n_boots):
    calib_all = pd.DataFrame()
    
    pred_task_lst = ['rvsl', 'stgup']
    fs_type_opt = ['no_fs', 'rm_scr_bun'] 
    
    for site_label in site_labels:
        for pred_task in pred_task_lst:
            if pred_task == 'rvsl':
                outcome_label = 'AKI_RVRT'
            elif pred_task == 'stgup':
                outcome_label = 'AKI_STGUP'
            else:
                raise ValueError(f"Unknown pred_task: {pred_task}")

            for fs_type in fs_type_opt:
                result_dict = pd.read_pickle(model_path + model_label + '_' + aki_subgrp +'_' + pred_task + '_' + str(pred_pt + 1) + 'd_' + fs_type + '.pkl')

                valid_orig = result_dict[site_label]['data_test']
                valid_orig['y_test_pred'] = result_dict[site_label]['y_pred']
                valid_orig['y_test'] = valid_orig[outcome_label]
                df_cal = result_dict[site_label]['data_val']
            
                if model_label == 'cb':
                    x_cal = df_cal.drop(['ID_POD', 'ID_PAT_ENC', 'PATID', 'ENCOUNTERID', outcome_label], axis = 1)
                elif  model_label == 'cb_tmpr':
                    x_cal = df_cal.drop(['ID_POD', 'ID_PAT_ENC', 'PATID', 'ENCOUNTERID', 'BCCOVID', outcome_label], axis = 1)
                else:
                    raise ValueError(f"Unknown model: {model_label}")

                y_cal_pred = result_dict[site_label]['best_model'].predict_proba(x_cal)[:, 1]
                y_cal = result_dict[site_label]['data_val'][outcome_label]

                recal_model = get_calibration_curve(y_cal_pred, y_cal, cal_set = True)
                calib_orig = get_calibration_curve(valid_orig['y_test_pred'], valid_orig['y_test'], 
                           cal_set = False, 
                           iso_reg_model =recal_model)
                calib_tbl = pd.DataFrame()

                # Bootstrap the calibration metrics
                for b in range(n_boots):
                    rng_boot = b + 1234
                    valid = valid_orig.sample(frac=1, replace=True, random_state = rng_boot)

                    calib = get_calibration_curve(valid['y_test_pred'], valid['y_test'], cal_set = False, iso_reg_model =recal_model)
                    calib['boots'] = b
                    calib_tbl = pd.concat([calib_tbl, calib], ignore_index=True)
            
                calib_tbl['brier_orig'] = calib_orig['brier_uncal'].values[0]
                calib_tbl['brier_recal_orig'] = calib_orig['brier_recal'].values[0]
                calib_tbl['site'] = site_label
                calib_tbl['fs_type'] = fs_type
                calib_tbl['pred_task'] = pred_task
                calib_all =  pd.concat([calib_all, calib_tbl], ignore_index=True)

    return calib_all             

In [None]:
# Set paths
base_path = './'
model_path = os.path.join(base_path, 'model') + '/'
result_path = os.path.join(base_path, 'result') + '/'

# Collect performance metrics
nboots = 99
site_labels = ['Site1', 'Site2', 'Site3', 'Site4']

perf_aki1_dict = {'cb': collect_performance_metrics('cb', 0, 'aki1', site_labels, subset=True, n_boots=nboots),
                  'cb_tmpr': collect_performance_metrics('cb_tmpr', 0, 'aki1', site_labels, subset=True,
                                                         n_boots=nboots),
                  'rlr': collect_performance_metrics('rlr', 0, 'aki1', site_labels, subset=True, n_boots=nboots),
                  'rlr_tmpr': collect_performance_metrics('rlr_tmpr', 0, 'aki1', site_labels, subset=True,
                                                          n_boots=nboots)}
perf_aki23_dict = {'cb': collect_performance_metrics('cb', 0, 'aki23', site_labels, subset=True, n_boots=nboots),
                   'cb_tmpr': collect_performance_metrics('cb_tmpr', 0, 'aki23', site_labels, subset=True,
                                                          n_boots=nboots),
                   'rlr': collect_performance_metrics('rlr', 0, 'aki23', site_labels, subset=True, n_boots=nboots),
                   'rlr_tmpr': collect_performance_metrics('rlr_tmpr', 0, 'aki23', site_labels, subset=True,
                                                           n_boots=nboots)}

# Collect calibration results
calibtbl = {'aki1': {}, 'aki23': {}}

calibtbl['aki1']['rnd'] = collect_calibration_results('cb', 0, 'aki1', site_labels, nboots)
calibtbl['aki23']['rnd'] = collect_calibration_results('cb', 0, 'aki23', site_labels, nboots)
calibtbl['aki1']['tmpr'] = collect_calibration_results('cb_tmpr', 0, 'aki1', site_labels, nboots)
calibtbl['aki23']['tmpr'] = collect_calibration_results('cb_tmpr', 0, 'aki23', site_labels, nboots)

#### Part II: Create Plot of ROC/PRC Curves

In [None]:
def plot_pm_cross_site(perf_dicts, aki_subgrp, meas_type, result_path):
    scenarios = ['cb', 'cb_tmpr', 'rlr', 'rlr_tmpr']
    scenario_labels = {'cb':'Catboost, Int. Val.', 
                       'cb_tmpr':'Catboost, Temp. Val.',
                       'rlr': 'Logistic Reg., Int. Val.',
                       'rlr_tmpr': 'Logistic Reg., Temp. Val.'}  

    subplot_config = [
            ('rvsl', 'Reversal',    'no_fs',      ' (all fts.)'),
            ('rvsl', 'Reversal',    'rm_scr_bun', ' (no SCr/BUN)'),
            ('stgup', 'Progression', 'no_fs',      ' (all fts.)'),
            ('stgup', 'Progression', 'rm_scr_bun', ' (no SCr/BUN)')
        ]

    fig, axes = plt.subplots(nrows=4, ncols= len(scenarios), figsize=(16, 12), sharex=True, sharey=True)

    for col, scenario in enumerate(scenarios):
        perf_dict = perf_dicts[scenario]
        perf_tbl = perf_dict['perf_tbl'].groupby(
            ['site', 'pred_pt', 'pred_task', 'fs_type', 'grp', 'overall_meas']
        ).agg({
            'size': 'mean',
            'meas_val': ['median', lambda x: np.quantile(x, 0.025), lambda x: np.quantile(x, 0.975)]
        }).reset_index()
        perf_tbl.columns = ['site','pred_pt', 'pred_task', 'fs_type', 'grp', 'overall_meas', 'size', 'meas_med', 'meas_lb', 'meas_ub']

        for row, (pred_task, task_label, fs, fs_label) in enumerate(subplot_config):
            ax = axes[row, col]
            curve_data = perf_dict[meas_type + '_curve']
            curve_data = curve_data[
                (curve_data['pred_task'] == pred_task) &
                (curve_data['grp'] == 'Overall') &
                (curve_data['fs_type'] == fs)
            ]
            tbl_data = perf_tbl[
                (perf_tbl['overall_meas'] == ('au' + meas_type)) &
                (perf_tbl['pred_task'] == pred_task) &
                (perf_tbl['grp'] == 'Overall') &
                (perf_tbl['fs_type'] == fs)
            ]

            for site, site_data in curve_data.groupby('site'):
                boot_meas = tbl_data[tbl_data['site'] == site]
                site_label = site + ':' + f"{np.round(boot_meas['meas_med'].values[0],2)} ({np.round(boot_meas['meas_lb'].values[0],2)}, {np.round(boot_meas['meas_ub'].values[0],2)})"
                x_data = site_data['sens'] if meas_type == 'prc' else site_data['fpr']
                y_data = site_data['prec'] if meas_type == 'prc' else site_data['tpr']

                line, = ax.plot(x_data, y_data, label=site_label)
                ax.plot(x_data[::120], y_data[::120], '^', color=line.get_color())  # Markers

            if meas_type == 'roc':
                ax.plot([0, 1], [0, 1], 'k--', lw=2)

            ax.set_title(f' {task_label + fs_label} : {scenario_labels[scenario]}', fontsize=9)
            ax.legend(loc='lower right' if meas_type == 'roc' else 'upper right', fontsize='small')

    x_label = 'False Positive Rate' if meas_type == 'roc' else 'Recall'
    y_label = 'True Positive Rate'  if meas_type == 'roc' else 'Precision'
    
    fig.text(0.5, 0.04, x_label, ha='center', va='center', fontsize=12)
    fig.text(0.04, 0.5, y_label, ha='center', va='center', rotation='vertical', fontsize=12)

    plt.tight_layout(rect=(0.05, 0.05, 1, 1))

    figure_filename = os.path.join(result_path, 'figure', 'plot_eval_'+ aki_subgrp + meas_type + '.png')
    plt.savefig(figure_filename, bbox_inches='tight', dpi=150)
    plt.show()

In [None]:
# Generate ROC and PRC curves 
# AKI 1 at onset
plot_pm_cross_site(perf_aki1_dict, 'aki1', 'roc', result_path)
plot_pm_cross_site(perf_aki1_dict, 'aki1', 'prc', result_path)
# AKI 2&3 at onset
plot_pm_cross_site(perf_aki23_dict, 'aki23', 'roc', result_path)
plot_pm_cross_site(perf_aki23_dict, 'aki23', 'prc', result_path)

#### Part III: Create Plot of Calibration Curves

In [None]:
def plot_calibration_curve(calib_rnd_tbl, calib_tmpr_tbl, aki_subgrp, site_labels, result_path="."):
    task_labels = {'rvsl': 'Reversal', 'stgup': 'Progression'}
    pred_task_lst = list(task_labels.keys()) 

    fs_labels = {'no_fs': 'all fts.', 'rm_scr_bun': 'no SCr/BUN'}
    fs_type_opt = list(fs_labels.keys())    

    subplot_titles = [(pred_task, fs_type) 
                      for pred_task in pred_task_lst 
                      for fs_type in fs_type_opt]
    
    n_rows, n_cols = 2, 4
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, 8), sharex=True, sharey=True)

    global_legend = [
        plt.Line2D([0], [0], linestyle='--', color='k', label='Uncalibrated'),
        plt.Line2D([0], [0], linestyle='-', color='k', label='Recalibrated')
    ]
    
    def plot_one_subplt(ax, calib_tbl, pred_task, fs_type, site_labels, 
                        task_labels, fs_labels, row_label):

        calib_data = calib_tbl[
            (calib_tbl['pred_task'] == pred_task) &
            (calib_tbl['fs_type'] == fs_type)
        ]
        
        agg_data = calib_data.groupby(['pred_bin', 'site']).agg(
            median_pred_p=('pred_recal', 'median'),  # Recalibrated probability
            median_pred_uncal=('bin_mid', 'median'), # Uncalibrated bin midpoint
            median_y_p=('y_p', 'median'),            # Observed probability
            y_p_lower=('y_p', lambda x: np.quantile(x, 0.025)),
            y_p_upper=('y_p', lambda x: np.quantile(x, 0.975))
        ).reset_index()
        
        for site in site_labels:
            site_data = agg_data[agg_data['site'] == site]
            if site_data.empty:
                continue
            
            site_rows = calib_data[calib_data['site'] == site]
            if 'brier_orig' in site_rows.columns:
                brier_orig = site_rows['brier_orig'].iloc[0]
            else:
                brier_orig = np.nan
            
            if 'brier_uncal' in site_rows.columns:
                brier_cil = site_rows['brier_uncal'].quantile(0.025)
                brier_ciu = site_rows['brier_uncal'].quantile(0.975)
            else:
                brier_cil, brier_ciu = np.nan, np.nan

            color_idx = site_labels.index(site)

            site_label = f"{site}:{brier_orig:.2f} ({brier_cil:.2f}, {brier_ciu:.2f})"
            
            ax.errorbar(
                site_data['median_pred_uncal'], 
                site_data['median_y_p'],
                yerr=[
                    site_data['median_y_p'] - site_data['y_p_lower'],
                    site_data['y_p_upper'] - site_data['median_y_p']
                ],
                fmt='--o', color=f"C{color_idx}", alpha=0.35, 
                capsize=2, label=None
            )
            
            ax.errorbar(
                site_data['median_pred_p'], 
                site_data['median_y_p'],
                yerr=[
                    site_data['median_y_p'] - site_data['y_p_lower'],
                    site_data['y_p_upper'] - site_data['median_y_p']
                ],
                fmt='-o', color=f"C{color_idx}", capsize=2, label=site_label
            )
        
        ax.plot([0, 1], [0, 1], 'k--', lw=1, label='Perfect Calibration')
        ax.set_title(f"{task_labels[pred_task]} ({fs_labels[fs_type]}): {row_label}", fontsize=10)
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.legend(loc='upper left', fontsize=7.5)
    
    # Row 0: Random validation
    for col_idx, (pred_task, fs_type) in enumerate(subplot_titles):
        ax = axes[0, col_idx]
        plot_one_subplt(
            ax=ax,
            calib_tbl=calib_rnd_tbl,
            pred_task=pred_task,
            fs_type=fs_type,
            site_labels=site_labels,
            task_labels=task_labels,
            fs_labels=fs_labels,
            row_label="Int. Val."
        )

    # Row 1: Temporal validation
    for col_idx, (pred_task, fs_type) in enumerate(subplot_titles):
        ax = axes[1, col_idx]
        plot_one_subplt(
            ax=ax,
            calib_tbl=calib_tmpr_tbl,
            pred_task=pred_task,
            fs_type=fs_type,
            site_labels=site_labels,
            task_labels=task_labels,
            fs_labels=fs_labels,
            row_label="Temp. Val."
        )
    
    x_label = "Predicted Probability"
    y_label = "Observed Probability"
    fig.text(0.5, 0.06, x_label, ha='center', va='center', fontsize=12)
    fig.text(0.01, 0.5, y_label, ha='center', va='center', rotation='vertical', fontsize=12)
    
    fig.legend(
        handles=global_legend, 
        loc='lower center', 
        ncol=2, 
        fontsize=10, 
        bbox_to_anchor=(0.5, 0)
    )
    
    plt.tight_layout()
    plt.subplots_adjust(left=0.04, top=0.9, bottom=0.12)
    
    figure_filename = os.path.join(result_path, 'figure', f'plot_calbr_{aki_subgrp}.png')
    plt.savefig(figure_filename, bbox_inches='tight', dpi=150)
    plt.show()


In [None]:
# Generate plot of calibration curves
# AKI 1
plot_calibration_curve(
    calib_rnd_tbl=calibtbl['aki1']['rnd'],
    calib_tmpr_tbl=calibtbl['aki1']['tmpr'],
    aki_subgrp='aki1',
    site_labels=site_labels,
    result_path=result_path
)
# AKI 2&3
plot_calibration_curve(
    calib_rnd_tbl=calibtbl['aki23']['rnd'],
    calib_tmpr_tbl=calibtbl['aki23']['tmpr'],
    aki_subgrp='aki23',
    site_labels=site_labels,
    result_path=result_path
)

#### Part IV: Subgroup Analysis

In [None]:
def create_subgrp_table(perf_dict, aki_subgrp, pred_task, meas_type):
    perf_tbl = perf_dict['perf_tbl'].groupby(['site', 'pred_pt', 'pred_task', 'fs_type', 'grp', 'overall_meas']).agg({
        'size': 'mean',
        'meas_val': ['median', lambda x: np.quantile(x, 0.025), lambda x: np.quantile(x, 0.975)]
    }).reset_index()
    perf_tbl.columns = ['site','pred_pt', 'pred_task', 'fs_type', 'grp', 'overall_meas', 'size', 'meas_med', 'meas_lb', 'meas_ub']


    perf_tbl_filtered = perf_tbl[(perf_tbl['fs_type'] == 'no_fs') & (perf_tbl['overall_meas'] == meas_type) & (perf_tbl['pred_task'] == pred_task)]
    def format_meas(row):
        median = np.round(row['meas_med'], 2)
        lb = np.round(row['meas_lb'], 2) 
        ub = np.round(row['meas_ub'], 2) 
        return f"{median} ({lb}, {ub})"

    perf_tbl_filtered['meas_formatted'] = perf_tbl_filtered.apply(format_meas, axis=1)

    perf_tbl_pivot = perf_tbl_filtered.pivot(index='grp', columns='site', values='meas_formatted').reset_index()
    
    file_name = os.path.join(result_path, 'table', 'subgroup_'+ aki_subgrp + '_' + meas_type + '.csv')
    perf_tbl_pivot.to_csv(file_name)
    return perf_tbl_pivot

In [None]:
# Generate performance tables for subgroup analysis
# AKI 1 at onset
create_subgrp_table(perf_aki1_dict['cb'], 'aki1', 'rvsl', 'auroc')
create_subgrp_table(perf_aki1_dict['cb'], 'aki1', 'stgup', 'auroc')
create_subgrp_table(perf_aki1_dict['cb'], 'aki1', 'rvsl', 'auprc')
create_subgrp_table(perf_aki1_dict['cb'], 'aki1', 'stgup', 'auprc')
# AKI 2&3 at onset
create_subgrp_table(perf_aki23_dict['cb'], 'aki_23', 'rvsl', 'auroc')
create_subgrp_table(perf_aki23_dict['cb'], 'aki_23', 'stgup', 'auroc')
create_subgrp_table(perf_aki23_dict['cb'], 'aki_23', 'rvsl', 'auprc')
create_subgrp_table(perf_aki23_dict['cb'], 'aki_23', 'stgup', 'auprc')