# ROC Curve comparison with confidence intervals

In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
import pandas as pd
from sklearn.metrics import roc_curve, auc
import matplotlib.patches as mpatches
from matplotlib.legend_handler import HandlerTuple

In [None]:
lstm_bs_predictions_path = '/Users/jk1/temp/opsum_prediction_output/LSTM_72h_testing/3M_mRS01/2023_01_06_1847/test_LSTM_sigmoid_all_unchanged_0.0_2_True_RMSprop_3M mRS 0-1_128_3/bootstrapped_gt_and_pred.pkl'
lstm_predictions_path = '/Users/jk1/temp/opsum_prediction_output/LSTM_72h_testing/3M_mRS01/2023_01_06_1847/test_LSTM_sigmoid_all_unchanged_0.0_2_True_RMSprop_3M mRS 0-1_128_3/test_gt_and_pred.pkl'
outcome = '3M mRS 0-1'

In [None]:
lstm_bs_gt, lstm_bs_predictions = pickle.load(open(lstm_bs_predictions_path, 'rb'))
lstm_gt, lstm_test_predictions = pickle.load(open(lstm_predictions_path, 'rb'))

In [None]:
all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)
all_colors_palette

Compute overall ROC curve and ROC area

In [None]:
test_fpr, test_tpr, _ = roc_curve(lstm_gt, lstm_test_predictions)
test_roc_auc = auc(test_fpr, test_tpr)

In [None]:
lstm_test_df = pd.DataFrame({'fpr': test_fpr, 'tpr': test_tpr})

Compute ROC curves and ROC area for each bootstrap sample

In [None]:
lstm_bs_fprs = []
lstm_bs_tprs = []
lstm_bs_aucs = []
lstm_bs_df = pd.DataFrame()
resampled_bs_df = pd.DataFrame()
for idx in tqdm(range(len(lstm_bs_predictions))):
    # calculate the ROC curve and AUC
    bs_fpr, bs_tpr, _ = roc_curve(lstm_bs_gt[idx], lstm_bs_predictions[idx])
    bs_roc_auc = auc(bs_fpr, bs_tpr)
    lstm_bs_fprs.append(bs_fpr)
    lstm_bs_tprs.append(bs_tpr)
    lstm_bs_aucs.append(bs_roc_auc)

    bs_df = pd.DataFrame({'fpr': bs_fpr, 'tpr': bs_tpr, 'auc': bs_roc_auc})
    bs_df['bootstrap_idx'] = idx
    lstm_bs_df = lstm_bs_df.append(bs_df)

    bs_resampled_tpr = np.interp(np.linspace(0, 1, 200), bs_fpr, bs_tpr)
    bs_resampled_df = pd.DataFrame({'fpr': np.linspace(0,1,200),  'tpr': bs_resampled_tpr})
    bs_resampled_df['bootstrap_idx'] = idx
    resampled_bs_df = resampled_bs_df.append(bs_resampled_df)

In [None]:
# get medians
median_roc_auc = np.percentile(lstm_bs_aucs, 50)

# get 95% interval
alpha = 100 - 95
lower_ci_roc_auc = np.percentile(lstm_bs_aucs, alpha / 2)
upper_ci_roc_auc = np.percentile(lstm_bs_aucs, 100 - alpha / 2)

In [None]:
sorted_lower_ci_idx = int((alpha / 2) * len(lstm_bs_aucs) / 100 + 0.5)
unsorted_lower_ci_idx = np.where(lstm_bs_aucs == sorted(lstm_bs_aucs)[sorted_lower_ci_idx])[0][0]
sorted_upper_ci_idx = int((100 - alpha / 2) * len(lstm_bs_aucs) / 100 + 0.5)
unsorted_upper_ci_idx = np.where(lstm_bs_aucs == sorted(lstm_bs_aucs)[sorted_upper_ci_idx])[0][0]
sorted_median_idx = int(50 * len(lstm_bs_aucs) / 100 + 0.5)
unsorted_median_idx = np.where(lstm_bs_aucs == sorted(lstm_bs_aucs)[sorted_median_idx])[0][0]

In [None]:
lower_ci_fpr = lstm_bs_fprs[unsorted_lower_ci_idx]
lower_ci_tpr = lstm_bs_tprs[unsorted_lower_ci_idx]
upper_ci_fpr = lstm_bs_fprs[unsorted_upper_ci_idx]
upper_ci_tpr = lstm_bs_tprs[unsorted_upper_ci_idx]
median_fpr = lstm_bs_fprs[unsorted_median_idx]
median_tpr = lstm_bs_tprs[unsorted_median_idx]

In [None]:
# subsample upper and lower ci to 150 points between 0 and 1 (to have same number of points as for plotting)
sub_lower_ci_tpr = np.interp(np.linspace(0, 1, 150), lower_ci_fpr, lower_ci_tpr)
sub_upper_ci_tpr = np.interp(np.linspace(0, 1, 150), upper_ci_fpr, upper_ci_tpr)

### Overall ROC curve


In [None]:
ax = plt.subplot(111)

plot_legend = True

tick_label_size = 11
label_font_size = 13

custom_params = {"axes.spines.right": False, "axes.spines.top": False, 'figure.figsize':(10,10)}
sns.set_theme(style="whitegrid", rc=custom_params, context="paper", font_scale = 1)

ax = sns.lineplot(data=lstm_test_df, x='fpr', y='tpr', color=all_colors_palette[0], label='LSTM (area = %0.2f)' % test_roc_auc,
                   ax=ax, errorbar='sd')

ax.plot([0, 1], [0, 1], color='grey', lw=1, linestyle='--', alpha=0.5)

ax.set_xlabel('1 - Specificity (False Positive Rate)', fontsize=label_font_size)
ax.set_ylabel('Sensitivity (True Positive Rate)', fontsize=label_font_size)
ax.tick_params('x', labelsize=tick_label_size)
ax.tick_params('y', labelsize=tick_label_size)

if plot_legend:
    legend_markers, legend_labels = ax.get_legend_handles_labels()
    sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)
    # sd2_patch = mpatches.Patch(color=(0.00392156862745098, 0.17647058823529413, 0.596078431372549, 0.3))
    # sd3_patch = mpatches.Patch(color=(0.6549019607843137, 0.42745098039215684, 0.996078431372549, 0.3))
    sd_marker = (sd1_patch)
    sd_labels = '± s.d.'
    legend_markers.append(sd_marker)
    legend_labels.append(sd_labels)
    ax.legend(legend_markers, legend_labels, fontsize=tick_label_size,
              handler_map={tuple: HandlerTuple(ndivide=None)})

else:
    # remove legend
    ax.get_legend().remove()






### ROC curve after boostrapping

Confidence interval by taking corresponding roc curves

In [None]:
# plot roc curves for median, and fill between lower and upper confidence intervals
plt.figure(figsize=(10, 10))
ax = plt.subplot(111)

# plot median
ax.plot(median_fpr, median_tpr, color=all_colors_palette[0], lw=2, label='Median ROC curve (area = %0.2f)' % median_roc_auc)
ax.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--')

ax.fill_between(np.linspace(0, 1, 150), sub_lower_ci_tpr, sub_upper_ci_tpr, color=all_colors_palette[0], alpha=.2,
                label=r'95% CI')

Confidence interval by resampling

In [None]:
plt.figure(figsize=(10, 10))
ax1 = plt.subplot(111)


ax1 = sns.lineplot(data=resampled_bs_df.groupby('fpr').tpr.quantile(0.5).reset_index(), x='fpr', y='tpr', color=all_colors_palette[0], lw=2, ax=ax1)
ax1.fill_between(resampled_bs_df.groupby('fpr').tpr.quantile(alpha / 2 / 100).reset_index().fpr.values,
                resampled_bs_df.groupby('fpr').tpr.quantile(alpha / 2 / 100).reset_index().tpr.values,
                resampled_bs_df.groupby('fpr').tpr.quantile(1 - alpha / 2 / 100).reset_index().tpr.values,
                color=all_colors_palette[0], alpha=.2)

In [None]:
sns.lineplot(data=resampled_bs_df, x='fpr', y='tpr', color=all_colors_palette[0], lw=2, errorbar='sd')

In [None]:
sns.lineplot(data=lstm_bs_df, x='fpr', y='tpr', color=all_colors_palette[0], lw=2, errorbar='sd')
