# ROC Curve comparison with confidence intervals

In [None]:
import pickle
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
import pandas as pd
from sklearn.metrics import roc_curve, auc, precision_recall_curve
import matplotlib.patches as mpatches
from matplotlib.legend_handler import HandlerTuple

In [None]:
thrive_c_mrs02_predictions_path = '/Users/jk1/temp/opsum_prediction_output/THRIVE_C/THRIVE_C_3m_mrs02_predictions/test_gt_and_pred.pkl'
transformer_mrs02_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_mrs02/transformer_20230402_184459_test_set_evaluation'

thrive_c_death_predictions_path = '/Users/jk1/temp/opsum_prediction_output/THRIVE_C/THRIVE_C_3m_death_predictions/3m_death_test_gt_and_pred.pkl'
transformer_death_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/testing'
transformer_ext_death_predictions_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/external_validation'

In [None]:
output_dir = '/Users/jk1/Downloads'

In [None]:
n_folds = 5
seed = 42
mrs02_selected_fold = 2
death_selected_fold = 1

In [None]:
save_plot_data = False # save data used in plots

Load data

In [None]:
thrivec_mrs02_gt, thrivec_mrs02_predictions = pickle.load(open(thrive_c_mrs02_predictions_path, 'rb'))
thrivec_death_gt, thrivec_death_predictions = pickle.load(open(thrive_c_death_predictions_path, 'rb'))

In [None]:
transformer_mrs02_folds = []
for fidx in range(n_folds):
    transformer_mrs02_folds.append(pickle.load(open(os.path.join(transformer_mrs02_predictions_path, f'fold_{fidx}_test_gt_and_pred.pkl'), 'rb')))

In [None]:
transformer_death_folds = []
for fidx in range(n_folds):
    transformer_death_folds.append(pickle.load(open(os.path.join(transformer_death_predictions_path, f'fold_{fidx}_test_gt_and_pred.pkl'), 'rb')))

In [None]:
transformer_ext_death_folds = []
for fidx in range(n_folds):
    transformer_ext_death_folds.append(pickle.load(open(os.path.join(transformer_ext_death_predictions_path, f'fold_{fidx}_test_gt_and_pred.pkl'), 'rb')))

In [None]:
all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)
all_colors_palette

# Helper functions

In [None]:
def compute_roc_and_pr_curves(folds, n_interpolated_points=200):
    roc_df = pd.DataFrame()
    resampled_roc_df = pd.DataFrame()
    roc_aucs = []
    pr_df = pd.DataFrame()
    resampled_pr_df = pd.DataFrame()
    pr_aucs = []
    for fidx in tqdm(range(n_folds)):
        fpr, tpr, _ = roc_curve(folds[fidx][0], folds[fidx][1])
        roc_aucs.append(auc(fpr, tpr))
        resampled_tpr = np.interp(np.linspace(0, 1, n_interpolated_points), fpr, tpr)
        roc_df = roc_df.append(pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'fold': fidx}))
        resampled_roc_df = resampled_roc_df.append(pd.DataFrame({'fpr': np.linspace(0,1,n_interpolated_points),  'tpr': resampled_tpr, 'fold': fidx}))

        precision, recall, _ = precision_recall_curve(folds[fidx][0], folds[fidx][1])
        recall, precision = zip(*sorted(zip(recall, precision)))
        pr_aucs.append(auc(recall, precision))
        resampled_precision = np.interp(np.linspace(0, 1, n_interpolated_points), recall, precision)
        pr_df = pr_df.append(pd.DataFrame({'recall': recall, 'precision': precision, 'fold': fidx}))
        resampled_pr_df = resampled_pr_df.append(pd.DataFrame({'recall': np.linspace(0,1,n_interpolated_points),  'precision': resampled_precision, 'fold': fidx}))

    return roc_df, resampled_roc_df, roc_aucs, pr_df, resampled_pr_df, pr_aucs


# Functional Outcome

In [None]:
outcome = '3M mRS 0-2'
selected_fold = mrs02_selected_fold

### Compute ROC and PR curve standard deviation for THRIVE-C

In [None]:
# split thrivec data into 5 folds
thrivec_folds = []
for fidx in range(n_folds):
    thrivec_folds.append((thrivec_mrs02_gt[fidx::n_folds], thrivec_mrs02_predictions[fidx::n_folds]))

In [None]:
thrivec_fpr, thrivec_tpr, _ = roc_curve(thrivec_mrs02_gt, thrivec_mrs02_predictions)
thrivec_roc_auc = auc(thrivec_fpr, thrivec_tpr)
thrivec_resampled_tpr = np.interp(np.linspace(0, 1, 200), thrivec_fpr, thrivec_tpr)

thrivec_precision, thrivec_recall, _ = precision_recall_curve(thrivec_mrs02_gt, thrivec_mrs02_predictions)
thrivec_recall, thrivec_precision = zip(*sorted(zip(thrivec_recall, thrivec_precision)))
thrivec_pr_auc = auc(thrivec_recall, thrivec_precision)
thrivec_resampled_precision = np.interp(np.linspace(0, 1, 200), thrivec_recall, thrivec_precision)

In [None]:
thrivec_roc_df, thrivec_resampled_roc_df, thrivec_roc_aucs, thrivec_pr_df, thrivec_resampled_pr_df, thrivec_pr_aucs = compute_roc_and_pr_curves(thrivec_folds)

In [None]:
thrivec_resampled_roc_std = thrivec_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
thrivec_resampled_pr_std = thrivec_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

### Transformer curves with fold variation

In [None]:
transformer_roc_df, transformer_resampled_roc_df, transformer_roc_aucs, transformer_pr_df, transformer_resampled_pr_df, transformer_pr_aucs = compute_roc_and_pr_curves(transformer_mrs02_folds)

In [None]:
transformer_resampled_roc_std = transformer_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
transformer_resampled_pr_std = transformer_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

### Resampled ROC curve with fold variation


In [None]:
def plot_mrs_roc_auc_curve(transformer_roc_df, transformer_resampled_roc_df, transformer_resampled_roc_std, transformer_roc_aucs,
                           thrivec_resampled_tpr, thrivec_resampled_roc_std, thrivec_roc_auc, selected_fold,
                           ax, plot_legend = True, tick_label_size = 11, label_font_size = 13):

    all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)
    
    ## Main model: Transformer
    # plot variation across folds (+/- 1 std)
    error = 1*transformer_resampled_roc_std.values
    lower = transformer_resampled_roc_df[transformer_resampled_roc_df.fold == selected_fold].tpr - error
    upper = transformer_resampled_roc_df[transformer_resampled_roc_df.fold == selected_fold].tpr + error
    ax.fill_between(transformer_resampled_roc_std.index, lower, upper, alpha=0.2, color=all_colors_palette[0])
    
    # Plot selected fold in bold
    ax = sns.lineplot(data=transformer_roc_df[transformer_roc_df.fold == selected_fold], x='fpr', y='tpr', color=all_colors_palette[0], label='Transformer (area = %0.2f)' % transformer_roc_aucs[selected_fold],
                       ax=ax, errorbar=None)
    
    
    ## Comparators: THRIVE-C
    # plot variation across folds (+/- 1 std)
    error = 1*thrivec_resampled_roc_std.values
    lower = thrivec_resampled_tpr - error
    upper = thrivec_resampled_tpr + error
    ax.fill_between(np.linspace(0, 1, 200), lower, upper, alpha=0.2, color=all_colors_palette[1])
    
    # Plot THRIVE-C in bold
    sns.lineplot(x=np.linspace(0, 1, 200), y=thrivec_resampled_tpr, color=all_colors_palette[1], label='THRIVE-C (area = %0.2f)' % thrivec_roc_auc,
                 ax=ax, linewidth=2)
    
    
    ax.plot([0, 1], [0, 1], color='grey', lw=1, linestyle='--', alpha=0.5)
    
    ax.set_xlabel('1 - Specificity (False Positive Rate)', fontsize=label_font_size)
    ax.set_ylabel('Sensitivity (True Positive Rate)', fontsize=label_font_size)
    ax.tick_params('x', labelsize=tick_label_size)
    ax.tick_params('y', labelsize=tick_label_size)
    
    if plot_legend:
        legend_markers, legend_labels = ax.get_legend_handles_labels()
        sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)
        sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)
        sd_marker = (sd1_patch, sd2_patch)
        sd_labels = '± s.d.'
        legend_markers.append(sd_marker)
        legend_labels.append(sd_labels)
        ax.legend(legend_markers, legend_labels, fontsize=label_font_size,
                  handler_map={tuple: HandlerTuple(ndivide=None)})
    
    else:
        # remove legend
        ax.get_legend().remove()
    
    fig = ax.get_figure()
    return fig

In [None]:
custom_params = {"axes.spines.right": False, "axes.spines.top": False, 'figure.figsize':(10,10)}
sns.set_theme(style="whitegrid", rc=custom_params, context="paper", font_scale = 1)

ax = plt.subplot(111)

plot_mrs_roc_auc_curve(transformer_roc_df, transformer_resampled_roc_df, transformer_resampled_roc_std, transformer_roc_aucs,
                       thrivec_resampled_tpr, thrivec_resampled_roc_std, thrivec_roc_auc, selected_fold,
                       ax, plot_legend = True, tick_label_size = 11, label_font_size = 13)

plt.show()

In [None]:
if save_plot_data:
    with open(os.path.join('/Users/jk1/Downloads', f'transformer_roc_auc_figure_data.pkl'), 'wb') as f:
        pickle.dump((transformer_roc_df, transformer_resampled_roc_df, transformer_resampled_roc_std, transformer_roc_aucs), f)
    with open(os.path.join('/Users/jk1/Downloads', f'thrivec_roc_auc_figure_data.pkl'), 'wb') as f:
        pickle.dump((thrivec_resampled_tpr, thrivec_resampled_roc_std, thrivec_roc_auc), f)

In [None]:
# fig.savefig(os.path.join(output_dir, f'roc_curve_{outcome.replace(" ", "_")}.svg'), bbox_inches="tight", format='svg', dpi=1200)

### Overall Precision-Recall curve

In [None]:
def plot_mrs_pr_curve(transformer_pr_df, transformer_resampled_pr_df, transformer_resampled_pr_std, transformer_pr_aucs,
                      thrivec_resampled_precision, thrivec_resampled_pr_std, thrivec_pr_auc,
                      selected_fold,
                      ax1, plot_legend = True, tick_label_size = 11, label_font_size = 13):
    
    all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)

    ## Main model: Transformer
    # plot variation across folds (+/- 1 std)
    error = 1*transformer_resampled_pr_std.values
    lower = transformer_resampled_pr_df[transformer_resampled_pr_df.fold == selected_fold].precision - error
    upper = transformer_resampled_pr_df[transformer_resampled_pr_df.fold == selected_fold].precision + error
    ax1.fill_between(transformer_resampled_pr_std.index, lower, upper, alpha=0.2, color=all_colors_palette[0])
    
    # Plot selected fold in bold
    ax1 = sns.lineplot(data=transformer_pr_df[transformer_pr_df.fold == selected_fold], x='recall', y='precision', color=all_colors_palette[0], label='Transformer (area = %0.2f)' % transformer_pr_aucs[selected_fold],
                       ax=ax1, errorbar=None)
    
    
    ## Comparators: THRIVE-C
    # plot variation across folds (+/- 1 std)
    error = 1*thrivec_resampled_pr_std.values
    lower = thrivec_resampled_precision - error
    upper = thrivec_resampled_precision + error
    ax1.fill_between(np.linspace(0, 1, 200), lower, upper, alpha=0.2, color=all_colors_palette[1])
    
    # Plot THRIVE-C in bold
    sns.lineplot(x=np.linspace(0, 1, 200), y=thrivec_resampled_precision, color=all_colors_palette[1], label='THRIVE-C (area = %0.2f)' % thrivec_pr_auc,
                 ax=ax1, linewidth=2)
    
    ax1.set_xlabel('Recall', fontsize=label_font_size)
    ax1.set_ylabel('Precision', fontsize=label_font_size)
    ax1.tick_params('x', labelsize=tick_label_size)
    ax1.tick_params('y', labelsize=tick_label_size)
    
    if plot_legend:
        legend_markers, legend_labels = ax1.get_legend_handles_labels()
        sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)
        sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)
        # sd3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)
        # sd4_patch = mpatches.Patch(color=all_colors_palette[3], alpha=0.3)
        # sd_marker = (sd1_patch, sd2_patch, sd3_patch, sd4_patch)
        sd_marker = (sd1_patch, sd2_patch)
        sd_labels = '± s.d.'
        legend_markers.append(sd_marker)
        legend_labels.append(sd_labels)
        ax1.legend(legend_markers, legend_labels, fontsize=label_font_size,
                  handler_map={tuple: HandlerTuple(ndivide=None)})
    
    else:
        # remove legend
        ax1.get_legend().remove()
    
    fig1 = ax1.get_figure()
    return fig1

In [None]:
custom_params = {"axes.spines.right": False, "axes.spines.top": False, 'figure.figsize':(10,10)}
sns.set_theme(style="whitegrid", rc=custom_params, context="paper", font_scale = 1)

ax1 = plt.subplot(111)

fig1 = plot_mrs_pr_curve(transformer_pr_df, transformer_resampled_pr_df, transformer_resampled_pr_std,
                      thrivec_resampled_precision, thrivec_resampled_pr_std, selected_fold,
                      ax1, plot_legend = True, tick_label_size = 11, label_font_size = 13)

plt.show()

In [None]:
if save_plot_data:
    with open(os.path.join('/Users/jk1/Downloads', f'transformer_pr_figure_data.pkl'), 'wb') as f:
        pickle.dump((transformer_pr_df, transformer_resampled_pr_df, transformer_resampled_pr_std, transformer_pr_aucs), f)
    with open(os.path.join('/Users/jk1/Downloads', f'thrivec_pr_figure_data.pkl'), 'wb') as f:
        pickle.dump((thrivec_resampled_precision, thrivec_resampled_pr_std, thrivec_pr_auc), f)

In [None]:
# fig1.savefig(os.path.join(output_dir, f'precision_recall_curve_{outcome.replace(" ", "_")}.svg'), bbox_inches="tight", format='svg', dpi=1200)

# Survival outcome

In [None]:
outcome = '3M Death'
selected_fold = death_selected_fold

### Compute ROC and PR curve standard deviation for THRIVE-C


In [None]:
# split thrivec data into 5 folds
thrivec_folds = []
for fidx in range(n_folds):
    thrivec_folds.append((thrivec_death_gt[fidx::n_folds], thrivec_death_predictions[fidx::n_folds]))

thrivec_fpr, thrivec_tpr, _ = roc_curve(thrivec_death_gt, thrivec_death_predictions)
thrivec_roc_auc = auc(thrivec_fpr, thrivec_tpr)
thrivec_resampled_tpr = np.interp(np.linspace(0, 1, 200), thrivec_fpr, thrivec_tpr)

thrivec_precision, thrivec_recall, _ = precision_recall_curve(thrivec_death_gt, thrivec_death_predictions)
thrivec_recall, thrivec_precision = zip(*sorted(zip(thrivec_recall, thrivec_precision)))
thrivec_pr_auc = auc(thrivec_recall, thrivec_precision)
thrivec_resampled_precision = np.interp(np.linspace(0, 1, 200), thrivec_recall, thrivec_precision)

thrivec_roc_df, thrivec_resampled_roc_df, thrivec_roc_aucs, thrivec_pr_df, thrivec_resampled_pr_df, thrivec_pr_aucs = compute_roc_and_pr_curves(
    thrivec_folds)

thrivec_resampled_roc_std = thrivec_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
thrivec_resampled_pr_std = thrivec_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision'][
    'std']

### Compute ROC and PR curve standard deviation for Transformer

Hold out data

In [None]:
transformer_death_roc_df, transformer_death_resampled_roc_df, transformer_death_roc_aucs, transformer_death_pr_df, transformer_death_resampled_pr_df, transformer_death_pr_aucs = compute_roc_and_pr_curves(transformer_death_folds)

transformer_death_resampled_roc_std = transformer_death_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
transformer_death_resampled_pr_std = transformer_death_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

External data

In [None]:
transformer_ext_death_roc_df, transformer_ext_death_resampled_roc_df, transformer_ext_death_roc_aucs, transformer_ext_death_pr_df, transformer_ext_death_resampled_pr_df, transformer_ext_death_pr_aucs = compute_roc_and_pr_curves(transformer_ext_death_folds)

transformer_ext_death_resampled_roc_std = transformer_ext_death_resampled_roc_df.groupby('fpr').agg({'tpr': ['mean', 'std']})['tpr']['std']
transformer_ext_death_resampled_pr_std = transformer_ext_death_resampled_pr_df.groupby('recall').agg({'precision': ['mean', 'std']})['precision']['std']

## Resampled ROC curve with inter fold variation

In [None]:
def plot_death_roc_curve(
                    transformer_death_roc_df, transformer_death_resampled_roc_df, transformer_death_resampled_roc_std, transformer_death_roc_aucs,
                    transformer_ext_death_roc_df, transformer_ext_death_resampled_roc_df, transformer_ext_death_resampled_roc_std, transformer_ext_death_roc_aucs,
                    thrivec_resampled_tpr, thrivec_resampled_roc_std, thrivec_roc_auc, 
                    ax, selected_fold, 
                        plot_legend = True, tick_label_size = 11, label_font_size = 13):
    
    all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)


    ## Main model: Transformer
    # plot variation across folds (+/- 1 std)
    error = 1*transformer_death_resampled_roc_std.values
    lower = transformer_death_resampled_roc_df[transformer_death_resampled_roc_df.fold == selected_fold].tpr - error
    upper = transformer_death_resampled_roc_df[transformer_death_resampled_roc_df.fold == selected_fold].tpr + error
    ax.fill_between(transformer_death_resampled_roc_std.index, lower, upper, alpha=0.2, color=all_colors_palette[0])

    # Plot selected fold in bold
    ax = sns.lineplot(data=transformer_death_roc_df[transformer_death_roc_df.fold == selected_fold], x='fpr', y='tpr', color=all_colors_palette[0], label='Transformer (area = %0.2f)' % np.median(transformer_death_roc_aucs[selected_fold]),
                    ax=ax, errorbar=None)

    ## Main model in external data
    # plot variation across folds (+/- 1 std)
    error = 1*transformer_ext_death_resampled_roc_std.values
    lower = transformer_ext_death_resampled_roc_df[transformer_ext_death_resampled_roc_df.fold == selected_fold].tpr - error
    upper = transformer_ext_death_resampled_roc_df[transformer_ext_death_resampled_roc_df.fold == selected_fold].tpr + error
    ax.fill_between(transformer_ext_death_resampled_roc_std.index, lower, upper, alpha=0.2, color=all_colors_palette[3])

    # Plot selected fold in bold
    ax = sns.lineplot(data=transformer_ext_death_roc_df[transformer_ext_death_roc_df.fold == selected_fold], x='fpr', y='tpr', color=all_colors_palette[3], label='Transformer MIMIC (area = %0.2f)' % np.median(transformer_ext_death_roc_aucs[selected_fold]),
                    ax=ax, errorbar=None)

    ## Comparators: THRIVE-C
    # plot variation across folds (+/- 1 std)
    error = 1*thrivec_resampled_roc_std.values
    lower = thrivec_resampled_tpr - error
    upper = thrivec_resampled_tpr + error
    ax.fill_between(np.linspace(0, 1, 200), lower, upper, alpha=0.2, color=all_colors_palette[1])

    # Plot THRIVE-C in bold
    sns.lineplot(x=np.linspace(0, 1, 200), y=thrivec_resampled_tpr, color=all_colors_palette[1], label='THRIVE-C (area = %0.2f)' % thrivec_roc_auc,
                ax=ax, linewidth=2)


    ax.plot([0, 1], [0, 1], color='grey', lw=1, linestyle='--', alpha=0.5)

    ax.set_xlabel('1 - Specificity (False Positive Rate)', fontsize=label_font_size)
    ax.set_ylabel('Sensitivity (True Positive Rate)', fontsize=label_font_size)
    ax.tick_params('x', labelsize=tick_label_size)
    ax.tick_params('y', labelsize=tick_label_size)

    if plot_legend:
        legend_markers, legend_labels = ax.get_legend_handles_labels()
        sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)
        sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)
        sd3_patch = mpatches.Patch(color=all_colors_palette[3], alpha=0.3)
        sd_marker = (sd1_patch, sd2_patch, sd3_patch)
        sd_labels = '± s.d.'
        legend_markers.append(sd_marker)
        legend_labels.append(sd_labels)
        ax.legend(legend_markers, legend_labels, fontsize=label_font_size,
                handler_map={tuple: HandlerTuple(ndivide=None)})

    else:
        # remove legend
        ax.get_legend().remove()

    fig = ax.get_figure()
    return fig

In [None]:
custom_params = {"axes.spines.right": False, "axes.spines.top": False, 'figure.figsize':(10,10)}
sns.set_theme(style="whitegrid", rc=custom_params, context="paper", font_scale = 1)

ax = plt.subplot(111)

plot_death_roc_curve(
    transformer_death_roc_df, transformer_death_resampled_roc_df, transformer_death_resampled_roc_std, transformer_death_roc_aucs,
                    transformer_ext_death_roc_df, transformer_ext_death_resampled_roc_df, transformer_ext_death_resampled_roc_std, transformer_ext_death_roc_aucs,
                    thrivec_resampled_tpr, thrivec_resampled_roc_std, thrivec_roc_auc, 
                    ax, selected_fold
)

plt.show()

In [None]:
if save_plot_data:
    with open(os.path.join('/Users/jk1/Downloads', f'transformer_death_roc_figure_data.pkl'), 'wb') as f:
        pickle.dump((transformer_death_roc_df, transformer_death_resampled_roc_df, transformer_death_resampled_roc_std, transformer_death_roc_aucs), f)
    with open(os.path.join('/Users/jk1/Downloads', f'transformer_ext_death_roc_figure_data.pkl'), 'wb') as f:
        pickle.dump((transformer_ext_death_roc_df, transformer_ext_death_resampled_roc_df, transformer_ext_death_resampled_roc_std, transformer_ext_death_roc_aucs), f)
    with open(os.path.join('/Users/jk1/Downloads', f'thrivec_death_roc_figure_data.pkl'), 'wb') as f:
        pickle.dump((thrivec_resampled_tpr, thrivec_resampled_roc_std, thrivec_roc_auc), f)

In [None]:
# fig.savefig(os.path.join(output_dir, f'roc_curve_{outcome.replace(" ", "_")}.svg'), bbox_inches="tight", format='svg', dpi=1200)

## Resampled PR curve with inter fold variation

In [None]:
def plot_death_pr_curve(transformer_death_pr_df, transformer_death_resampled_pr_df, transformer_death_resampled_pr_std, transformer_death_pr_aucs,
                        transformer_ext_death_pr_df, transformer_ext_death_resampled_pr_df, transformer_ext_death_resampled_pr_std, transformer_ext_death_pr_aucs,
                        thrivec_resampled_precision, thrivec_resampled_pr_std, thrivec_pr_auc,  
                        ax1, selected_fold, 
                                            plot_legend = True, tick_label_size = 11, label_font_size = 13):
    all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)

    ## Main model: Transformer
    # plot variation across folds (+/- 1 std)
    error = 1*transformer_death_resampled_pr_std.values
    lower = transformer_death_resampled_pr_df[transformer_death_resampled_pr_df.fold == selected_fold].precision - error
    upper = transformer_death_resampled_pr_df[transformer_death_resampled_pr_df.fold == selected_fold].precision + error
    ax1.fill_between(transformer_death_resampled_pr_std.index, lower, upper, alpha=0.2, color=all_colors_palette[0])

    # Plot selected fold in bold
    ax1 = sns.lineplot(data=transformer_death_pr_df[transformer_death_pr_df.fold == selected_fold], x='recall', y='precision', color=all_colors_palette[0], label='Transformer (area = %0.2f)' % transformer_death_pr_aucs[selected_fold],
                    ax=ax1, errorbar=None)

    ## Main model in external data
    # plot variation across folds (+/- 1 std)
    error = 1*transformer_ext_death_resampled_pr_std.values
    lower = transformer_ext_death_resampled_pr_df[transformer_ext_death_resampled_pr_df.fold == selected_fold].precision - error
    upper = transformer_ext_death_resampled_pr_df[transformer_ext_death_resampled_pr_df.fold == selected_fold].precision + error
    ax1.fill_between(transformer_ext_death_resampled_pr_std.index, lower, upper, alpha=0.2, color=all_colors_palette[3])

    # Plot selected fold in bold
    ax1 = sns.lineplot(data=transformer_ext_death_pr_df[transformer_ext_death_pr_df.fold == selected_fold], x='recall', y='precision', color=all_colors_palette[3], label='Transformer MIMIC (area = %0.2f)' % transformer_ext_death_pr_aucs[selected_fold],
                    ax=ax1, errorbar=None)

    ## Comparators: THRIVE-C
    # plot variation across folds (+/- 1 std)
    error = 1*thrivec_resampled_pr_std.values
    lower = thrivec_resampled_precision - error
    upper = thrivec_resampled_precision + error
    ax1.fill_between(np.linspace(0, 1, 200), lower, upper, alpha=0.2, color=all_colors_palette[1])

    # Plot THRIVE-C in bold
    sns.lineplot(x=np.linspace(0, 1, 200), y=thrivec_resampled_precision, color=all_colors_palette[1], label='THRIVE-C (area = %0.2f)' % thrivec_pr_auc,
                ax=ax1, linewidth=2)

    ax1.set_xlabel('Recall', fontsize=label_font_size)
    ax1.set_ylabel('Precision', fontsize=label_font_size)
    ax1.tick_params('x', labelsize=tick_label_size)
    ax1.tick_params('y', labelsize=tick_label_size)

    if plot_legend:
        legend_markers, legend_labels = ax1.get_legend_handles_labels()
        sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)
        sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)
        sd3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)
        sd_marker = (sd1_patch, sd2_patch, sd3_patch)
        sd_labels = '± s.d.'
        legend_markers.append(sd_marker)
        legend_labels.append(sd_labels)
        ax1.legend(legend_markers, legend_labels, fontsize=label_font_size,
                handler_map={tuple: HandlerTuple(ndivide=None)})

    else:
        # remove legend
        ax1.get_legend().remove()

    fig1 = ax1.get_figure()
    return fig1

In [None]:
custom_params = {"axes.spines.right": False, "axes.spines.top": False, 'figure.figsize':(10,10)}
sns.set_theme(style="whitegrid", rc=custom_params, context="paper", font_scale = 1)

ax1 = plt.subplot(111)

plot_death_pr_curve(transformer_death_pr_df, transformer_death_resampled_pr_df, transformer_death_resampled_pr_std, transformer_death_pr_aucs,  
                    transformer_ext_death_pr_df, transformer_ext_death_resampled_pr_df, transformer_ext_death_resampled_pr_std, transformer_ext_death_pr_aucs,
                    thrivec_resampled_precision, thrivec_resampled_pr_std, thrivec_pr_auc,
                    ax1, selected_fold, plot_legend=True, tick_label_size=11, label_font_size=13)

plt.show()

In [None]:
if save_plot_data:
    with open(os.path.join('/Users/jk1/Downloads', f'transformer_death_pr_figure_data.pkl'), 'wb') as f:
        pickle.dump((transformer_death_pr_df, transformer_death_resampled_pr_df, transformer_death_resampled_pr_std, transformer_death_pr_aucs), f)
    with open(os.path.join('/Users/jk1/Downloads', f'transformer_ext_death_pr_figure_data.pkl'), 'wb') as f:
        pickle.dump((transformer_ext_death_pr_df, transformer_ext_death_resampled_pr_df, transformer_ext_death_resampled_pr_std, transformer_ext_death_pr_aucs), f)
    with open(os.path.join('/Users/jk1/Downloads', f'thrivec_death_pr_figure_data.pkl'), 'wb') as f:
        pickle.dump((thrivec_resampled_precision, thrivec_resampled_pr_std, thrivec_pr_auc), f)

In [None]:
# fig1.savefig(os.path.join(output_dir, f'precision_recall_curve_{outcome.replace(" ", "_")}.svg'), bbox_inches="tight", format='svg', dpi=1200)