In [None]:
import pandas as pd
import pickle
import numpy as np
from sklearn.calibration import calibration_curve
import matplotlib.pyplot as plt

In [None]:
predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/external_validation/fold_1_test_gt_and_pred.pkl'

In [None]:
model_weights_path= '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/testing/checkpoints_opsum_transformer_20230409_060354_cv_1/opsum_transformer_epoch=14_val_auroc=0.9105.ckpt'
model_config_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/hyperopt_selected_transformer_death_20230409_060354.json'

features_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_features_01012023_233050.csv'
labels_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_outcomes_01012023_233050.csv'

ext_features_path = '/Users/jk1/temp/opsum_mimic/preprocessing/mimic_prepro_25012023_232713/preprocessed_features_25012023_232713.csv'
ext_labels_path = '/Users/jk1/temp/opsum_mimic/preprocessing/mimic_prepro_25012023_232713/preprocessed_outcomes_25012023_232713.csv'

## Initial calibration curve

Load predictions

In [None]:
death_gt, death_predictions = pickle.load(open(predictions_path, 'rb'))

In [None]:
from sklearn.metrics import brier_score_loss
print(f"Brier score for death at 3 months: {brier_score_loss(death_gt, death_predictions.numpy())}")

In [None]:
from prediction.utils.visualisation_helper_functions import plot_calibration_curve

plot_calibration_curve(death_gt, death_predictions.numpy(), n_bins=2)
plt.title(f"Calibration curve for prediction of mRS (0-2) at 3 months")
plt.show()

## Recalibration

Load data

In [None]:
outcome = '3M Death'
test_size = 0.2
n_splits = 5
seed = 42
use_gpu = False

In [None]:
from prediction.outcome_prediction.data_loading.data_loader import load_data

pids, train_data, test_data, train_splits, test_features_lookup_table = load_data(features_path, labels_path, outcome, test_size, n_splits, seed)


In [None]:
from prediction.outcome_prediction.data_loading.data_loader import load_external_data

# load external test data
test_X_np, test_y_np, test_features_lookup_table = load_external_data(ext_features_path, ext_labels_path, outcome)


Load model

In [None]:
import json

model_config = json.load(open(model_config_path, 'r'))

In [None]:
from prediction.outcome_prediction.Transformer.lightning_wrapper import LitModel
from prediction.outcome_prediction.Transformer.architecture import OPSUMTransformer
from prediction.outcome_prediction.Transformer.utils.utils import DictLogger
import pytorch_lightning as pl

if use_gpu:
    accelerator = 'gpu'
else:
    accelerator = 'cpu'
logger = DictLogger(0)
trainer = pl.Trainer(accelerator=accelerator, devices=1, max_epochs=1000,
                     gradient_clip_val=model_config['grad_clip_value'], logger=logger)



# define model
ff_factor = 2
ff_dim = ff_factor * model_config['model_dim']
pos_encode_factor = 1

model_architecture = OPSUMTransformer(
    input_dim=84,
    num_layers=int(model_config['num_layers']),
    model_dim=int(model_config['model_dim']),
    dropout=int(model_config['dropout']),
    ff_dim=int(ff_dim),
    num_heads=int(model_config['num_head']),
    num_classes=1,
    max_dim=500,
    pos_encode_factor=pos_encode_factor
)

trained_model = LitModel.load_from_checkpoint(checkpoint_path=model_weights_path, model=model_architecture,
                                              lr=model_config['lr'],
                                              wd=model_config['weight_decay'],
                                              train_noise=model_config['train_noise'])

In [None]:
fold_X_train, fold_X_val, fold_y_train, fold_y_val = train_splits[model_config['best_cv_fold']]

In [None]:
from prediction.outcome_prediction.Transformer.utils.utils import prepare_dataset

_, test_dataset = prepare_dataset((fold_X_train, test_X_np, fold_y_train, test_y_np),
                                                  balanced=model_config['balanced'],
                                                  rescale=True,
                                                  use_gpu=False)

In [None]:
from torch.utils.data import DataLoader

train_dataset, val_dataset = prepare_dataset((fold_X_train, fold_X_val, fold_y_train, fold_y_val),
                                                      balanced=model_config['balanced'],
                                                      rescale=True,
                                                      use_gpu=False)

val_loader = DataLoader(val_dataset, batch_size=1024)
train_loader = DataLoader(train_dataset, batch_size=fold_X_train.shape[0])

In [None]:
samples, samples_y = [], []
for i in range(len(val_dataset)):
    sample, sample_y = val_dataset[i]
    samples.append(sample)
    samples_y.append(sample_y)

In [None]:
from torch.utils.data import TensorDataset
import torch as ch
import numpy as np

long_val_dataset = TensorDataset(ch.stack(samples), ch.from_numpy(np.array(samples_y)).long())

## Re-calibration using GSU dataset


Prepare temperature calibration

In [None]:
from prediction.utils.calibration_tools import ModelWithTemperature

temp_scale_model = ModelWithTemperature(trained_model.model, use_gpu=False)
temp_scale_model.set_temperature(valid_loader=DataLoader(long_val_dataset, batch_size=1024))

Make predictions

In [None]:
from tqdm import tqdm
temp_calibrated_preds = []
for i in tqdm(range(len(test_dataset))):
    sample, sample_y = test_dataset[i]
    temp_calibrated_preds.append(temp_scale_model.forward(sample.unsqueeze(0)))

In [None]:
temp_calibrated_preds_arr = np.array([ch.sigmoid(temp_calibrated_preds[i][0, -1, -1]).detach().numpy() for i in range(len(test_dataset))])

In [None]:
plot_calibration_curve(death_gt, temp_calibrated_preds_arr, n_bins=10, title=f"Calibration curve for calibrated prediction of death at 3 months")
plt.show()

Using isotonic regression

In [None]:
from prediction.utils.calibration_tools import CalibratableModelFactory

factory = CalibratableModelFactory()

wrapped_model = factory.get_model(trained_model, trainer)

In [None]:
wrapped_model.calibrate(val_loader, fold_y_val)

In [None]:
iso_cal_pred = wrapped_model.predict_calibrated(DataLoader(test_dataset, batch_size=1024),'isotonic')

In [None]:
plot_calibration_curve(death_gt, iso_cal_pred, n_bins=10, title=f"Calibration curve for calibrated prediction of death at 3 months")
plt.show()

## Re-calibration using a fraction of MIMIC dataset

Gist: use a sub-fraction of MIMIC to calibrate model

In [None]:
from sklearn.model_selection import train_test_split

calib_size = 0.1

ext_test_X, ext_calib_X, ext_test_y, ext_calib_y = train_test_split(test_X_np, test_y_np,
                                                                    stratify=test_y_np,
                                                                test_size=calib_size,
                                                                random_state=seed)

In [None]:
_, ext_calib_dataset = prepare_dataset((fold_X_train, ext_calib_X, fold_y_train, ext_calib_y),
                                                  balanced=model_config['balanced'],
                                                  rescale=True,
                                                  use_gpu=False)

_, ext_test_dataset = prepare_dataset((fold_X_train, ext_test_X, fold_y_train, ext_test_y),
                                                  balanced=model_config['balanced'],
                                                  rescale=True,
                                                  use_gpu=False)

Prepare temperature calibration

In [None]:
samples, samples_y = [], []
for i in range(len(ext_calib_dataset)):
    sample, sample_y = ext_calib_dataset[i]
    samples.append(sample)
    samples_y.append(sample_y)

long_ext_calib_dataset = TensorDataset(ch.stack(samples), ch.from_numpy(np.array(samples_y)).long())

In [None]:
temp_scale_model = ModelWithTemperature(trained_model.model, use_gpu=False)
temp_scale_model.set_temperature(valid_loader=DataLoader(long_ext_calib_dataset, batch_size=1024))

In [None]:
ext_temp_calibrated_preds = []
for i in tqdm(range(len(ext_test_dataset))):
    sample, sample_y = ext_test_dataset[i]
    ext_temp_calibrated_preds.append(temp_scale_model.forward(sample.unsqueeze(0)))

In [None]:
ext_temp_calibrated_preds_arr = np.array([ch.sigmoid(ext_temp_calibrated_preds[i][0, -1, -1]).detach().numpy() for i in range(len(ext_test_dataset))])

In [None]:
plot_calibration_curve(ext_test_y, ext_temp_calibrated_preds_arr, n_bins=10, title=f"Calibration curve for calibrated prediction of death at 3 months")
plt.show()

Using isotonic regression

In [None]:
wrapped_model = factory.get_model(trained_model, trainer)

In [None]:
wrapped_model.calibrate(DataLoader(ext_calib_dataset, batch_size=1024), ext_calib_y)

In [None]:
ext_iso_cal_pred = wrapped_model.predict_calibrated(DataLoader(ext_test_dataset, batch_size=1024),'isotonic')

In [None]:
plot_calibration_curve(ext_test_y, ext_iso_cal_pred, n_bins=10, title=f"Calibration curve for calibrated prediction of death at 3 months")
plt.show()