# ROC Curve comparison with confidence intervals

In [None]:
import pickle
import numpy as np
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
import pandas as pd
from sklearn.metrics import roc_curve, auc, precision_recall_curve
import matplotlib.patches as mpatches
from matplotlib.legend_handler import HandlerTuple

In [None]:
lstm_bs_predictions_path = '/Users/jk1/temp/opsum_prediction_output/LSTM_72h_testing/3M_mRS01/2023_01_06_1847/test_LSTM_sigmoid_all_unchanged_0.0_2_True_RMSprop_3M mRS 0-1_128_3/bootstrapped_gt_and_pred.pkl'
thrive_c_bs_mrs02_predictions_path = '/Users/jk1/temp/opsum_prediction_output/THRIVE_C/THRIVE_C_3m_mrs02_predictions/bootstrapped_gt_and_pred.pkl'
thrive_c_bs_death_predictions_path = '/Users/jk1/temp/opsum_prediction_output/THRIVE_C/THRIVE_C_3m_death_predictions/3m_death_bootstrapped_gt_and_pred.pkl'
xgb_bs_predictions_path = '/Users/jk1/temp/opsum_prediction_output/linear_72h_xgb/with_feature_aggregration/testing/bootstrapped_gt_and_pred.pkl'
transformer_bs_death_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/testing/fold_1_bootstrapped_gt_and_pred.pkl'
transformer_bs_mrs02_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_mrs02/transformer_20230402_184459_test_set_evaluation/fold_2_bootstrapped_gt_and_pred.pkl'
output_dir = '/Users/jk1/Downloads'
outcome = '3M Death'

In [None]:
if outcome == '3M Death':
    transformer_bs_predictions_path = transformer_bs_death_predictions_path
    thrive_c_bs_predictions_path = thrive_c_bs_death_predictions_path
elif outcome == '3M mRS 0-2':
    transformer_bs_predictions_path = transformer_bs_mrs02_predictions_path
    thrive_c_bs_predictions_path = thrive_c_bs_mrs02_predictions_path
else:
    raise ValueError('Outcome not supported')

In [None]:
lstm_bs_gt, lstm_bs_predictions = pickle.load(open(lstm_bs_predictions_path, 'rb'))
thrivec_bs_gt, thrivec_bs_predictions = pickle.load(open(thrive_c_bs_predictions_path, 'rb'))
# xgb_bs_gt, xgb_bs_predictions = pickle.load(open(xgb_bs_predictions_path, 'rb'))
transformer_bs_gt, transformer_bs_predictions = pickle.load(open(transformer_bs_predictions_path, 'rb'))

In [None]:
all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)
all_colors_palette

Compute resampled ROC curve and ROC area / PR curve and PR area for LSTM over bootstrapped samples

In [None]:
def compute_boostrapped_auc_and_pr_curves(bs_gt, bs_predictions, interpolations_points=200):
    bs_roc_aucs = []
    bs_pr_aucs = []

    resampled_roc_df = pd.DataFrame()
    resampled_pr_df = pd.DataFrame()

    for idx in tqdm(range(len(bs_predictions))):
        # calculate the ROC curve and AUC
        bs_fpr, bs_tpr, _ = roc_curve(bs_gt[idx], bs_predictions[idx])
        bs_roc_auc = auc(bs_fpr, bs_tpr)
        bs_roc_aucs.append(bs_roc_auc)

        bs_precision, bs_recall, _ = precision_recall_curve(bs_gt[idx], bs_predictions[idx])
        bs_pr_auc = auc(bs_recall, bs_precision)
        bs_pr_aucs.append(bs_pr_auc)

        bs_resampled_tpr = np.interp(np.linspace(0, 1, interpolations_points), bs_fpr, bs_tpr)
        bs_resampled_roc_df = pd.DataFrame({'fpr': np.linspace(0,1,interpolations_points),  'tpr': bs_resampled_tpr})
        bs_resampled_roc_df['bootstrap_idx'] = idx
        resampled_roc_df = resampled_roc_df.append(bs_resampled_roc_df)

        # sort by recall
        bs_recall, bs_precision = zip(*sorted(zip(bs_recall, bs_precision)))
        bs_resampled_precision = np.interp(np.linspace(0, 1, interpolations_points), bs_recall, bs_precision)
        bs_resampled_pr_df = pd.DataFrame({'precision': bs_resampled_precision, 'recall': np.linspace(0, 1, interpolations_points)})
        bs_resampled_pr_df['bootstrap_idx'] = idx
        resampled_pr_df = resampled_pr_df.append(bs_resampled_pr_df)

    return bs_roc_aucs, bs_pr_aucs, resampled_roc_df, resampled_pr_df

In [None]:
lstm_bs_roc_aucs, lstm_bs_pr_aucs, lstm_resampled_roc_df, lstm_resampled_pr_df = compute_boostrapped_auc_and_pr_curves(lstm_bs_gt, lstm_bs_predictions)
thrivec_bs_roc_aucs, thrivec_bs_pr_aucs, thrivec_resampled_roc_df, thrivec_resampled_pr_df = compute_boostrapped_auc_and_pr_curves(thrivec_bs_gt, thrivec_bs_predictions)
# xgb_bs_roc_aucs, xgb_bs_pr_aucs, xgb_resampled_roc_df, xgb_resampled_pr_df = compute_boostrapped_auc_and_pr_curves(xgb_bs_gt, xgb_bs_predictions)
transformer_bs_roc_aucs, transformer_bs_pr_aucs, transformer_resampled_roc_df, transformer_resampled_pr_df = compute_boostrapped_auc_and_pr_curves(transformer_bs_gt, transformer_bs_predictions)

### Bootstrapped resampled ROC curve


In [None]:
custom_params = {"axes.spines.right": False, "axes.spines.top": False, 'figure.figsize':(10,10)}
sns.set_theme(style="whitegrid", rc=custom_params, context="paper", font_scale = 1)

ax = plt.subplot(111)

plot_legend = True

tick_label_size = 11
label_font_size = 13

# ax = sns.lineplot(data=lstm_resampled_roc_df, x='fpr', y='tpr', color=all_colors_palette[0], label='LSTM (area = %0.2f)' % np.median(lstm_bs_roc_aucs),
#                    ax=ax, errorbar='sd')

ax = sns.lineplot(data=thrivec_resampled_roc_df, x='fpr', y='tpr', color=all_colors_palette[1], label='THRIVE-C (area = %0.2f)' % np.median(thrivec_bs_roc_aucs),
                   ax=ax, errorbar='sd')

# ax = sns.lineplot(data=xgb_resampled_roc_df, x='fpr', y='tpr', color=all_colors_palette[2], label='XGBoost (area = %0.2f)' % np.median(xgb_bs_roc_aucs),
#                    ax=ax, errorbar='sd')

ax = sns.lineplot(data=transformer_resampled_roc_df, x='fpr', y='tpr', color=all_colors_palette[0], label='Transformer (area = %0.2f)' % np.median(transformer_bs_roc_aucs),
                   ax=ax, errorbar='sd')

ax.plot([0, 1], [0, 1], color='grey', lw=1, linestyle='--', alpha=0.5)

ax.set_xlabel('1 - Specificity (False Positive Rate)', fontsize=label_font_size)
ax.set_ylabel('Sensitivity (True Positive Rate)', fontsize=label_font_size)
ax.tick_params('x', labelsize=tick_label_size)
ax.tick_params('y', labelsize=tick_label_size)

if plot_legend:
    legend_markers, legend_labels = ax.get_legend_handles_labels()
    sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)
    sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)
    # sd3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)
    # sd4_patch = mpatches.Patch(color=all_colors_palette[3], alpha=0.3)
    # sd_marker = (sd1_patch, sd2_patch, sd3_patch, sd4_patch)
    sd_marker = (sd1_patch, sd2_patch)
    sd_labels = '± s.d.'
    legend_markers.append(sd_marker)
    legend_labels.append(sd_labels)
    ax.legend(legend_markers, legend_labels, fontsize=label_font_size,
              handler_map={tuple: HandlerTuple(ndivide=None)})

else:
    # remove legend
    ax.get_legend().remove()

fig = ax.get_figure()

plt.show()

In [None]:
fig.savefig(os.path.join(output_dir, f'roc_curve_{outcome.replace(" ", "_")}.svg'), bbox_inches="tight", format='svg', dpi=1200)

### Overall Precision-Recall curve

In [None]:
custom_params = {"axes.spines.right": False, "axes.spines.top": False, 'figure.figsize':(10,10)}
sns.set_theme(style="whitegrid", rc=custom_params, context="paper", font_scale = 1)

ax1 = plt.subplot(111)

plot_legend = True

tick_label_size = 11
label_font_size = 13

# ax1 = sns.lineplot(data=lstm_resampled_pr_df, x='recall', y='precision', color=all_colors_palette[1], label='LSTM (area = %0.2f)' % np.median(lstm_bs_pr_aucs),
#                    ax=ax1, errorbar='sd')

ax1 = sns.lineplot(data=thrivec_resampled_pr_df, x='recall', y='precision', color=all_colors_palette[1], label='THRIVE-C (area = %0.2f)' % np.median(thrivec_bs_pr_aucs),
                   ax=ax1, errorbar='sd')

# ax1 = sns.lineplot(data=xgb_resampled_pr_df, x='recall', y='precision', color=all_colors_palette[2], label='XGBoost (area = %0.2f)' % np.median(xgb_bs_pr_aucs),
#                    ax=ax1, errorbar='sd')

ax1 = sns.lineplot(data=transformer_resampled_pr_df, x='recall', y='precision', color=all_colors_palette[0], label='Transformer (area = %0.2f)' % np.median(transformer_bs_pr_aucs),
                   ax=ax1, errorbar='sd')

# ax1.set_ylim(0, 1)

ax1.set_xlabel('Recall', fontsize=label_font_size)
ax1.set_ylabel('Precision', fontsize=label_font_size)
ax1.tick_params('x', labelsize=tick_label_size)
ax1.tick_params('y', labelsize=tick_label_size)

if plot_legend:
    legend_markers, legend_labels = ax1.get_legend_handles_labels()
    sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)
    sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)
    # sd3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)
    # sd4_patch = mpatches.Patch(color=all_colors_palette[3], alpha=0.3)
    # sd_marker = (sd1_patch, sd2_patch, sd3_patch, sd4_patch)
    sd_marker = (sd1_patch, sd2_patch)
    sd_labels = '± s.d.'
    legend_markers.append(sd_marker)
    legend_labels.append(sd_labels)
    ax1.legend(legend_markers, legend_labels, fontsize=label_font_size,
              handler_map={tuple: HandlerTuple(ndivide=None)})

else:
    # remove legend
    ax1.get_legend().remove()

fig1 = ax1.get_figure()

plt.show()

In [None]:
fig1.savefig(os.path.join(output_dir, f'precision_recall_curve_{outcome.replace(" ", "_")}.svg'), bbox_inches="tight", format='svg', dpi=1200)
