In [None]:
import pickle
from scipy.optimize import minimize
from scipy.special import logit, expit
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.calibration import calibration_curve
import matplotlib.pyplot as plt

from prediction.outcome_prediction.Transformer.calibration.calibration_measures import evaluate_calibration

# Recalibration by Platt scaling

Ref for implementation: https://github.com/nplan-io/kdd2020-calibration/blob/master/tutorial/KDD%202020%20-%20nPlan%20calibration%20session%20(completed).ipynb

In [None]:
output_dir = '/Users/jk1/Downloads'

## MRS02

In [None]:
val_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_mrs02/predictions_for_all_sets/val_predictions_and_gt.pkl'
test_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_mrs02/predictions_for_all_sets/test_predictions_and_gt.pkl'
train_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_mrs02/predictions_for_all_sets/train_predictions_and_gt.pkl'

In [None]:
with open(val_predictions_path, 'rb') as f:
    raw_predictions_validation, sigm_predictions_validation, gt_validation = pickle.load(f)
    raw_predictions_validation = np.array(raw_predictions_validation)
    sigm_predictions_validation = np.array(sigm_predictions_validation)
with open(test_predictions_path, 'rb') as f:
    raw_predictions_test, sigm_predictions_test, gt_test = pickle.load(f)
    raw_predictions_test = np.array(raw_predictions_test)
    sigm_predictions_test = np.array(sigm_predictions_test)
with open(train_predictions_path, 'rb') as f:
    raw_predictions_train, sigm_predictions_train, gt_train = pickle.load(f)
    raw_predictions_train = np.array(raw_predictions_train)
    sigm_predictions_train = np.array(sigm_predictions_train)

In [None]:
all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)
all_colors_palette

In [None]:
base_colors = sns.color_palette(['#012D98', '#049b9a'], n_colors=2)
base_colors

In [None]:
def plot_reliability_diagram(prob_true, prob_pred, model_name, y_prob=None, predefined_ax=None, n_bins=10, hist=False):
    # Plot the calibration curve for ResNet in comparison with what a perfectly calibrated model would look like
    tick_label_size = 13
    label_font_size = 15

    if predefined_ax==None:
        fig = plt.figure(figsize=(10, 10))
        ax = plt.gca()
    else:
        ax = predefined_ax
        plt.sca(ax)
        fig = plt.gcf()

    if predefined_ax==None:
        plt.plot([0, 1], [0, 1], color="#FE4A49", linestyle=":", label="Perfectly calibrated model")
        if hist and y_prob is not None:
            ax.hist(y_prob, weights=np.ones_like(y_prob) / len(y_prob), alpha=.2,
                    bins=np.maximum(10, n_bins))
        color = base_colors[0]
    else:
        color = base_colors[1]
    plt.plot(prob_pred, prob_true, "s-", label=model_name, color=color)

    plt.ylabel("Fraction of positives", fontsize=label_font_size)
    plt.xlabel("Mean predicted value", fontsize=label_font_size,)

    plt.legend(fontsize=label_font_size)
    plt.xticks(fontsize=tick_label_size)
    plt.yticks(fontsize=tick_label_size)

    plt.grid(True, color="#B2C7D9")

    return fig, ax


In [None]:
y_logits = logit(sigm_predictions_validation)
y_val_binary = np.array(gt_validation)

In [None]:
def scale_fun_bce(x, *args):
    a, b = x
    y_logit_scaled = a*y_logits + b
    y_pred_inner = expit(y_logit_scaled)
    bce = sum([-(y_t * np.log(y_p) + (1 - y_t) * np.log(1 - y_p)) for y_t, y_p in zip(y_val_binary[:1000], y_pred_inner) if not y_p==0])
    return bce


In [None]:
min_obj = minimize(scale_fun_bce,[1,0], method='Nelder-Mead',options={'xatol': 1e-8, 'disp': True})
min_obj

In [None]:
y_logits = logit(sigm_predictions_test)
y_test_pred_corr = expit(min_obj.x[0]*y_logits+min_obj.x[1])

prob_true_binary, prob_pred_binary = calibration_curve(gt_test, sigm_predictions_test, n_bins=10)
prob_true_binary_corr, prob_pred_binary_corr = calibration_curve(gt_test, y_test_pred_corr, n_bins=10)
fig, ax = plot_reliability_diagram(prob_true_binary, prob_pred_binary, "Transformer")
# plot_reliability_diagram(prob_true_binary_corr, prob_pred_binary_corr, "Transformer (calibrated)")
plot_reliability_diagram(prob_true_binary_corr, prob_pred_binary_corr, "Transformer (calibrated)", predefined_ax=ax)
plt.show()

In [None]:
# fig.savefig(f'{output_dir}/mrs02_reliability_diagram.svg', bbox_inches="tight", format='svg', dpi=1200)

In [None]:
initial_cal_df = evaluate_calibration(gt_test, sigm_predictions_test)
initial_cal_df['state'] = 'initial'

In [None]:
recal_df = evaluate_calibration(gt_test, y_test_pred_corr)
recal_df['state'] = 'Platt scaled'

In [None]:
mrs02_cal_df = pd.concat([initial_cal_df, recal_df])
mrs02_cal_df.to_csv(f'{output_dir}/mrs02_calibration.csv')

## Death

In [None]:
val_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/testing/all_sets_predictions/val_predictions_and_gt.pkl'
test_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/testing/all_sets_predictions/test_predictions_and_gt.pkl'
train_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/testing/all_sets_predictions/train_predictions_and_gt.pkl'

In [None]:
with open(val_predictions_path, 'rb') as f:
    raw_predictions_validation, sigm_predictions_validation, gt_validation = pickle.load(f)
    raw_predictions_validation = np.array(raw_predictions_validation)
    sigm_predictions_validation = np.array(sigm_predictions_validation)

with open(test_predictions_path, 'rb') as f:
    raw_predictions_test, sigm_predictions_test, gt_test = pickle.load(f)
    raw_predictions_test = np.array(raw_predictions_test)
    sigm_predictions_test = np.array(sigm_predictions_test)

with open(train_predictions_path, 'rb') as f:
    raw_predictions_train, sigm_predictions_train, gt_train = pickle.load(f)
    raw_predictions_train = np.array(raw_predictions_train)
    sigm_predictions_train = np.array(sigm_predictions_train)

In [None]:
y_logits = logit(sigm_predictions_validation)
y_val_binary = np.array(gt_validation)

In [None]:
def scale_fun_bce(x, *args):
    a, b = x
    y_logit_scaled = a*y_logits + b
    y_pred_inner = expit(y_logit_scaled)
    bce = sum([-(y_t * np.log(y_p) + (1 - y_t) * np.log(1 - y_p)) for y_t, y_p in zip(y_val_binary[:1000], y_pred_inner) if not y_p==0])
    return bce

In [None]:
min_obj = minimize(scale_fun_bce,[1,0], method='Nelder-Mead',options={'xatol': 1e-8, 'disp': True})
min_obj

In [None]:
y_logits = logit(sigm_predictions_test)
y_test_pred_corr = expit(min_obj.x[0]*y_logits+min_obj.x[1])

prob_true_binary, prob_pred_binary = calibration_curve(gt_test, sigm_predictions_test, n_bins=10)
prob_true_binary_corr, prob_pred_binary_corr = calibration_curve(gt_test, y_test_pred_corr, n_bins=10)
fig2, ax = plot_reliability_diagram(prob_true_binary, prob_pred_binary, "Transformer")
# plot_reliability_diagram(prob_true_binary_corr, prob_pred_binary_corr, "Transformer (calibrated)")
plot_reliability_diagram(prob_true_binary_corr, prob_pred_binary_corr, "Transformer (calibrated)", predefined_ax=ax)
plt.show()

In [None]:
# fig2.savefig(f'{output_dir}/death_reliability_diagram.svg', bbox_inches="tight", format='svg', dpi=1200)

In [None]:
initial_cal_df = evaluate_calibration(gt_test, sigm_predictions_test)
initial_cal_df['state'] = 'initial'

In [None]:
recal_df = evaluate_calibration(gt_test, y_test_pred_corr)
recal_df['state'] = 'Platt scaled'

In [None]:
death_cal_df = pd.concat([initial_cal_df, recal_df])
death_cal_df.to_csv(f'{output_dir}/death_calibration.csv')

## MIMIC

In [None]:
test_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/external_validation/fold_1_test_gt_and_pred.pkl'

In [None]:
with open(test_predictions_path, 'rb') as f:
    y_test, y_pred_test = pickle.load(f)
    sigm_predictions_test = np.array(y_pred_test)

In [None]:
y_logits = logit(sigm_predictions_test)
y_test_pred_corr = expit(min_obj.x[0]*y_logits+min_obj.x[1])

In [None]:
prob_true_binary, prob_pred_binary = calibration_curve(y_test, sigm_predictions_test, n_bins=10)
prob_true_binary_corr, prob_pred_binary_corr = calibration_curve(y_test, y_test_pred_corr, n_bins=10)
fig3, ax = plot_reliability_diagram(prob_true_binary, prob_pred_binary, "Transformer", y_prob=sigm_predictions_test,  hist=False)
# plot_reliability_diagram(prob_true_binary_corr, prob_pred_binary_corr, "Transformer (calibrated)")
plot_reliability_diagram(prob_true_binary_corr, prob_pred_binary_corr, "Transformer (calibrated)", predefined_ax=ax)
plt.show()

In [None]:
# fig3.savefig(f'{output_dir}/mimic_reliability_diagram.svg', bbox_inches="tight", format='svg', dpi=1200)

In [None]:
initial_cal_df = evaluate_calibration(y_test, sigm_predictions_test)
initial_cal_df['state'] = 'initial'

In [None]:
recal_df = evaluate_calibration(y_test, y_test_pred_corr)
recal_df['state'] = 'Platt scaled'

In [None]:
mimic_cal_df = pd.concat([initial_cal_df, recal_df])
mimic_cal_df.to_csv(f'{output_dir}/mimic_calibration.csv')

## Recalibrate with a fraction of mimic

Gist: a small fraction of external validation data is used to recalibrate the model. The recalibrated model is then evaluated on the rest of the external validation data.

In [None]:
from sklearn.model_selection import train_test_split

# split mimic into 2 parts
y_test_pred, y_recal_pred, y_test_gt, y_recal_gt = train_test_split(y_pred_test, y_test, test_size=0.2, random_state=42)
y_test_gt = np.array(y_test_gt)
y_test_pred = np.array(y_test_pred)

In [None]:
y_logits = logit(y_recal_pred)
y_val_binary = np.array(y_recal_gt)

In [None]:
def scale_fun_bce(x, *args):
    a, b = x
    y_logit_scaled = a*y_logits + b
    y_pred_inner = expit(y_logit_scaled)
    bce = sum([-(y_t * np.log(y_p) + (1 - y_t) * np.log(1 - y_p)) for y_t, y_p in zip(y_val_binary[:1000], y_pred_inner) if not y_p==0])
    return bce

In [None]:
min_obj = minimize(scale_fun_bce,[1,0], method='Nelder-Mead',options={'xatol': 1e-8, 'disp': True})
min_obj

In [None]:
y_logits = logit(y_test_pred)
y_test_pred_corr = expit(min_obj.x[0]*y_logits+min_obj.x[1])

prob_true_binary, prob_pred_binary = calibration_curve(y_test_gt, y_test_pred, n_bins=10)
prob_true_binary_corr, prob_pred_binary_corr = calibration_curve(y_test_gt, y_test_pred_corr, n_bins=10)
fig4, ax = plot_reliability_diagram(prob_true_binary, prob_pred_binary, "Transformer")
# plot_reliability_diagram(prob_true_binary_corr, prob_pred_binary_corr, "Transformer (calibrated)")
plot_reliability_diagram(prob_true_binary_corr, prob_pred_binary_corr, "Transformer (calibrated)", predefined_ax=ax)
plt.show()

In [None]:
initial_cal_df = evaluate_calibration(y_test_gt, y_test_pred)
initial_cal_df['state'] = 'initial'
recal_df = evaluate_calibration(y_test_gt, y_test_pred_corr)
recal_df['state'] = 'Platt scaled'
mimic_cal_df = pd.concat([initial_cal_df, recal_df])
mimic_cal_df