# ROC Curve comparison with confidence intervals

In [None]:
import pickle
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
import pandas as pd
from sklearn.metrics import roc_curve, auc, precision_recall_curve
import matplotlib.patches as mpatches
from matplotlib.legend_handler import HandlerTuple

In [None]:
thrive_c_mrs02_predictions_path = '/Users/jk1/temp/opsum_prediction_output/THRIVE_C/THRIVE_C_3m_mrs02_predictions/test_gt_and_pred.pkl'
transformer_mrs02_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_mrs02/transformer_20230402_184459_test_set_evaluation'
lstm_mrs02_predictions_path = '/Users/jk1/temp/opsum_prediction_output/LSTM_72h_testing/3M_mRS02/2023_01_02_1057/all_folds'
xgb_mrs02_predictions_path = '/Users/jk1/temp/opsum_prediction_output/linear_72h_xgb/3M_mrs02/with_feature_aggregration/testing/all_folds'

thrive_c_death_predictions_path = '/Users/jk1/temp/opsum_prediction_output/THRIVE_C/THRIVE_C_3m_death_predictions/3m_death_test_gt_and_pred.pkl'
transformer_death_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/testing'
lstm_death_predictions_path = '/Users/jk1/temp/opsum_prediction_output/LSTM_72h_testing/3M_Death/2023_01_04_2020/all_folds'
xgb_death_predictions_path = '/Users/jk1/temp/opsum_prediction_output/linear_72h_xgb/3M_Death/testing/all_folds'

In [None]:
output_dir = '/Users/jk1/Downloads'

In [None]:
n_folds = 5
seed = 42

mrs02_transformer_selected_fold = 2
death_transformer_selected_fold = 1

# subtract one to go from fold number to index
mrs02_xgb_selected_fold = 3 - 1
death_xgb_selected_fold = 3 - 1

mrs02_lstm_selected_fold = 3 - 1 
death_lstm_selected_fold = 2 - 1

Load data

In [None]:
thrivec_mrs02_gt, thrivec_mrs02_predictions = pickle.load(open(thrive_c_mrs02_predictions_path, 'rb'))
thrivec_death_gt, thrivec_death_predictions = pickle.load(open(thrive_c_death_predictions_path, 'rb'))

In [None]:
transformer_mrs02_folds = []
for fidx in range(n_folds):
    transformer_mrs02_folds.append(pickle.load(open(os.path.join(transformer_mrs02_predictions_path, f'fold_{fidx}_test_gt_and_pred.pkl'), 'rb')))
transformer_death_folds = []
for fidx in range(n_folds):
    transformer_death_folds.append(pickle.load(open(os.path.join(transformer_death_predictions_path, f'fold_{fidx}_test_gt_and_pred.pkl'), 'rb')))

In [None]:
xgb_mrs02_folds = []
for fidx in range(n_folds):
    xgb_mrs02_folds.append(pickle.load(open(os.path.join(xgb_mrs02_predictions_path, f'test_gt_and_pred_cv_{fidx}.pkl'), 'rb')))
    
xgb_death_folds = []
for fidx in range(n_folds):
    xgb_death_folds.append(pickle.load(open(os.path.join(xgb_death_predictions_path, f'test_gt_and_pred_cv_{fidx}.pkl'), 'rb')))

In [None]:
lstm_mrs02_folds = []
for fidx in range(n_folds):
    # search in every subdir of lstm_mrs02_predictions_path to find and load file: test_gt_and_pred_fold_{fidx}.pkl
    for subdir in os.listdir(lstm_mrs02_predictions_path):
        if os.path.isdir(os.path.join(lstm_mrs02_predictions_path, subdir)):
            for file in os.listdir(os.path.join(lstm_mrs02_predictions_path, subdir)):
                if file == f'test_gt_and_pred_fold_{fidx+1}.pkl':
                    lstm_mrs02_folds.append(pickle.load(open(os.path.join(lstm_mrs02_predictions_path, subdir, file), 'rb')))
                    
lstm_death_folds = []
for fidx in range(n_folds):
    # search in every subdir of lstm_mrs02_predictions_path to find and load file: test_gt_and_pred_fold_{fidx}.pkl
    for subdir in os.listdir(lstm_death_predictions_path):
        if os.path.isdir(os.path.join(lstm_death_predictions_path, subdir)):
            for file in os.listdir(os.path.join(lstm_death_predictions_path, subdir)):
                if file == f'test_gt_and_pred_fold_{fidx+1}.pkl':
                    lstm_death_folds.append(pickle.load(open(os.path.join(lstm_death_predictions_path, subdir, file), 'rb')))

In [None]:
all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)
all_colors_palette

# Helper functions

In [None]:
def compute_roc_and_pr_curves(folds, n_interpolated_points=200):
    roc_df = pd.DataFrame()
    resampled_roc_df = pd.DataFrame()
    roc_aucs = []
    pr_df = pd.DataFrame()
    resampled_pr_df = pd.DataFrame()
    pr_aucs = []
    for fidx in tqdm(range(n_folds)):
        fpr, tpr, _ = roc_curve(folds[fidx][0], folds[fidx][1])
        roc_aucs.append(auc(fpr, tpr))
        resampled_tpr = np.interp(np.linspace(0, 1, n_interpolated_points), fpr, tpr)
        roc_df = roc_df.append(pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'fold': fidx}))
        resampled_roc_df = resampled_roc_df.append(pd.DataFrame({'fpr': np.linspace(0,1,n_interpolated_points),  'tpr': resampled_tpr, 'fold': fidx}))

        precision, recall, _ = precision_recall_curve(folds[fidx][0], folds[fidx][1])
        recall, precision = zip(*sorted(zip(recall, precision)))
        pr_aucs.append(auc(recall, precision))
        resampled_precision = np.interp(np.linspace(0, 1, n_interpolated_points), recall, precision)
        pr_df = pr_df.append(pd.DataFrame({'recall': recall, 'precision': precision, 'fold': fidx}))
        resampled_pr_df = resampled_pr_df.append(pd.DataFrame({'recall': np.linspace(0,1,n_interpolated_points),  'precision': resampled_precision, 'fold': fidx}))

    return roc_df, resampled_roc_df, roc_aucs, pr_df, resampled_pr_df, pr_aucs


## Prepare data for mrs02 outcome

### Compute ROC and PR curve standard deviation for THRIVE-C

In [None]:
# split thrivec data into 5 folds
mrs02_thrivec_folds = []
for fidx in range(n_folds):
    mrs02_thrivec_folds.append((thrivec_mrs02_gt[fidx::n_folds], thrivec_mrs02_predictions[fidx::n_folds]))

In [None]:
mrs02_thrivec_fpr, mrs02_thrivec_tpr, _ = roc_curve(thrivec_mrs02_gt, thrivec_mrs02_predictions)
mrs02_thrivec_roc_auc = auc(mrs02_thrivec_fpr, mrs02_thrivec_tpr)
mrs02_thrivec_resampled_tpr = np.interp(np.linspace(0, 1, 200), mrs02_thrivec_fpr, mrs02_thrivec_tpr)

mrs02_thrivec_precision, mrs02_thrivec_recall, _ = precision_recall_curve(thrivec_mrs02_gt, thrivec_mrs02_predictions)
mrs02_thrivec_recall, mrs02_thrivec_precision = zip(*sorted(zip(mrs02_thrivec_recall, mrs02_thrivec_precision)))
mrs02_thrivec_pr_auc = auc(mrs02_thrivec_recall, mrs02_thrivec_precision)
mrs02_thrivec_resampled_precision = np.interp(np.linspace(0, 1, 200), mrs02_thrivec_recall, mrs02_thrivec_precision)

In [None]:
mrs02_thrivec_roc_df, mrs02_thrivec_resampled_roc_df, mrs02_thrivec_roc_aucs, mrs02_thrivec_pr_df, mrs02_thrivec_resampled_pr_df, mrs02_thrivec_pr_aucs = compute_roc_and_pr_curves(mrs02_thrivec_folds)

In [None]:
mrs02_thrivec_resampled_roc_std = mrs02_thrivec_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
mrs02_thrivec_resampled_pr_std = mrs02_thrivec_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

### Transformer curves with fold variation

In [None]:
mrs02_transformer_roc_df, mrs02_transformer_resampled_roc_df, mrs02_transformer_roc_aucs, mrs02_transformer_pr_df, mrs02_transformer_resampled_pr_df, mrs02_transformer_pr_aucs = compute_roc_and_pr_curves(transformer_mrs02_folds)

In [None]:
mrs02_transformer_resampled_roc_std = mrs02_transformer_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
mrs02_transformer_resampled_pr_std = mrs02_transformer_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

### XGB curves with fold variation

In [None]:
mrs02_xgb_roc_df, mrs02_xgb_resampled_roc_df, mrs02_xgb_roc_aucs, mrs02_xgb_pr_df, mrs02_xgb_resampled_pr_df, mrs02_xgb_pr_aucs = compute_roc_and_pr_curves(xgb_mrs02_folds)

In [None]:
mrs02_xgb_resampled_roc_std = mrs02_xgb_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
mrs02_xgb_resampled_pr_std = mrs02_xgb_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

### LSTM curves with fold variation

In [None]:
mrs02_lstm_roc_df, mrs02_lstm_resampled_roc_df, mrs02_lstm_roc_aucs, mrs02_lstm_pr_df, mrs02_lstm_resampled_pr_df, mrs02_lstm_pr_aucs = compute_roc_and_pr_curves(lstm_mrs02_folds)

In [None]:
mrs02_lstm_resampled_roc_std = mrs02_lstm_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
mrs02_lstm_resampled_pr_std = mrs02_lstm_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

## Prepare data for mortality prediction

In [None]:
# split thrivec data into 5 folds
death_thrivec_folds = []
for fidx in range(n_folds):
    death_thrivec_folds.append((thrivec_death_gt[fidx::n_folds], thrivec_death_predictions[fidx::n_folds]))

death_thrivec_fpr, death_thrivec_tpr, _ = roc_curve(thrivec_death_gt, thrivec_death_predictions)
death_thrivec_roc_auc = auc(death_thrivec_fpr, death_thrivec_tpr)
death_thrivec_resampled_tpr = np.interp(np.linspace(0, 1, 200), death_thrivec_fpr, death_thrivec_tpr)

death_thrivec_precision, death_thrivec_recall, _ = precision_recall_curve(thrivec_death_gt, thrivec_death_predictions)
death_thrivec_recall, death_thrivec_precision = zip(*sorted(zip(death_thrivec_recall, death_thrivec_precision)))
death_thrivec_pr_auc = auc(death_thrivec_recall, death_thrivec_precision)
death_thrivec_resampled_precision = np.interp(np.linspace(0, 1, 200), death_thrivec_recall, death_thrivec_precision)

death_thrivec_roc_df, death_thrivec_resampled_roc_df, death_thrivec_roc_aucs, death_thrivec_pr_df, death_thrivec_resampled_pr_df, death_thrivec_pr_aucs = compute_roc_and_pr_curves(
    death_thrivec_folds)

death_thrivec_resampled_roc_std = death_thrivec_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
death_thrivec_resampled_pr_std = death_thrivec_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision'][
    'std']

In [None]:
transformer_death_roc_df, transformer_death_resampled_roc_df, transformer_death_roc_aucs, transformer_death_pr_df, transformer_death_resampled_pr_df, transformer_death_pr_aucs = compute_roc_and_pr_curves(transformer_death_folds)

transformer_death_resampled_roc_std = transformer_death_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
transformer_death_resampled_pr_std = transformer_death_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']


In [None]:
death_xgb_roc_df, death_xgb_resampled_roc_df, death_xgb_roc_aucs, death_xgb_pr_df, death_xgb_resampled_pr_df, death_xgb_pr_aucs = compute_roc_and_pr_curves(xgb_death_folds)
death_xgb_resampled_roc_std = death_xgb_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
death_xgb_resampled_pr_std = death_xgb_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

In [None]:
death_lstm_roc_df, death_lstm_resampled_roc_df, death_lstm_roc_aucs, death_lstm_pr_df, death_lstm_resampled_pr_df, death_lstm_pr_aucs = compute_roc_and_pr_curves(lstm_death_folds)
death_lstm_resampled_roc_std = death_lstm_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
death_lstm_resampled_pr_std = death_lstm_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

## Plot ROC AUCs

In [None]:
mrs02_roc_auc_df = pd.concat(
[
    pd.DataFrame({'model': 'Transformer', 'auc': mrs02_transformer_roc_aucs}),
    pd.DataFrame({'model': 'XGBoost', 'auc': mrs02_xgb_roc_aucs}),
    pd.DataFrame({'model': 'LSTM', 'auc': mrs02_lstm_roc_aucs}),
    pd.DataFrame({'model': 'THRIVE-C', 'auc': mrs02_thrivec_roc_aucs}),
]
)

selected_fold_roc_auc_df = pd.concat(
    [
        pd.DataFrame({'model': 'Transformer', 'auc': mrs02_transformer_roc_aucs[mrs02_transformer_selected_fold]}, index=[0]),
        pd.DataFrame({'model': 'XGBoost', 'auc': mrs02_xgb_roc_aucs[mrs02_xgb_selected_fold]}, index=[0]),
        pd.DataFrame({'model': 'LSTM', 'auc': mrs02_lstm_roc_aucs[mrs02_lstm_selected_fold]}, index=[0]),
        pd.DataFrame({'model': 'THRIVE-C', 'auc': mrs02_thrivec_roc_auc}, index=[0]),
    ]
)

In [None]:
ax = sns.stripplot(x='model', y='auc', data=mrs02_roc_auc_df)
sns.stripplot(x='model', y='auc', data=selected_fold_roc_auc_df, ax=ax, color='red')

In [None]:
death_roc_auc_df = pd.concat(
[
    pd.DataFrame({'model': 'Transformer', 'auc': transformer_death_roc_aucs}),
    pd.DataFrame({'model': 'XGBoost', 'auc': death_xgb_roc_aucs}),
    pd.DataFrame({'model': 'LSTM', 'auc': death_lstm_roc_aucs}),
    pd.DataFrame({'model': 'THRIVE-C', 'auc': death_thrivec_roc_aucs}),
]
)

selected_fold_death_roc_auc_df = pd.concat(
    [
        pd.DataFrame({'model': 'Transformer', 'auc': transformer_death_roc_aucs[death_transformer_selected_fold]}, index=[0]),
        pd.DataFrame({'model': 'XGBoost', 'auc': death_xgb_roc_aucs[death_xgb_selected_fold]}, index=[0]),
        pd.DataFrame({'model': 'LSTM', 'auc': death_lstm_roc_aucs[death_lstm_selected_fold]}, index=[0]),
        pd.DataFrame({'model': 'THRIVE-C', 'auc': death_thrivec_roc_auc}, index=[0]),
    ]
)

In [None]:
ax = sns.stripplot(x='model', y='auc', data=death_roc_auc_df)
sns.stripplot(x='model', y='auc', data=selected_fold_death_roc_auc_df, ax=ax, color='red')

In [None]:
import matplotlib
from matplotlib.font_manager import FontProperties

font_files = ['/Library/Fonts/calibri-bold-italic.ttf',
'/Library/Fonts/calibri-bold.ttf',
'/Library/Fonts/calibri-italic.ttf',
'/Library/Fonts/calibri-regular.ttf',
'/Library/Fonts/calibril.ttf']

font_path = font_files[-1]
calibri_font = FontProperties(fname=font_path)
calibri_font.get_name()

for font_file in font_files:
    matplotlib.font_manager.fontManager.addfont(font_file)

In [None]:
 # combined plot 
sns.set_theme(style="whitegrid", context="paper", font_scale = 1)
plt.rcParams['font.family'] = calibri_font.get_name()

cm = 1/2.54  # centimeters in inches
main_fig = plt.figure(figsize=(18 * cm, 9 * cm))


tick_label_size = 6
label_font_size = 7
subplot_number_font_size = 9
suptitle_font_size = 10

ax = main_fig.subplots(1, 2, sharex=False, sharey=True)

sns.stripplot(x='model', y='auc', data=mrs02_roc_auc_df, ax=ax[0])
sns.stripplot(x='model', y='auc', data=selected_fold_roc_auc_df, ax=ax[0], color='red')
ax[0].set_title('Prediction of functional outcome')
ax[0].set_ylabel('ROC AUC', fontsize=label_font_size)
ax[0].set_xlabel(None)

sns.stripplot(x='model', y='auc', data=death_roc_auc_df, ax=ax[1])
sns.stripplot(x='model', y='auc', data=selected_fold_death_roc_auc_df, ax=ax[1], color='red')
ax[1].set_title('Prediction of mortality')
ax[1].set_ylabel(None)

# adjust spacing between subplots
plt.subplots_adjust(wspace=0.1)

plt.suptitle('ROC AUC performance (inter-fold variability)', fontsize=suptitle_font_size, x=0.5, y=1.02)

In [None]:
# main_fig.savefig(os.path.join(output_dir, 'inter_fold_variation_roc_auc_performance.svg'), bbox_inches="tight", format='svg', dpi=1200)