# Cross-validation sets: ROC & PR Curve comparison with confidence intervals

Confidence are obtained from variance of performance across folds

In [None]:
import pickle
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
import pandas as pd
from sklearn.metrics import roc_curve, auc, precision_recall_curve
import matplotlib.patches as mpatches
from matplotlib.legend_handler import HandlerTuple

In [None]:
# for some reason spacing of plots only works correctly in opsum_shap environment 
# verify that python version is 3.7
import sys
assert sys.version_info.major == 3 and sys.version_info.minor == 7

In [None]:
thrive_c_mrs02_predictions_path = '/Users/jk1/temp/opsum_prediction_output/THRIVE_C/THRIVE_C_3m_mrs02_predictions/test_gt_and_pred.pkl'
transformer_mrs02_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_mrs02/predictions_for_all_sets/all_folds'
lstm_mrs02_predictions_path = '/Users/jk1/temp/opsum_prediction_output/LSTM_72h/testing/3M_mRS02/2023_01_02_1057/prediction_all_sets'
xgb_mrs02_predictions_path = '/Users/jk1/temp/opsum_prediction_output/linear_72h_xgb/3M_mrs02/with_feature_aggregration/testing/predictions_all_sets'

thrive_c_death_predictions_path = '/Users/jk1/temp/opsum_prediction_output/THRIVE_C/THRIVE_C_3m_death_predictions/3m_death_test_gt_and_pred.pkl'
transformer_death_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/testing/all_sets_predictions/all_folds'
lstm_death_predictions_path = '/Users/jk1/temp/opsum_prediction_output/LSTM_72h/testing/3M_Death/2023_01_04_2020/predictions_all_sets'
xgb_death_predictions_path = '/Users/jk1/temp/opsum_prediction_output/linear_72h_xgb/3M_Death/inference/all_sets'

thrive_c_death_in_hosp_predictions_path = '/Users/jk1/temp/opsum_prediction_output/THRIVE_C/THRIVE_C_death_in_hospital_predictions/death_in_hospital_test_gt_and_pred.pkl'
transformer_death_in_hosp_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/Death_in_hospital/inference/training_sets'
lstm_death_in_hosp_predictions_path = '/Users/jk1/temp/opsum_prediction_output/LSTM_72h/testing/Death_in_hospital/2024_02_05_1346/inference/training_sets'
xgb_death_in_hosp_predictions_path = '/Users/jk1/temp/opsum_prediction_output/linear_72h_xgb/Death_in_hospital/inference/training_sets'

roc_auc_statistics_path = '/Users/jk1/Library/CloudStorage/OneDrive-unige.ch/stroke_research/geneva_stroke_unit_dataset/opsum_paper/supp_figures_and_tables/model_comparison/table/revision/overall_model_comparison.csv'

In [None]:
output_dir = '/Users/jk1/Downloads'

In [None]:
n_folds = 5
seed = 42

# for transformer, indices already start at 0
mrs02_transformer_selected_fold = 2
death_transformer_selected_fold = 1
death_in_hosp_transformer_selected_fold = 2

# subtract one to go from fold number to index (as folds will be in a list)
mrs02_xgb_selected_fold = 3 - 1
death_xgb_selected_fold = 3 - 1
death_in_hosp_xgb_selected_fold = 3 - 1

# subtract one to go from fold number to index (as folds will be in a list)
mrs02_lstm_selected_fold = 3 - 1 
death_lstm_selected_fold = 2 - 1
death_in_hosp_lstm_selected_fold = 3 - 1

Load data

In [None]:
roc_auc_statistics = pd.read_csv(roc_auc_statistics_path)
roc_auc_statistics['clean_roc_auc'] = roc_auc_statistics['ROC AUC'].apply(lambda x: float(x.split(' ')[0]))

In [None]:
thrivec_mrs02_gt, thrivec_mrs02_predictions = pickle.load(open(thrive_c_mrs02_predictions_path, 'rb'))
thrivec_death_gt, thrivec_death_predictions = pickle.load(open(thrive_c_death_predictions_path, 'rb'))
thrivec_death_in_hosp_gt, thrivec_death_in_hosp_predictions = pickle.load(open(thrive_c_death_in_hosp_predictions_path, 'rb'))

thrivec_mrs02_roc_auc = roc_auc_statistics[(roc_auc_statistics.Model == 'THRIVE-C') & (roc_auc_statistics.Outcome == '3M mrs02')]['clean_roc_auc'].values[0]
thrivec_death_roc_auc = roc_auc_statistics[(roc_auc_statistics.Model == 'THRIVE-C') & (roc_auc_statistics.Outcome == '3M Death') 
                                               & (roc_auc_statistics.Dataset == 'GSU')]['clean_roc_auc'].values[0]
thrivec_death_in_hosp_roc_auc = roc_auc_statistics[(roc_auc_statistics.Model == 'THRIVE-C') & (roc_auc_statistics.Outcome == 'Death in hospital') 
                                               & (roc_auc_statistics.Dataset == 'GSU')]['clean_roc_auc'].values[0]

In [None]:
transformer_mrs02_folds = []
for fidx in range(n_folds):
    temp_transformer_data = pickle.load(open(os.path.join(transformer_mrs02_predictions_path, f'val_predictions_fold{fidx}.pkl'), 'rb'))
    # for all set predictions data is stored as y_true, y_pred (containing: raw predictions, probabilities (after sigm), and gt)
    transformer_data = (temp_transformer_data[0], temp_transformer_data[1][1])
    transformer_mrs02_folds.append(transformer_data)
    
transformer_death_folds = []
for fidx in range(n_folds):
    temp_transformer_data = pickle.load(open(os.path.join(transformer_death_predictions_path, f'val_predictions_fold{fidx}.pkl'), 'rb'))
    # for all set predictions data is stored as y_true, y_pred (containing: raw predictions, probabilities (after sigm), and gt)
    transformer_data = (temp_transformer_data[0], temp_transformer_data[1][1])
    transformer_death_folds.append(transformer_data)

transformer_death_in_hosp_folds = []
for fidx in range(n_folds):
    temp_transformer_data = pickle.load(open(os.path.join(transformer_death_in_hosp_predictions_path, f'val_predictions_fold{fidx}.pkl'), 'rb'))
    # for all set predictions data is stored as y_true, y_pred (containing: raw predictions, probabilities (after sigm), and gt)
    transformer_data = (temp_transformer_data[0], temp_transformer_data[1][1])
    transformer_death_in_hosp_folds.append(transformer_data)

In [None]:
xgb_mrs02_folds = []
for fidx in range(n_folds):
    xgb_mrs02_folds.append(pickle.load(open(os.path.join(xgb_mrs02_predictions_path, f'val_gt_and_pred_cv_{fidx}.pkl'), 'rb')))

xgb_death_folds = []
for fidx in range(n_folds):
    xgb_death_folds.append(pickle.load(open(os.path.join(xgb_death_predictions_path, f'val_gt_and_pred_cv_{fidx}.pkl'), 'rb')))

xgb_death_in_hosp_folds = []
for fidx in range(n_folds):
    xgb_death_in_hosp_folds.append(pickle.load(open(os.path.join(xgb_death_in_hosp_predictions_path, f'val_gt_and_pred_cv_{fidx}.pkl'), 'rb')))


In [None]:
lstm_mrs02_folds = []
for fidx in range(n_folds):
    for subdir in os.listdir(lstm_mrs02_predictions_path):
        if os.path.isdir(os.path.join(lstm_mrs02_predictions_path, subdir)):
            for file in os.listdir(os.path.join(lstm_mrs02_predictions_path, subdir)):
                if file == f'val_gt_and_pred_fold_{fidx+1}.pkl':
                    lstm_mrs02_folds.append(pickle.load(open(os.path.join(lstm_mrs02_predictions_path, subdir, file), 'rb')))

lstm_death_folds = []
for fidx in range(n_folds):
    # search in every subdir of lstm_mrs02_predictions_path to find and load file: test_gt_and_pred_fold_{fidx}.pkl
    for subdir in os.listdir(lstm_death_predictions_path):
        if os.path.isdir(os.path.join(lstm_death_predictions_path, subdir)):
            for file in os.listdir(os.path.join(lstm_death_predictions_path, subdir)):
                if file == f'val_gt_and_pred_fold_{fidx+1}.pkl':
                    lstm_death_folds.append(pickle.load(open(os.path.join(lstm_death_predictions_path, subdir, file), 'rb')))

lstm_death_in_hosp_folds = []
for fidx in range(n_folds):
    # search in every subdir of lstm_mrs02_predictions_path to find and load file: test_gt_and_pred_fold_{fidx}.pkl
    for subdir in os.listdir(lstm_death_in_hosp_predictions_path):
        if os.path.isdir(os.path.join(lstm_death_in_hosp_predictions_path, subdir)):
            for file in os.listdir(os.path.join(lstm_death_in_hosp_predictions_path, subdir)):
                if file == f'val_gt_and_pred_fold_{fidx+1}.pkl':
                    lstm_death_in_hosp_folds.append(pickle.load(open(os.path.join(lstm_death_in_hosp_predictions_path, subdir, file), 'rb')))



In [None]:
all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)
all_colors_palette

# Helper functions

In [None]:
def compute_roc_and_pr_curves(folds, n_interpolated_points=200):
    roc_df = pd.DataFrame()
    resampled_roc_df = pd.DataFrame()
    roc_aucs = []
    pr_df = pd.DataFrame()
    resampled_pr_df = pd.DataFrame()
    pr_aucs = []
    for fidx in tqdm(range(n_folds)):
        fpr, tpr, _ = roc_curve(folds[fidx][0], folds[fidx][1], drop_intermediate=False)
        roc_aucs.append(auc(fpr, tpr))
        resampled_tpr = np.interp(np.linspace(0, 1, n_interpolated_points), fpr, tpr)
        roc_df = roc_df.append(pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'fold': fidx}))
        resampled_roc_df = resampled_roc_df.append(pd.DataFrame({'fpr': np.linspace(0,1,n_interpolated_points),  'tpr': resampled_tpr, 'fold': fidx}))

        precision, recall, _ = precision_recall_curve(folds[fidx][0], folds[fidx][1])
        recall, precision = zip(*sorted(zip(recall, precision)))
        pr_aucs.append(auc(recall, precision))
        resampled_precision = np.interp(np.linspace(0, 1, n_interpolated_points), recall, precision)
        pr_df = pr_df.append(pd.DataFrame({'recall': recall, 'precision': precision, 'fold': fidx}))
        resampled_pr_df = resampled_pr_df.append(pd.DataFrame({'recall': np.linspace(0,1,n_interpolated_points),  'precision': resampled_precision, 'fold': fidx}))

    return roc_df, resampled_roc_df, roc_aucs, pr_df, resampled_pr_df, pr_aucs


## Prepare data for mrs02 outcome

### Compute ROC and PR curve standard deviation for THRIVE-C

In [None]:
# split thrivec data into 5 folds
mrs02_thrivec_folds = []
for fidx in range(n_folds):
    mrs02_thrivec_folds.append((thrivec_mrs02_gt[fidx::n_folds], thrivec_mrs02_predictions[fidx::n_folds]))

In [None]:
mrs02_thrivec_fpr, mrs02_thrivec_tpr, _ = roc_curve(thrivec_mrs02_gt, thrivec_mrs02_predictions)
mrs02_thrivec_roc_auc = auc(mrs02_thrivec_fpr, mrs02_thrivec_tpr)
mrs02_thrivec_resampled_tpr = np.interp(np.linspace(0, 1, 200), mrs02_thrivec_fpr, mrs02_thrivec_tpr)

mrs02_thrivec_precision, mrs02_thrivec_recall, _ = precision_recall_curve(thrivec_mrs02_gt, thrivec_mrs02_predictions)
mrs02_thrivec_recall, mrs02_thrivec_precision = zip(*sorted(zip(mrs02_thrivec_recall, mrs02_thrivec_precision)))
mrs02_thrivec_pr_auc = auc(mrs02_thrivec_recall, mrs02_thrivec_precision)
mrs02_thrivec_resampled_precision = np.interp(np.linspace(0, 1, 200), mrs02_thrivec_recall, mrs02_thrivec_precision)

In [None]:
mrs02_thrivec_roc_df, mrs02_thrivec_resampled_roc_df, mrs02_thrivec_roc_aucs, mrs02_thrivec_pr_df, mrs02_thrivec_resampled_pr_df, mrs02_thrivec_pr_aucs = compute_roc_and_pr_curves(mrs02_thrivec_folds)

In [None]:
mrs02_thrivec_resampled_roc_std = mrs02_thrivec_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
mrs02_thrivec_resampled_pr_std = mrs02_thrivec_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

### Transformer curves with fold variation

In [None]:
mrs02_transformer_roc_df, mrs02_transformer_resampled_roc_df, mrs02_transformer_roc_aucs, mrs02_transformer_pr_df, mrs02_transformer_resampled_pr_df, mrs02_transformer_pr_aucs = compute_roc_and_pr_curves(transformer_mrs02_folds)

In [None]:
mrs02_transformer_resampled_roc_std = mrs02_transformer_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
mrs02_transformer_resampled_pr_std = mrs02_transformer_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

### XGB curves with fold variation

In [None]:
mrs02_xgb_roc_df, mrs02_xgb_resampled_roc_df, mrs02_xgb_roc_aucs, mrs02_xgb_pr_df, mrs02_xgb_resampled_pr_df, mrs02_xgb_pr_aucs = compute_roc_and_pr_curves(xgb_mrs02_folds)

In [None]:
mrs02_xgb_resampled_roc_std = mrs02_xgb_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
mrs02_xgb_resampled_pr_std = mrs02_xgb_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

### LSTM curves with fold variation

In [None]:
mrs02_lstm_roc_df, mrs02_lstm_resampled_roc_df, mrs02_lstm_roc_aucs, mrs02_lstm_pr_df, mrs02_lstm_resampled_pr_df, mrs02_lstm_pr_aucs = compute_roc_and_pr_curves(lstm_mrs02_folds)

In [None]:
mrs02_lstm_resampled_roc_std = mrs02_lstm_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
mrs02_lstm_resampled_pr_std = mrs02_lstm_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

## Prepare data for mortality prediction

In [None]:
# split thrivec data into 5 folds
death_thrivec_folds = []
for fidx in range(n_folds):
    death_thrivec_folds.append((thrivec_death_gt[fidx::n_folds], thrivec_death_predictions[fidx::n_folds]))

death_thrivec_fpr, death_thrivec_tpr, _ = roc_curve(thrivec_death_gt, thrivec_death_predictions)
death_thrivec_roc_auc = auc(death_thrivec_fpr, death_thrivec_tpr)
death_thrivec_resampled_tpr = np.interp(np.linspace(0, 1, 200), death_thrivec_fpr, death_thrivec_tpr)

death_thrivec_precision, death_thrivec_recall, _ = precision_recall_curve(thrivec_death_gt, thrivec_death_predictions)
death_thrivec_recall, death_thrivec_precision = zip(*sorted(zip(death_thrivec_recall, death_thrivec_precision)))
death_thrivec_pr_auc = auc(death_thrivec_recall, death_thrivec_precision)
death_thrivec_resampled_precision = np.interp(np.linspace(0, 1, 200), death_thrivec_recall, death_thrivec_precision)

death_thrivec_roc_df, death_thrivec_resampled_roc_df, death_thrivec_roc_aucs, death_thrivec_pr_df, death_thrivec_resampled_pr_df, death_thrivec_pr_aucs = compute_roc_and_pr_curves(
    death_thrivec_folds)

death_thrivec_resampled_roc_std = death_thrivec_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
death_thrivec_resampled_pr_std = death_thrivec_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision'][
    'std']

In [None]:
transformer_death_roc_df, transformer_death_resampled_roc_df, transformer_death_roc_aucs, transformer_death_pr_df, transformer_death_resampled_pr_df, transformer_death_pr_aucs = compute_roc_and_pr_curves(transformer_death_folds)

transformer_death_resampled_roc_std = transformer_death_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
transformer_death_resampled_pr_std = transformer_death_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']


In [None]:
death_xgb_roc_df, death_xgb_resampled_roc_df, death_xgb_roc_aucs, death_xgb_pr_df, death_xgb_resampled_pr_df, death_xgb_pr_aucs = compute_roc_and_pr_curves(xgb_death_folds)
death_xgb_resampled_roc_std = death_xgb_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
death_xgb_resampled_pr_std = death_xgb_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

In [None]:
death_lstm_roc_df, death_lstm_resampled_roc_df, death_lstm_roc_aucs, death_lstm_pr_df, death_lstm_resampled_pr_df, death_lstm_pr_aucs = compute_roc_and_pr_curves(lstm_death_folds)
death_lstm_resampled_roc_std = death_lstm_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
death_lstm_resampled_pr_std = death_lstm_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

## Prepare data death in hospital

In [None]:
# split thrivec data into 5 folds
death_in_hosp_thrivec_folds = []
for fidx in range(n_folds):
    death_in_hosp_thrivec_folds.append((thrivec_death_in_hosp_gt[fidx::n_folds], thrivec_death_in_hosp_predictions[fidx::n_folds]))
    
death_in_hosp_thrivec_fpr, death_in_hosp_thrivec_tpr, _ = roc_curve(thrivec_death_in_hosp_gt, thrivec_death_in_hosp_predictions)
death_in_hosp_thrivec_roc_auc = auc(death_in_hosp_thrivec_fpr, death_in_hosp_thrivec_tpr)
death_in_hosp_thrivec_resampled_tpr = np.interp(np.linspace(0, 1, 200), death_in_hosp_thrivec_fpr, death_in_hosp_thrivec_tpr)

death_in_hosp_thrivec_precision, death_in_hosp_thrivec_recall, _ = precision_recall_curve(thrivec_death_in_hosp_gt, thrivec_death_in_hosp_predictions)
death_in_hosp_thrivec_recall, death_in_hosp_thrivec_precision = zip(*sorted(zip(death_in_hosp_thrivec_recall, death_in_hosp_thrivec_precision)))
death_in_hosp_thrivec_pr_auc = auc(death_in_hosp_thrivec_recall, death_in_hosp_thrivec_precision)
death_in_hosp_thrivec_resampled_precision = np.interp(np.linspace(0, 1, 200), death_in_hosp_thrivec_recall, death_in_hosp_thrivec_precision)

death_in_hosp_thrivec_roc_df, death_in_hosp_thrivec_resampled_roc_df, death_in_hosp_thrivec_roc_aucs, death_in_hosp_thrivec_pr_df, death_in_hosp_thrivec_resampled_pr_df, death_in_hosp_thrivec_pr_aucs = compute_roc_and_pr_curves(
    death_in_hosp_thrivec_folds)

death_in_hosp_thrivec_resampled_roc_std = death_in_hosp_thrivec_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
death_in_hosp_thrivec_resampled_pr_std = death_in_hosp_thrivec_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision'][
    'std']

In [None]:
transformer_death_in_hosp_roc_df, transformer_death_in_hosp_resampled_roc_df, transformer_death_in_hosp_roc_aucs, transformer_death_in_hosp_pr_df, transformer_death_in_hosp_resampled_pr_df, transformer_death_in_hosp_pr_aucs = compute_roc_and_pr_curves(transformer_death_in_hosp_folds)

transformer_death_in_hosp_resampled_roc_std = transformer_death_in_hosp_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
transformer_death_in_hosp_resampled_pr_std = transformer_death_in_hosp_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

In [None]:
death_in_hosp_xgb_roc_df, death_in_hosp_xgb_resampled_roc_df, death_in_hosp_xgb_roc_aucs, death_in_hosp_xgb_pr_df, death_in_hosp_xgb_resampled_pr_df, death_in_hosp_xgb_pr_aucs = compute_roc_and_pr_curves(xgb_death_in_hosp_folds)

death_in_hosp_xgb_resampled_roc_std = death_in_hosp_xgb_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
death_in_hosp_xgb_resampled_pr_std = death_in_hosp_xgb_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

In [None]:
death_in_hosp_lstm_roc_df, death_in_hosp_lstm_resampled_roc_df, death_in_hosp_lstm_roc_aucs, death_in_hosp_lstm_pr_df, death_in_hosp_lstm_resampled_pr_df, death_in_hosp_lstm_pr_aucs = compute_roc_and_pr_curves(lstm_death_in_hosp_folds)

death_in_hosp_lstm_resampled_roc_std = death_in_hosp_lstm_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
death_in_hosp_lstm_resampled_pr_std = death_in_hosp_lstm_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

### Resampled ROC curve with fold variation


In [None]:
def plot_roc_auc_curve(transformer_roc_df, transformer_resampled_roc_df, transformer_resampled_roc_std, transformer_roc_auc, transformer_selected_fold,
                            xgb_roc_df, xgb_resampled_roc_df, xgb_resampled_roc_std, xgb_roc_auc, xgb_selected_fold,
                            lstm_roc_df, lstm_resampled_roc_df, lstm_resampled_roc_std, lstm_roc_auc, lstm_selected_fold,
                           thrivec_resampled_tpr, thrivec_resampled_roc_std, thrivec_roc_auc,
                           ax, plot_legend = True, tick_label_size = 11, label_font_size = 13):

    all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)
    n_float_numbers = 3
    
    ## Main model: Transformer
    # plot variation across folds (+/- 1 std)
    error = 1*transformer_resampled_roc_std.values
    lower = transformer_resampled_roc_df[transformer_resampled_roc_df.fold == transformer_selected_fold].tpr - error
    upper = transformer_resampled_roc_df[transformer_resampled_roc_df.fold == transformer_selected_fold].tpr + error
    ax.fill_between(transformer_resampled_roc_std.index, lower, upper, alpha=0.2, color=all_colors_palette[0], zorder=100)
    
    # Plot selected fold in bold
    ax = sns.lineplot(data=transformer_roc_df[transformer_roc_df.fold == transformer_selected_fold], x='fpr', y='tpr', color=all_colors_palette[0], label=f'Transformer (area = %0.{n_float_numbers}f)' % transformer_roc_auc,
                       ax=ax, errorbar=None, zorder=100)
    
    
    ## Comparators: THRIVE-C
    # plot variation across folds (+/- 1 std)
    error = 1*thrivec_resampled_roc_std.values
    lower = thrivec_resampled_tpr - error
    upper = thrivec_resampled_tpr + error
    ax.fill_between(np.linspace(0, 1, 200), lower, upper, alpha=0.1, color=all_colors_palette[1])
    
    # Plot THRIVE-C in bold
    sns.lineplot(x=np.linspace(0, 1, 200), y=thrivec_resampled_tpr, color=all_colors_palette[1], label=f'THRIVE-C (area = %0.{n_float_numbers}f)' % thrivec_roc_auc,
                 ax=ax, linewidth=2)
    
    # ## Comparators: XGBoost 
    error = 1*xgb_resampled_roc_std.values
    lower = xgb_resampled_roc_df[xgb_resampled_roc_df.fold == xgb_selected_fold].tpr - error
    upper = xgb_resampled_roc_df[xgb_resampled_roc_df.fold == xgb_selected_fold].tpr + error
    ax.fill_between(xgb_resampled_roc_std.index, lower, upper, alpha=0.2, color=all_colors_palette[2])  
    ax = sns.lineplot(data=xgb_roc_df[xgb_roc_df.fold == xgb_selected_fold], x='fpr', y='tpr', color=all_colors_palette[2], label=f'XGBoost (area = %0.{n_float_numbers}f)' % xgb_roc_auc,
                       ax=ax, errorbar=None)
    
    ## Comparators: LSTM
    error = 1*lstm_resampled_roc_std.values
    lower = lstm_resampled_roc_df[lstm_resampled_roc_df.fold == lstm_selected_fold].tpr - error
    upper = lstm_resampled_roc_df[lstm_resampled_roc_df.fold == lstm_selected_fold].tpr + error
    ax.fill_between(lstm_resampled_roc_std.index, lower, upper, alpha=0.2, color=all_colors_palette[3])
    ax = sns.lineplot(data=lstm_roc_df[lstm_roc_df.fold == lstm_selected_fold], x='fpr', y='tpr', color=all_colors_palette[3], label=f'LSTM (area = %0.{n_float_numbers}f)' % lstm_roc_auc,
                       ax=ax, errorbar=None)
    
    ax.plot([0, 1], [0, 1], color='grey', lw=1, linestyle='--', alpha=0.5)
    
    ax.set_xlabel('1 - Specificity (False Positive Rate)', fontsize=label_font_size)
    ax.set_ylabel('Sensitivity (True Positive Rate)', fontsize=label_font_size)
    ax.tick_params('x', labelsize=tick_label_size)
    ax.tick_params('y', labelsize=tick_label_size)
    
    if plot_legend:
        legend_markers, legend_labels = ax.get_legend_handles_labels()
        sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)
        sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)
        sd_3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)
        sd_4_patch = mpatches.Patch(color=all_colors_palette[3], alpha=0.3)
        sd_marker = (sd1_patch, sd2_patch, sd_3_patch, sd_4_patch)
        sd_labels = '± s.d.'
        legend_markers.append(sd_marker)
        legend_labels.append(sd_labels)
        ax.legend(legend_markers, legend_labels, fontsize=label_font_size,
                  handler_map={tuple: HandlerTuple(ndivide=None)})
    
    else:
        # remove legend
        ax.get_legend().remove()
    
    fig = ax.get_figure()
    return fig

#### Plot mrs02 ROC curve

In [None]:
custom_params = {"axes.spines.right": False, "axes.spines.top": False, 'figure.figsize':(10,10)}
sns.set_theme(style="whitegrid", rc=custom_params, context="paper", font_scale = 1)

ax = plt.subplot(111)

plot_roc_auc_curve(mrs02_transformer_roc_df, mrs02_transformer_resampled_roc_df, mrs02_transformer_resampled_roc_std, np.median(mrs02_transformer_roc_aucs), mrs02_transformer_selected_fold,
                          mrs02_xgb_roc_df, mrs02_xgb_resampled_roc_df, mrs02_xgb_resampled_roc_std, np.median(mrs02_xgb_roc_aucs), mrs02_xgb_selected_fold,
                            mrs02_lstm_roc_df, mrs02_lstm_resampled_roc_df, mrs02_lstm_resampled_roc_std, np.median(mrs02_lstm_roc_aucs), mrs02_lstm_selected_fold,
                       mrs02_thrivec_resampled_tpr, mrs02_thrivec_resampled_roc_std, thrivec_mrs02_roc_auc, 
                       ax, plot_legend = True, tick_label_size = 11, label_font_size = 13)

plt.show()

#### Plot death ROC curve

In [None]:
custom_params = {"axes.spines.right": False, "axes.spines.top": False, 'figure.figsize':(10,10)}
sns.set_theme(style="whitegrid", rc=custom_params, context="paper", font_scale = 1)

ax = plt.subplot(111)

plot_roc_auc_curve(transformer_death_roc_df, transformer_death_resampled_roc_df, transformer_death_resampled_roc_std, np.median(transformer_death_roc_aucs), death_transformer_selected_fold,
                          death_xgb_roc_df, death_xgb_resampled_roc_df, death_xgb_resampled_roc_std, np.median(death_xgb_roc_aucs), death_xgb_selected_fold,
                            death_lstm_roc_df, death_lstm_resampled_roc_df, death_lstm_resampled_roc_std, np.median(death_lstm_roc_aucs), death_lstm_selected_fold,
                       death_thrivec_resampled_tpr, death_thrivec_resampled_roc_std, thrivec_death_roc_auc,
                       ax, plot_legend = True, tick_label_size = 11, label_font_size = 13)

plt.show()

#### Plot death in hospital ROC curve

In [None]:
custom_params = {"axes.spines.right": False, "axes.spines.top": False, 'figure.figsize':(10,10)}
sns.set_theme(style="whitegrid", rc=custom_params, context="paper", font_scale = 1)

ax = plt.subplot(111)

plot_roc_auc_curve(transformer_death_in_hosp_roc_df, transformer_death_in_hosp_resampled_roc_df, transformer_death_in_hosp_resampled_roc_std, np.median(transformer_death_in_hosp_roc_aucs), death_in_hosp_transformer_selected_fold,
                          death_in_hosp_xgb_roc_df, death_in_hosp_xgb_resampled_roc_df, death_in_hosp_xgb_resampled_roc_std, np.median(death_in_hosp_xgb_roc_aucs), death_in_hosp_xgb_selected_fold,
                            death_in_hosp_lstm_roc_df, death_in_hosp_lstm_resampled_roc_df, death_in_hosp_lstm_resampled_roc_std, np.median(death_in_hosp_lstm_roc_aucs), death_in_hosp_lstm_selected_fold,
                       death_in_hosp_thrivec_resampled_tpr, death_in_hosp_thrivec_resampled_roc_std, thrivec_death_in_hosp_roc_auc,
                       ax, plot_legend = True, tick_label_size = 11, label_font_size = 13)

plt.show()

### Overall Precision-Recall curve

In [None]:
def plot_pr_curve(transformer_pr_df, transformer_resampled_pr_df, transformer_resampled_pr_std, transformer_pr_aucs, transformer_selected_fold,
                    xgb_pr_df, xgb_resampled_pr_df, xgb_resampled_pr_std, xgb_pr_aucs, xgb_selected_fold,
                    lstm_pr_df, lstm_resampled_pr_df, lstm_resampled_pr_std, lstm_pr_aucs, lstm_selected_fold,
                      thrivec_resampled_precision, thrivec_resampled_pr_std, thrivec_pr_auc,
                      ax1, plot_legend = True, tick_label_size = 11, label_font_size = 13):
    
    all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)
    n_float_numbers = 3

    ## Main model: Transformer
    # plot variation across folds (+/- 1 std)
    error = 1*transformer_resampled_pr_std.values
    lower = transformer_resampled_pr_df[transformer_resampled_pr_df.fold == transformer_selected_fold].precision - error
    upper = transformer_resampled_pr_df[transformer_resampled_pr_df.fold == transformer_selected_fold].precision + error
    ax1.fill_between(transformer_resampled_pr_std.index, lower, upper, alpha=0.2, color=all_colors_palette[0])
    
    # Plot selected fold in bold
    ax1 = sns.lineplot(data=transformer_pr_df[transformer_pr_df.fold == transformer_selected_fold], x='recall', y='precision', color=all_colors_palette[0], label=f'Transformer (area = %0.{n_float_numbers}f)' % transformer_pr_aucs[transformer_selected_fold],
                       ax=ax1, errorbar=None)
    
    ## Comparators: THRIVE-C
    # plot variation across folds (+/- 1 std)
    error = 1*thrivec_resampled_pr_std.values
    lower = thrivec_resampled_precision - error
    upper = thrivec_resampled_precision + error
    ax1.fill_between(np.linspace(0, 1, 200), lower, upper, alpha=0.1, color=all_colors_palette[1])
    
    # Plot THRIVE-C in bold
    sns.lineplot(x=np.linspace(0, 1, 200), y=thrivec_resampled_precision, color=all_colors_palette[1], label='THRIVE-C (area = %0.2f)' % thrivec_pr_auc,
                 ax=ax1, linewidth=2)
    
    ## Comparators: XGBoost
    error = 1*xgb_resampled_pr_std.values
    lower = xgb_resampled_pr_df[xgb_resampled_pr_df.fold == xgb_selected_fold].precision - error
    upper = xgb_resampled_pr_df[xgb_resampled_pr_df.fold == xgb_selected_fold].precision + error
    ax1.fill_between(xgb_resampled_pr_std.index, lower, upper, alpha=0.2, color=all_colors_palette[2])
    ax1 = sns.lineplot(data=xgb_pr_df[xgb_pr_df.fold == xgb_selected_fold], x='recall', y='precision', color=all_colors_palette[2], label=f'XGBoost (area = %0.{n_float_numbers}f)' % xgb_pr_aucs[xgb_selected_fold],
                       ax=ax1, errorbar=None)
    
    ## Comparators: LSTM
    error = 1*lstm_resampled_pr_std.values
    lower = lstm_resampled_pr_df[lstm_resampled_pr_df.fold == lstm_selected_fold].precision - error
    upper = lstm_resampled_pr_df[lstm_resampled_pr_df.fold == lstm_selected_fold].precision + error
    ax1.fill_between(lstm_resampled_pr_std.index, lower, upper, alpha=0.2, color=all_colors_palette[3])
    ax1 = sns.lineplot(data=lstm_pr_df[lstm_pr_df.fold == lstm_selected_fold], x='recall', y='precision', color=all_colors_palette[3], label=f'LSTM (area = %0.{n_float_numbers}f)' % lstm_pr_aucs[lstm_selected_fold],
                       ax=ax1, errorbar=None)
    
    ax1.set_xlabel('Recall', fontsize=label_font_size)
    ax1.set_ylabel('Precision', fontsize=label_font_size)
    ax1.tick_params('x', labelsize=tick_label_size)
    ax1.tick_params('y', labelsize=tick_label_size)
    
    if plot_legend:
        legend_markers, legend_labels = ax1.get_legend_handles_labels()
        sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)
        sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)
        sd3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)
        sd4_patch = mpatches.Patch(color=all_colors_palette[3], alpha=0.3)
        sd_marker = (sd1_patch, sd2_patch, sd3_patch, sd4_patch)
        sd_labels = '± s.d.'
        legend_markers.append(sd_marker)
        legend_labels.append(sd_labels)
        ax1.legend(legend_markers, legend_labels, fontsize=label_font_size,
                  handler_map={tuple: HandlerTuple(ndivide=None)})
    
    else:
        # remove legend
        ax1.get_legend().remove()
    
    fig1 = ax1.get_figure()
    return fig1

## Plot PR curve for mrs02

In [None]:
custom_params = {"axes.spines.right": False, "axes.spines.top": False, 'figure.figsize':(10,10)}
sns.set_theme(style="whitegrid", rc=custom_params, context="paper", font_scale = 1)

ax1 = plt.subplot(111)

fig1 = plot_pr_curve(mrs02_transformer_pr_df, mrs02_transformer_resampled_pr_df, mrs02_transformer_resampled_pr_std, mrs02_transformer_pr_aucs, mrs02_transformer_selected_fold,
                    mrs02_xgb_pr_df, mrs02_xgb_resampled_pr_df, mrs02_xgb_resampled_pr_std, mrs02_xgb_pr_aucs, mrs02_xgb_selected_fold,
                    mrs02_lstm_pr_df, mrs02_lstm_resampled_pr_df, mrs02_lstm_resampled_pr_std, mrs02_lstm_pr_aucs, mrs02_lstm_selected_fold,
                      mrs02_thrivec_resampled_precision, mrs02_thrivec_resampled_pr_std, mrs02_thrivec_pr_auc,
                      ax1, plot_legend = True, tick_label_size = 11, label_font_size = 13)

plt.show()

## Plot PR curve for death

In [None]:
custom_params = {"axes.spines.right": False, "axes.spines.top": False, 'figure.figsize':(10,10)}
sns.set_theme(style="whitegrid", rc=custom_params, context="paper", font_scale = 1)

ax1 = plt.subplot(111)

fig1 = plot_pr_curve(transformer_death_pr_df, transformer_death_resampled_pr_df, transformer_death_resampled_pr_std, transformer_death_pr_aucs, death_transformer_selected_fold,
                    death_xgb_pr_df, death_xgb_resampled_pr_df, death_xgb_resampled_pr_std, death_xgb_pr_aucs, death_xgb_selected_fold,
                    death_lstm_pr_df, death_lstm_resampled_pr_df, death_lstm_resampled_pr_std, death_lstm_pr_aucs, death_lstm_selected_fold,
                      death_thrivec_resampled_precision, death_thrivec_resampled_pr_std, death_thrivec_pr_auc,
                      ax1, plot_legend = True, tick_label_size = 11, label_font_size = 13)

plt.show()

## Plot PR curve for death in hospital

In [None]:
custom_params = {"axes.spines.right": False, "axes.spines.top": False, 'figure.figsize':(10,10)}
sns.set_theme(style="whitegrid", rc=custom_params, context="paper", font_scale = 1)

ax1 = plt.subplot(111)

fig1 = plot_pr_curve(transformer_death_in_hosp_pr_df, transformer_death_in_hosp_resampled_pr_df, transformer_death_in_hosp_resampled_pr_std, transformer_death_in_hosp_pr_aucs, death_in_hosp_transformer_selected_fold,
                    death_in_hosp_xgb_pr_df, death_in_hosp_xgb_resampled_pr_df, death_in_hosp_xgb_resampled_pr_std, death_in_hosp_xgb_pr_aucs, death_in_hosp_xgb_selected_fold,
                    death_in_hosp_lstm_pr_df, death_in_hosp_lstm_resampled_pr_df, death_in_hosp_lstm_resampled_pr_std, death_in_hosp_lstm_pr_aucs, 
                     death_in_hosp_lstm_selected_fold,
                      death_in_hosp_thrivec_resampled_precision, death_in_hosp_thrivec_resampled_pr_std, death_in_hosp_thrivec_pr_auc,
                      ax1, plot_legend = True, tick_label_size = 11, label_font_size = 13)

plt.show()

### Combined plot

In [None]:
sns.set_theme(style="whitegrid", context="paper", font_scale = 1)

cm = 1/2.54  # centimeters in inches
main_fig2 = plt.figure(figsize=(18 * cm, 30 * cm))

tick_label_size = 6
label_font_size = 7
subplot_number_font_size = 9
suptitle_font_size = 10
plot_subplot_titles = True

subfigs = main_fig2.subfigures(3, 1, wspace=0.07, height_ratios=[1, 1, 1])

# MRS02
subfigs[0].suptitle('I. Prediction of functional outcome (3 months)', fontsize=suptitle_font_size, horizontalalignment='left', x=0.1, y=1.)

ax1, ax2 = subfigs[0].subplots(1, 2)
subfigs[0].subplots_adjust(wspace=0.3)

plot_roc_auc_curve(mrs02_transformer_roc_df, mrs02_transformer_resampled_roc_df, mrs02_transformer_resampled_roc_std, np.median(mrs02_transformer_roc_aucs), mrs02_transformer_selected_fold,
                          mrs02_xgb_roc_df, mrs02_xgb_resampled_roc_df, mrs02_xgb_resampled_roc_std, np.median(mrs02_xgb_roc_aucs), mrs02_xgb_selected_fold,
                            mrs02_lstm_roc_df, mrs02_lstm_resampled_roc_df, mrs02_lstm_resampled_roc_std, np.median(mrs02_lstm_roc_aucs), mrs02_lstm_selected_fold,
                       mrs02_thrivec_resampled_tpr, mrs02_thrivec_resampled_roc_std, thrivec_mrs02_roc_auc, 
                       ax1, plot_legend = True, tick_label_size = tick_label_size, label_font_size = label_font_size)
if plot_subplot_titles:
    ax1.set_title('A.', fontsize=subplot_number_font_size, horizontalalignment='left', x=-0.1)

plot_pr_curve(mrs02_transformer_pr_df, mrs02_transformer_resampled_pr_df, mrs02_transformer_resampled_pr_std, mrs02_transformer_pr_aucs, mrs02_transformer_selected_fold,
                    mrs02_xgb_pr_df, mrs02_xgb_resampled_pr_df, mrs02_xgb_resampled_pr_std, mrs02_xgb_pr_aucs, mrs02_xgb_selected_fold,
                    mrs02_lstm_pr_df, mrs02_lstm_resampled_pr_df, mrs02_lstm_resampled_pr_std, mrs02_lstm_pr_aucs, mrs02_lstm_selected_fold,
                      mrs02_thrivec_resampled_precision, mrs02_thrivec_resampled_pr_std, mrs02_thrivec_pr_auc,
                      ax2, plot_legend = True, tick_label_size = tick_label_size, label_font_size = label_font_size)
if plot_subplot_titles:
    ax2.set_title('B.', fontsize=subplot_number_font_size, horizontalalignment='left', x=-0.1)

# Death
subfigs[1].suptitle('II. Prediction of mortality (3 months)', fontsize=suptitle_font_size, horizontalalignment='left', x=0.1, y=1.0)

ax3, ax4 = subfigs[1].subplots(1, 2)
subfigs[1].subplots_adjust(wspace=0.3)

plot_roc_auc_curve(transformer_death_roc_df, transformer_death_resampled_roc_df, transformer_death_resampled_roc_std, np.median(transformer_death_roc_aucs), death_transformer_selected_fold,
                          death_xgb_roc_df, death_xgb_resampled_roc_df, death_xgb_resampled_roc_std, np.median(death_xgb_roc_aucs), death_xgb_selected_fold,
                            death_lstm_roc_df, death_lstm_resampled_roc_df, death_lstm_resampled_roc_std, np.median(death_lstm_roc_aucs), death_lstm_selected_fold,
                       death_thrivec_resampled_tpr, death_thrivec_resampled_roc_std, thrivec_death_roc_auc,
                       ax3, plot_legend = True, tick_label_size = tick_label_size, label_font_size = label_font_size)
if plot_subplot_titles:
    ax3.set_title('C.', fontsize=subplot_number_font_size, horizontalalignment='left', x=-0.1)

plot_pr_curve(transformer_death_pr_df, transformer_death_resampled_pr_df, transformer_death_resampled_pr_std, transformer_death_pr_aucs, death_transformer_selected_fold,
                    death_xgb_pr_df, death_xgb_resampled_pr_df, death_xgb_resampled_pr_std, death_xgb_pr_aucs, death_xgb_selected_fold,
                    death_lstm_pr_df, death_lstm_resampled_pr_df, death_lstm_resampled_pr_std, death_lstm_pr_aucs, death_lstm_selected_fold,
                      death_thrivec_resampled_precision, death_thrivec_resampled_pr_std, death_thrivec_pr_auc,
                      ax4, plot_legend = True, tick_label_size = tick_label_size, label_font_size = label_font_size)
if plot_subplot_titles:
    ax4.set_title('D.', fontsize=subplot_number_font_size, horizontalalignment='left', x=-0.1)
    
    
# External validation
subfigs[2].suptitle('III. Prediction of mortality (in hospital)', fontsize=suptitle_font_size, horizontalalignment='left', x=0.1, y=1.0)

ax5, ax6 = subfigs[2].subplots(1, 2)
subfigs[2].subplots_adjust(wspace=0.3)

plot_roc_auc_curve(transformer_death_in_hosp_roc_df, transformer_death_in_hosp_resampled_roc_df, transformer_death_in_hosp_resampled_roc_std, np.median(transformer_death_in_hosp_roc_aucs), death_in_hosp_transformer_selected_fold,
                          death_in_hosp_xgb_roc_df, death_in_hosp_xgb_resampled_roc_df, death_in_hosp_xgb_resampled_roc_std, np.median(death_in_hosp_xgb_roc_aucs), death_in_hosp_xgb_selected_fold,
                            death_in_hosp_lstm_roc_df, death_in_hosp_lstm_resampled_roc_df, death_in_hosp_lstm_resampled_roc_std, np.median(death_in_hosp_lstm_roc_aucs), death_in_hosp_lstm_selected_fold,
                       death_in_hosp_thrivec_resampled_tpr, death_in_hosp_thrivec_resampled_roc_std, thrivec_death_in_hosp_roc_auc,
                            ax5, plot_legend = True, tick_label_size = tick_label_size, label_font_size = label_font_size)
    
if plot_subplot_titles:
    ax5.set_title('E.', fontsize=subplot_number_font_size, horizontalalignment='left', x=-0.1)
    
plot_pr_curve(transformer_death_pr_df, transformer_death_resampled_pr_df, transformer_death_resampled_pr_std, transformer_death_pr_aucs, death_transformer_selected_fold,
                    death_xgb_pr_df, death_xgb_resampled_pr_df, death_xgb_resampled_pr_std, death_xgb_pr_aucs, death_xgb_selected_fold,
                    death_lstm_pr_df, death_lstm_resampled_pr_df, death_lstm_resampled_pr_std, death_lstm_pr_aucs, death_lstm_selected_fold,
                      death_thrivec_resampled_precision, death_thrivec_resampled_pr_std, death_thrivec_pr_auc,
                        ax6, plot_legend = True, tick_label_size = tick_label_size, label_font_size = label_font_size)

if plot_subplot_titles:
    ax6.set_title('F.', fontsize=subplot_number_font_size, horizontalalignment='left', x=-0.1)

    
plt.show()

In [None]:
main_fig2.savefig(os.path.join(output_dir, 'cross_validation_sets_comparative_performances.tiff'), bbox_inches="tight", format='tiff', dpi=1200)