In [None]:
import matplotlib.pyplot as plt
from prediction.utils.visualisation_helper_functions import plot_calibration_curve
import pickle

In [None]:
mrs_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_mrs02/transformer_20230402_184459_test_set_evaluation/fold_2_test_gt_and_pred.pkl'
death_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/testing/fold_1_test_gt_and_pred.pkl'

In [None]:
model_weights_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/testing/checkpoints_opsum_transformer_20230409_060354_cv_1/opsum_transformer_epoch=14_val_auroc=0.9105.ckpt'
model_config_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/hyperopt_selected_transformer_death_20230409_060354.json'
features_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_features_01012023_233050.csv'
labels_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_outcomes_01012023_233050.csv'

## Initial calibration curve


Load predictions

In [None]:
mrs_gt, mrs_predictions = pickle.load(open(mrs_predictions_path, 'rb'))
death_gt, death_predictions = pickle.load(open(death_predictions_path, 'rb'))

In [None]:
plot_calibration_curve(mrs_gt, mrs_predictions.numpy(), n_bins=5)
plt.title(f"Calibration curve for prediction of mRS (0-2) at 3 months")
plt.show()

In [None]:
plot_calibration_curve(death_gt, death_predictions.numpy(), n_bins=10)
plt.title(f"Calibration curve for prediction of death at 3 months")
plt.show()

## Recalibration

In [None]:
outcome = '3M Death'
test_size = 0.2
n_splits = 5
seed = 42
use_gpu = False

In [None]:
from prediction.outcome_prediction.data_loading.data_loader import load_data

pids, train_data, test_data, train_splits, test_features_lookup_table = load_data(features_path, labels_path, outcome, test_size, n_splits, seed)


In [None]:
import json

model_config = json.load(open(model_config_path, 'r'))

In [None]:
from prediction.outcome_prediction.Transformer.lightning_wrapper import LitModel
from prediction.outcome_prediction.Transformer.architecture import OPSUMTransformer
from prediction.outcome_prediction.Transformer.utils.utils import DictLogger
import pytorch_lightning as pl

if use_gpu:
    accelerator = 'gpu'
else:
    accelerator = 'cpu'
logger = DictLogger(0)
trainer = pl.Trainer(accelerator=accelerator, devices=1, max_epochs=1000,
                     gradient_clip_val=model_config['grad_clip_value'], logger=logger)



# define model
ff_factor = 2
ff_dim = ff_factor * model_config['model_dim']
pos_encode_factor = 1

model_architecture = OPSUMTransformer(
    input_dim=84,
    num_layers=int(model_config['num_layers']),
    model_dim=int(model_config['model_dim']),
    dropout=int(model_config['dropout']),
    ff_dim=int(ff_dim),
    num_heads=int(model_config['num_head']),
    num_classes=1,
    max_dim=500,
    pos_encode_factor=pos_encode_factor
)

trained_model = LitModel.load_from_checkpoint(checkpoint_path=model_weights_path, model=model_architecture,
                                              lr=model_config['lr'],
                                              wd=model_config['weight_decay'],
                                              train_noise=model_config['train_noise'])

In [None]:
from prediction.utils.calibration_tools import CalibratableModelFactory

factory = CalibratableModelFactory()

wrapped_model = factory.get_model(trained_model, trainer)

In [None]:
fold_X_train, fold_X_val, fold_y_train, fold_y_val = train_splits[model_config['best_cv_fold']]

In [None]:
from prediction.outcome_prediction.Transformer.utils.utils import prepare_dataset
from torch.utils.data import DataLoader

train_dataset, val_dataset = prepare_dataset((fold_X_train, fold_X_val, fold_y_train, fold_y_val),
                                                      balanced=model_config['balanced'],
                                                      rescale=True,
                                                      use_gpu=False)

val_loader = DataLoader(val_dataset, batch_size=1024)
train_loader = DataLoader(train_dataset, batch_size=fold_X_train.shape[0])

In [None]:
samples, samples_y = [], []
for i in range(len(val_dataset)):
    sample, sample_y = val_dataset[i]
    samples.append(sample)
    samples_y.append(sample_y)

In [None]:
from torch.utils.data import TensorDataset
import torch as ch
import numpy as np

long_val_dataset = TensorDataset(ch.stack(samples), ch.from_numpy(np.array(samples_y)).long())

Calibrate with validation data

In [None]:
wrapped_model.calibrate(val_loader, fold_y_val)

In [None]:
X_test, y_test = test_data
train_dataset, test_dataset = prepare_dataset((fold_X_train, X_test, fold_y_train, y_test),
                                                      balanced=model_config['balanced'],
                                                      rescale=True,
                                                      use_gpu=False)

test_loader = DataLoader(test_dataset, batch_size=1024)

In [None]:
cali_pred = wrapped_model.predict_calibrated(test_loader, 'isotonic')

In [None]:
plot_calibration_curve(death_gt, cali_pred, n_bins=10, title=f"Calibration curve for calibrated prediction of death at 3 months")
plt.show()

Calibrate with temperature scaling

In [None]:
from prediction.utils.calibration_tools import ModelWithTemperature

temp_scale_model = ModelWithTemperature(trained_model.model, use_gpu=False)

In [None]:
temp_scale_model.set_temperature(valid_loader=DataLoader(long_val_dataset, batch_size=1024))

In [None]:
from tqdm import tqdm
temp_calibrated_preds = []
for i in tqdm(range(len(test_dataset))):
    sample, sample_y = test_dataset[i]
    temp_calibrated_preds.append(temp_scale_model.forward(sample.unsqueeze(0)))

In [None]:
temp_calibrated_preds_arr = np.array([ch.sigmoid(temp_calibrated_preds[i][0, -1, -1]).detach().numpy() for i in range(len(test_dataset))])

In [None]:
plot_calibration_curve(death_gt, temp_calibrated_preds_arr, n_bins=10, title=f"Calibration curve for calibrated prediction of death at 3 months")
plt.show()