# Figure with AUROC over time, ROC and PR curves for both functional and mortality outcomes

In [None]:
import pandas as pd
import pickle
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
import matplotlib.patches as mpatches
from sklearn.utils import resample
from tqdm import tqdm
from matplotlib.legend_handler import HandlerTuple
import matplotlib.lines as mlines


In [None]:
output_dir = '/Users/jk1/Downloads/'

In [None]:
# data for AUROC over time figure
auroc_over_time_data_path = '/Users/jk1/temp/opsum_figure_temp_data/performance/auroc_over_time_fig/Transformer_roc_auc_scores_over_time.pkl' 
auroc_std_over_time_path = '/Users/jk1/temp/opsum_figure_temp_data/performance/auroc_over_time_fig/Transformer_roc_auc_scores_over_time_std.pkl'
# data for ROC and PR curves
pr_and_roc_data_path = '/Users/jk1/temp/opsum_figure_temp_data/performance/roc_and_pr_curve_figs'

In [None]:
with open(os.path.join(auroc_over_time_data_path), 'rb') as f:
        mr02_selected_fold_bootstrapped_roc_auc_scores, death_selected_fold_bootstrapped_roc_auc_scores = pickle.load(f)
with open(os.path.join(auroc_std_over_time_path), 'rb') as f:
    mrs02_roc_auc_std, death_roc_auc_std = pickle.load(f)
    
# ROC mrs02
with open(os.path.join(pr_and_roc_data_path, f'transformer_roc_auc_figure_data.pkl'), 'rb') as f:
    transformer_roc_df, transformer_resampled_roc_df, transformer_resampled_roc_std, transformer_roc_aucs = pickle.load(f)
with open(os.path.join(pr_and_roc_data_path, f'thrivec_roc_auc_figure_data.pkl'), 'rb') as f:
    thrivec_resampled_tpr, thrivec_resampled_roc_std, thrivec_roc_auc = pickle.load(f)
# PR mrs02
with open(os.path.join(pr_and_roc_data_path, f'transformer_pr_figure_data.pkl'), 'rb') as f:
    transformer_pr_df, transformer_resampled_pr_df, transformer_resampled_pr_std, transformer_pr_aucs = pickle.load(f)
with open(os.path.join(pr_and_roc_data_path, f'thrivec_pr_figure_data.pkl'), 'rb') as f:
    thrivec_resampled_precision, thrivec_resampled_pr_std, thrivec_pr_auc = pickle.load(f)
    
# ROC death
with open(os.path.join(pr_and_roc_data_path, f'transformer_death_roc_figure_data.pkl'), 'rb') as f:
    transformer_death_roc_df, transformer_death_resampled_roc_df, transformer_death_resampled_roc_std, transformer_death_roc_aucs = pickle.load(f)
with open(os.path.join(pr_and_roc_data_path, f'transformer_ext_death_roc_figure_data.pkl'), 'rb') as f:
    transformer_ext_death_roc_df, transformer_ext_death_resampled_roc_df, transformer_ext_death_resampled_roc_std, transformer_ext_death_roc_aucs = pickle.load(f)
with open(os.path.join(pr_and_roc_data_path, f'thrivec_death_roc_figure_data.pkl'), 'rb') as f:
    thrivec_death_resampled_tpr, thrivec_death_resampled_roc_std, thrivec_death_roc_auc = pickle.load(f)

# PR death
with open(os.path.join(pr_and_roc_data_path, f'transformer_death_pr_figure_data.pkl'), 'rb') as f:
    transformer_death_pr_df, transformer_death_resampled_pr_df, transformer_death_resampled_pr_std, transformer_death_pr_aucs = pickle.load(f)
with open(os.path.join(pr_and_roc_data_path, f'transformer_ext_death_pr_figure_data.pkl'), 'rb') as f:
    transformer_ext_death_pr_df, transformer_ext_death_resampled_pr_df, transformer_ext_death_resampled_pr_std, transformer_ext_death_pr_aucs = pickle.load(f)
with open(os.path.join(pr_and_roc_data_path, f'thrivec_death_pr_figure_data.pkl'), 'rb') as f:
    thrivec_death_resampled_precision, thrivec_death_resampled_pr_std, thrivec_death_pr_auc = pickle.load(f)

In [None]:
mrs02_selected_fold = 2
death_selected_fold = 1

### Plotting functions for AUROC over time

In [None]:
def plot_auroc_over_time(mr02_roc_aucs, mrs02_roc_auc_std, death_roc_aucs, death_roc_auc_std, ax, plot_zoom = True, plot_title = False, plot_legend = True, tick_label_size = 11, label_font_size = 13, errorbar = 'sd'):
    all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)

    ## MRS02
    # plot variation across folds (+/- 1 std)
    error = 1*mrs02_roc_auc_std.values
    baseline = mr02_roc_aucs.groupby('n_hours').agg({'roc_auc_score': ['median']}).roc_auc_score['median'].values
    lower = baseline - error
    upper = baseline + error
    ax.fill_between(mrs02_roc_auc_std.index, lower, upper, alpha=0.2, color=all_colors_palette[2])
    
    sns.lineplot(x='n_hours', y='roc_auc_score', data=mr02_roc_aucs, legend=True, ax=ax, errorbar=None, color=all_colors_palette[2])
    
    ## Death
    # plot variation across folds (+/- 1 std)
    error = 1*death_roc_auc_std.values
    baseline = death_roc_aucs.groupby('n_hours').agg({'roc_auc_score': ['median']}).roc_auc_score['median'].values
    lower = baseline - error
    upper = baseline + error
    ax.fill_between(death_roc_auc_std.index, lower, upper, alpha=0.2, color=all_colors_palette[3])
    
    sns.lineplot(x='n_hours', y='roc_auc_score', data=death_roc_aucs, legend=True, ax=ax, errorbar=None, color=all_colors_palette[3])
    
    if plot_title:
        ax.set_title(f'{model_name} performance in the holdout test dataset as a function of observation period')
    
    ax.set_xlabel('Time after admission (hours)', fontsize=label_font_size)
    ax.set_ylabel('ROC AUC', fontsize=label_font_size)
    ax.set_ylim([0, 1])
    ax.tick_params('x', labelsize=tick_label_size)
    ax.tick_params('y', labelsize=tick_label_size)
    
    if plot_zoom:
        ax2 = ax.inset_axes([0.2, 0.25, .7, .5], facecolor='w')
        ## MRS02
        # plot variation across folds (+/- 1 std)
        error = 1*mrs02_roc_auc_std.values
        baseline = mr02_roc_aucs.groupby('n_hours').agg({'roc_auc_score': ['median']}).roc_auc_score['median'].values
        lower = baseline - error
        upper = baseline + error
        ax2.fill_between(mrs02_roc_auc_std.index, lower, upper, alpha=0.2, color=all_colors_palette[2])
    
        sns.lineplot(x='n_hours', y='roc_auc_score', data=mr02_roc_aucs, legend=True, ax=ax2, errorbar=None, color=all_colors_palette[2])
    
        ## Death
        # plot variation across folds (+/- 1 std)
        error = 1*death_roc_auc_std.values
        baseline = death_roc_aucs.groupby('n_hours').agg({'roc_auc_score': ['median']}).roc_auc_score['median'].values
        lower = baseline - error
        upper = baseline + error
        ax2.fill_between(death_roc_auc_std.index, lower, upper, alpha=0.2, color=all_colors_palette[3])
    
        sns.lineplot(x='n_hours', y='roc_auc_score', data=death_roc_aucs, legend=True, ax=ax2, errorbar=None, color=all_colors_palette[3])
    
        ax2.set_title('Zoomed in', fontsize=label_font_size)
        ax2.set_ybound(0.8,0.92)
        ax2.set_xlabel('Time after admission (hours)', fontsize=label_font_size)
        ax2.set_ylabel('ROC AUC', fontsize=label_font_size)
        ax2.tick_params('x', labelsize=tick_label_size)
        ax2.tick_params('y', labelsize=tick_label_size)
    
    if plot_legend:
        legend_markers, legend_labels = ax.get_legend_handles_labels()
        sd1_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)
        sd2_patch = mpatches.Patch(color=all_colors_palette[3], alpha=0.3)
        sd_marker = (sd1_patch, sd2_patch)
        sd_labels = '± s.d.'
        legend_markers.append(sd_marker)
        legend_labels.append(sd_labels)
    
        mrs02_line = mlines.Line2D([], [], color=all_colors_palette[2], linestyle='-')
        mrs02_line_label = 'ROC AUC for functional outcome'
        legend_markers.append(mrs02_line)
        legend_labels.append(mrs02_line_label)
    
        death_line = mlines.Line2D([], [], color=all_colors_palette[3], linestyle='-')
        death_line_label = 'ROC AUC for mortality'
        legend_markers.append(death_line)
        legend_labels.append(death_line_label)
    
        ax.legend(legend_markers, legend_labels, fontsize=label_font_size,
                  handler_map={tuple: HandlerTuple(ndivide=None)})

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))

plot_auroc_over_time(mr02_selected_fold_bootstrapped_roc_auc_scores, mrs02_roc_auc_std,
                     death_selected_fold_bootstrapped_roc_auc_scores, death_roc_auc_std,
                     ax)

plt.show()

### Plotting functions for  ROC & PR curves

In [None]:
def plot_mrs_roc_auc_curve(transformer_roc_df, transformer_resampled_roc_df, transformer_resampled_roc_std, transformer_roc_aucs,
                           thrivec_resampled_tpr, thrivec_resampled_roc_std, thrivec_roc_auc, selected_fold,
                           ax, plot_legend = True, tick_label_size = 11, label_font_size = 13):

    all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)
    
    ## Main model: Transformer
    # plot variation across folds (+/- 1 std)
    error = 1*transformer_resampled_roc_std.values
    lower = transformer_resampled_roc_df[transformer_resampled_roc_df.fold == selected_fold].tpr - error
    upper = transformer_resampled_roc_df[transformer_resampled_roc_df.fold == selected_fold].tpr + error
    ax.fill_between(transformer_resampled_roc_std.index, lower, upper, alpha=0.2, color=all_colors_palette[0])
    
    # Plot selected fold in bold
    ax = sns.lineplot(data=transformer_roc_df[transformer_roc_df.fold == selected_fold], x='fpr', y='tpr', color=all_colors_palette[0], label='Transformer (area = %0.2f)' % transformer_roc_aucs[selected_fold],
                       ax=ax, errorbar=None)
    
    
    ## Comparators: THRIVE-C
    # plot variation across folds (+/- 1 std)
    error = 1*thrivec_resampled_roc_std.values
    lower = thrivec_resampled_tpr - error
    upper = thrivec_resampled_tpr + error
    ax.fill_between(np.linspace(0, 1, 200), lower, upper, alpha=0.2, color=all_colors_palette[1])
    
    # Plot THRIVE-C in bold
    sns.lineplot(x=np.linspace(0, 1, 200), y=thrivec_resampled_tpr, color=all_colors_palette[1], label='THRIVE-C (area = %0.2f)' % thrivec_roc_auc,
                 ax=ax, linewidth=2)
    
    
    ax.plot([0, 1], [0, 1], color='grey', lw=1, linestyle='--', alpha=0.5)
    
    ax.set_xlabel('1 - Specificity (False Positive Rate)', fontsize=label_font_size)
    ax.set_ylabel('Sensitivity (True Positive Rate)', fontsize=label_font_size)
    ax.tick_params('x', labelsize=tick_label_size)
    ax.tick_params('y', labelsize=tick_label_size)
    
    if plot_legend:
        legend_markers, legend_labels = ax.get_legend_handles_labels()
        sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)
        sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)
        sd_marker = (sd1_patch, sd2_patch)
        sd_labels = '± s.d.'
        legend_markers.append(sd_marker)
        legend_labels.append(sd_labels)
        ax.legend(legend_markers, legend_labels, fontsize=label_font_size,
                  handler_map={tuple: HandlerTuple(ndivide=None)})
    
    else:
        # remove legend
        ax.get_legend().remove()
    
    fig = ax.get_figure()
    return fig

In [None]:
def plot_mrs_pr_curve(transformer_pr_df, transformer_resampled_pr_df, transformer_resampled_pr_std, transformer_pr_aucs,
                      thrivec_resampled_precision, thrivec_resampled_pr_std, thrivec_pr_auc,
                      selected_fold,
                      ax1, plot_legend = True, tick_label_size = 11, label_font_size = 13):
    
    all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)

    ## Main model: Transformer
    # plot variation across folds (+/- 1 std)
    error = 1*transformer_resampled_pr_std.values
    lower = transformer_resampled_pr_df[transformer_resampled_pr_df.fold == selected_fold].precision - error
    upper = transformer_resampled_pr_df[transformer_resampled_pr_df.fold == selected_fold].precision + error
    ax1.fill_between(transformer_resampled_pr_std.index, lower, upper, alpha=0.2, color=all_colors_palette[0])
    
    # Plot selected fold in bold
    ax1 = sns.lineplot(data=transformer_pr_df[transformer_pr_df.fold == selected_fold], x='recall', y='precision', color=all_colors_palette[0], label='Transformer (area = %0.2f)' % transformer_pr_aucs[selected_fold],
                       ax=ax1, errorbar=None)
    
    
    ## Comparators: THRIVE-C
    # plot variation across folds (+/- 1 std)
    error = 1*thrivec_resampled_pr_std.values
    lower = thrivec_resampled_precision - error
    upper = thrivec_resampled_precision + error
    ax1.fill_between(np.linspace(0, 1, 200), lower, upper, alpha=0.2, color=all_colors_palette[1])
    
    # Plot THRIVE-C in bold
    sns.lineplot(x=np.linspace(0, 1, 200), y=thrivec_resampled_precision, color=all_colors_palette[1], label='THRIVE-C (area = %0.2f)' % thrivec_pr_auc,
                 ax=ax1, linewidth=2)
    
    ax1.set_xlabel('Recall', fontsize=label_font_size)
    ax1.set_ylabel('Precision', fontsize=label_font_size)
    ax1.tick_params('x', labelsize=tick_label_size)
    ax1.tick_params('y', labelsize=tick_label_size)
    
    if plot_legend:
        legend_markers, legend_labels = ax1.get_legend_handles_labels()
        sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)
        sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)
        # sd3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)
        # sd4_patch = mpatches.Patch(color=all_colors_palette[3], alpha=0.3)
        # sd_marker = (sd1_patch, sd2_patch, sd3_patch, sd4_patch)
        sd_marker = (sd1_patch, sd2_patch)
        sd_labels = '± s.d.'
        legend_markers.append(sd_marker)
        legend_labels.append(sd_labels)
        ax1.legend(legend_markers, legend_labels, fontsize=label_font_size,
                  handler_map={tuple: HandlerTuple(ndivide=None)})
    
    else:
        # remove legend
        ax1.get_legend().remove()
    
    fig1 = ax1.get_figure()
    return fig1

In [None]:
def plot_death_roc_curve(
                    transformer_death_roc_df, transformer_death_resampled_roc_df, transformer_death_resampled_roc_std, transformer_death_roc_aucs,
                    transformer_ext_death_roc_df, transformer_ext_death_resampled_roc_df, transformer_ext_death_resampled_roc_std, transformer_ext_death_roc_aucs,
                    thrivec_resampled_tpr, thrivec_resampled_roc_std, thrivec_roc_auc, 
                    ax, selected_fold, 
                        plot_legend = True, tick_label_size = 11, label_font_size = 13):
    
    all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)


    ## Main model: Transformer
    # plot variation across folds (+/- 1 std)
    error = 1*transformer_death_resampled_roc_std.values
    lower = transformer_death_resampled_roc_df[transformer_death_resampled_roc_df.fold == selected_fold].tpr - error
    upper = transformer_death_resampled_roc_df[transformer_death_resampled_roc_df.fold == selected_fold].tpr + error
    ax.fill_between(transformer_death_resampled_roc_std.index, lower, upper, alpha=0.2, color=all_colors_palette[0])

    # Plot selected fold in bold
    ax = sns.lineplot(data=transformer_death_roc_df[transformer_death_roc_df.fold == selected_fold], x='fpr', y='tpr', color=all_colors_palette[0], label='Transformer (area = %0.2f)' % np.median(transformer_death_roc_aucs[selected_fold]),
                    ax=ax, errorbar=None)

    ## Main model in external data
    # plot variation across folds (+/- 1 std)
    error = 1*transformer_ext_death_resampled_roc_std.values
    lower = transformer_ext_death_resampled_roc_df[transformer_ext_death_resampled_roc_df.fold == selected_fold].tpr - error
    upper = transformer_ext_death_resampled_roc_df[transformer_ext_death_resampled_roc_df.fold == selected_fold].tpr + error
    ax.fill_between(transformer_ext_death_resampled_roc_std.index, lower, upper, alpha=0.2, color=all_colors_palette[3])

    # Plot selected fold in bold
    ax = sns.lineplot(data=transformer_ext_death_roc_df[transformer_ext_death_roc_df.fold == selected_fold], x='fpr', y='tpr', color=all_colors_palette[3], label='Transformer MIMIC (area = %0.2f)' % np.median(transformer_ext_death_roc_aucs[selected_fold]),
                    ax=ax, errorbar=None)

    ## Comparators: THRIVE-C
    # plot variation across folds (+/- 1 std)
    error = 1*thrivec_resampled_roc_std.values
    lower = thrivec_resampled_tpr - error
    upper = thrivec_resampled_tpr + error
    ax.fill_between(np.linspace(0, 1, 200), lower, upper, alpha=0.2, color=all_colors_palette[1])

    # Plot THRIVE-C in bold
    sns.lineplot(x=np.linspace(0, 1, 200), y=thrivec_resampled_tpr, color=all_colors_palette[1], label='THRIVE-C (area = %0.2f)' % thrivec_roc_auc,
                ax=ax, linewidth=2)


    ax.plot([0, 1], [0, 1], color='grey', lw=1, linestyle='--', alpha=0.5)

    ax.set_xlabel('1 - Specificity (False Positive Rate)', fontsize=label_font_size)
    ax.set_ylabel('Sensitivity (True Positive Rate)', fontsize=label_font_size)
    ax.tick_params('x', labelsize=tick_label_size)
    ax.tick_params('y', labelsize=tick_label_size)

    if plot_legend:
        legend_markers, legend_labels = ax.get_legend_handles_labels()
        sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)
        sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)
        sd3_patch = mpatches.Patch(color=all_colors_palette[3], alpha=0.3)
        sd_marker = (sd1_patch, sd2_patch, sd3_patch)
        sd_labels = '± s.d.'
        legend_markers.append(sd_marker)
        legend_labels.append(sd_labels)
        ax.legend(legend_markers, legend_labels, fontsize=label_font_size,
                handler_map={tuple: HandlerTuple(ndivide=None)})

    else:
        # remove legend
        ax.get_legend().remove()

    fig = ax.get_figure()
    return fig

In [None]:
def plot_death_pr_curve(transformer_death_pr_df, transformer_death_resampled_pr_df, transformer_death_resampled_pr_std, transformer_death_pr_aucs,
                        transformer_ext_death_pr_df, transformer_ext_death_resampled_pr_df, transformer_ext_death_resampled_pr_std, transformer_ext_death_pr_aucs,
                        thrivec_resampled_precision, thrivec_resampled_pr_std, thrivec_pr_auc,  
                        ax1, selected_fold, 
                                            plot_legend = True, legend_outside_plot=False, tick_label_size = 11, label_font_size = 13):
    all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)

    ## Main model: Transformer
    # plot variation across folds (+/- 1 std)
    error = 1*transformer_death_resampled_pr_std.values
    lower = transformer_death_resampled_pr_df[transformer_death_resampled_pr_df.fold == selected_fold].precision - error
    upper = transformer_death_resampled_pr_df[transformer_death_resampled_pr_df.fold == selected_fold].precision + error
    ax1.fill_between(transformer_death_resampled_pr_std.index, lower, upper, alpha=0.2, color=all_colors_palette[0])

    # Plot selected fold in bold
    ax1 = sns.lineplot(data=transformer_death_pr_df[transformer_death_pr_df.fold == selected_fold], x='recall', y='precision', color=all_colors_palette[0], label='Transformer (area = %0.2f)' % transformer_death_pr_aucs[selected_fold],
                    ax=ax1, errorbar=None)

    ## Main model in external data
    # plot variation across folds (+/- 1 std)
    error = 1*transformer_ext_death_resampled_pr_std.values
    lower = transformer_ext_death_resampled_pr_df[transformer_ext_death_resampled_pr_df.fold == selected_fold].precision - error
    upper = transformer_ext_death_resampled_pr_df[transformer_ext_death_resampled_pr_df.fold == selected_fold].precision + error
    ax1.fill_between(transformer_ext_death_resampled_pr_std.index, lower, upper, alpha=0.2, color=all_colors_palette[3])

    # Plot selected fold in bold
    ax1 = sns.lineplot(data=transformer_ext_death_pr_df[transformer_ext_death_pr_df.fold == selected_fold], x='recall', y='precision', color=all_colors_palette[3], label='Transformer MIMIC (area = %0.2f)' % transformer_ext_death_pr_aucs[selected_fold],
                    ax=ax1, errorbar=None)

    ## Comparators: THRIVE-C
    # plot variation across folds (+/- 1 std)
    error = 1*thrivec_resampled_pr_std.values
    lower = thrivec_resampled_precision - error
    upper = thrivec_resampled_precision + error
    ax1.fill_between(np.linspace(0, 1, 200), lower, upper, alpha=0.2, color=all_colors_palette[1])

    # Plot THRIVE-C in bold
    sns.lineplot(x=np.linspace(0, 1, 200), y=thrivec_resampled_precision, color=all_colors_palette[1], label='THRIVE-C (area = %0.2f)' % thrivec_pr_auc,
                ax=ax1, linewidth=2)

    ax1.set_xlabel('Recall', fontsize=label_font_size)
    ax1.set_ylabel('Precision', fontsize=label_font_size)
    ax1.tick_params('x', labelsize=tick_label_size)
    ax1.tick_params('y', labelsize=tick_label_size)

    if plot_legend:
        legend_markers, legend_labels = ax1.get_legend_handles_labels()
        sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)
        sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)
        sd3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)
        sd_marker = (sd1_patch, sd2_patch, sd3_patch)
        sd_labels = '± s.d.'
        legend_markers.append(sd_marker)
        legend_labels.append(sd_labels)

        if legend_outside_plot:
            ax1.legend(legend_markers, legend_labels, fontsize=label_font_size,
                handler_map={tuple: HandlerTuple(ndivide=None)}, bbox_to_anchor=(1.1, 0.1), loc=2, borderaxespad=0.)
        else: 
            ax1.legend(legend_markers, legend_labels, fontsize=label_font_size,
                handler_map={tuple: HandlerTuple(ndivide=None)})

    else:
        # remove legend
        ax1.get_legend().remove()

    fig1 = ax1.get_figure()
    return fig1

# Plotting combined figure

In [None]:
import matplotlib
from matplotlib.font_manager import FontProperties

font_files = ['/Library/Fonts/calibri-bold-italic.ttf',
'/Library/Fonts/calibri-bold.ttf',
'/Library/Fonts/calibri-italic.ttf',
'/Library/Fonts/calibri-regular.ttf',
'/Library/Fonts/calibril.ttf']

font_path = font_files[-1]
calibri_font = FontProperties(fname=font_path)
calibri_font.get_name()

for font_file in font_files:
    matplotlib.font_manager.fontManager.addfont(font_file)

In [None]:
sns.set_theme(style="whitegrid", context="paper", font_scale = 1)
plt.rcParams['font.family'] = calibri_font.get_name()

cm = 1/2.54  # centimeters in inches
fig = plt.figure(figsize=(18 * cm, 20 * cm))
subfigs = fig.subfigures(2, 2, wspace=0.07, width_ratios=[2, 1], height_ratios=[2, 1])

tick_label_size = 6
label_font_size = 7
subplot_number_font_size = 9
suptitle_font_size = 10
suptitle_font_weight = 'regular'
plot_subplot_titles = True

#######
# Upper left subfigure: ROC AUC over time
# title should be aligned to left
subfigs[0, 0].suptitle('I. Performance over time', fontsize=suptitle_font_size, horizontalalignment='left', x=0.1, y=0.95, weight=suptitle_font_weight)

axsULeft = subfigs[0, 0].subplots(1, 1)
plot_auroc_over_time(mr02_selected_fold_bootstrapped_roc_auc_scores, mrs02_roc_auc_std,
                     death_selected_fold_bootstrapped_roc_auc_scores, death_roc_auc_std,
                     axsULeft, 
                     tick_label_size=tick_label_size, label_font_size=label_font_size)

#######
# Lower left subfigure: PR and ROC curves for death
subfigs[1, 0].suptitle('III. Prediction of mortality', fontsize=suptitle_font_size, horizontalalignment='left', x=0.1, y=1.025, weight=suptitle_font_weight)

axsLLeft = subfigs[1, 0].subplots(1, 2)
# increase space between subplots
subfigs[1, 0].subplots_adjust(wspace=0.3)

# ROC curve for death
plot_death_roc_curve(
    transformer_death_roc_df, transformer_death_resampled_roc_df, transformer_death_resampled_roc_std, transformer_death_roc_aucs,
                    transformer_ext_death_roc_df, transformer_ext_death_resampled_roc_df, transformer_ext_death_resampled_roc_std, transformer_ext_death_roc_aucs,
                    thrivec_death_resampled_tpr, thrivec_death_resampled_roc_std, thrivec_death_roc_auc, 
                    axsLLeft[0], death_selected_fold,
                    plot_legend=False,
                    tick_label_size=tick_label_size, label_font_size=label_font_size
)
if plot_subplot_titles:
    axsLLeft[0].set_title('A.', fontsize=subplot_number_font_size, horizontalalignment='left', x=-0.1, weight=suptitle_font_weight)

# PR curve for death
plot_death_pr_curve(transformer_death_pr_df, transformer_death_resampled_pr_df, transformer_death_resampled_pr_std, transformer_death_pr_aucs,  
                    transformer_ext_death_pr_df, transformer_ext_death_resampled_pr_df, transformer_ext_death_resampled_pr_std, transformer_ext_death_pr_aucs,
                    thrivec_death_resampled_precision, thrivec_death_resampled_pr_std, thrivec_death_pr_auc,
                    axsLLeft[1], death_selected_fold, plot_legend=True, legend_outside_plot=False,
                    tick_label_size=tick_label_size, label_font_size=label_font_size)
if plot_subplot_titles:
    axsLLeft[1].set_title('B.', fontsize=subplot_number_font_size, horizontalalignment='left', x=-0.1, weight=suptitle_font_weight)

# remove upper and right spines
for ax in axsLLeft:
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

#######
# Upper right subfigure: PR and ROC curves for MRS
axsRight = subfigs[0, 1].subplots(2, 1)
# increase space between subplots
subfigs[0, 1].subplots_adjust(hspace=0.3)
subfigs[0, 1].suptitle('II. Prediction of functional outcome', fontsize=suptitle_font_size, horizontalalignment='left', x=-0.1, y=0.95, weight=suptitle_font_weight)

# ROC curve for MRS
plot_mrs_roc_auc_curve(transformer_roc_df, transformer_resampled_roc_df, transformer_resampled_roc_std, transformer_roc_aucs,
                       thrivec_resampled_tpr, thrivec_resampled_roc_std, thrivec_roc_auc, mrs02_selected_fold,
                       axsRight[0], plot_legend = False, tick_label_size = tick_label_size, label_font_size = label_font_size)
if plot_subplot_titles:
    axsRight[0].set_title('A.', fontsize=subplot_number_font_size, horizontalalignment='left', x=-0.1, weight=suptitle_font_weight)

# PR curve for MRS
plot_mrs_pr_curve(transformer_pr_df, transformer_resampled_pr_df, transformer_resampled_pr_std, transformer_pr_aucs,
                      thrivec_resampled_precision, thrivec_resampled_pr_std, thrivec_pr_auc, mrs02_selected_fold,
                      axsRight[1], plot_legend = True, tick_label_size = tick_label_size, label_font_size = label_font_size)
if plot_subplot_titles:
    axsRight[1].set_title('B.', fontsize=subplot_number_font_size, horizontalalignment='left', x=-0.1, weight=suptitle_font_weight)


# remove upper and right spines
for ax in axsRight:
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)


#######
# Lower right subfigure: Legends
legaxs = subfigs[1, 1].subplots(1, 2)
# turn off axis and grid for all legaxs
for ax in legaxs:
    ax.axis('off')
    ax.grid(False)

all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)

# legend for Death
legend_markers, legend_labels = axsLLeft[1].get_legend_handles_labels()
legend_labels = [
    f'Transformer (ROC AUC = {transformer_death_roc_aucs[death_selected_fold]:.2f}; PR AUC = {transformer_death_pr_aucs[death_selected_fold]:.2f})',
    f'Transformer MIMIC (ROC AUC = {transformer_ext_death_roc_aucs[death_selected_fold]:.2f}; PR AUC = {transformer_ext_death_pr_aucs[death_selected_fold]:.2f})',
    f'THRIVE-C (ROC AUC = {thrivec_death_roc_auc:.2f}; PR AUC = {thrivec_death_pr_auc:.2f})'
    ]
sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)
sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)
sd3_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)
sd_marker = (sd1_patch, sd2_patch, sd3_patch)
sd_labels = '± s.d.'
legend_markers.append(sd_marker)
legend_labels.append(sd_labels)
legaxs[0].legend(legend_markers, legend_labels, fontsize=label_font_size,
                handler_map={tuple: HandlerTuple(ndivide=None)},
                bbox_to_anchor=(-0.75, 0), loc='lower left', borderaxespad=0.)
axsLLeft[1].get_legend().remove()

# legend for MRS
legend_markers, legend_labels = axsRight[1].get_legend_handles_labels()
legend_labels = [
    f'Transformer (ROC AUC = {transformer_roc_aucs[mrs02_selected_fold]:.2f}; PR AUC = {transformer_pr_aucs[mrs02_selected_fold]:.2f})',
    f'THRIVE-C (ROC AUC = {thrivec_roc_auc:.2f}; PR AUC = {thrivec_pr_auc:.2f})'
    ]
sd1_patch = mpatches.Patch(color=all_colors_palette[0], alpha=0.3)
sd2_patch = mpatches.Patch(color=all_colors_palette[1], alpha=0.3)
sd_marker = (sd1_patch, sd2_patch)
sd_labels = '± s.d.'
legend_markers.append(sd_marker)
legend_labels.append(sd_labels)
legaxs[1].legend(legend_markers, legend_labels, fontsize=label_font_size,
                handler_map={tuple: HandlerTuple(ndivide=None)},
                bbox_to_anchor=(-1.7, 1.2), loc='upper left', borderaxespad=0.)
axsRight[1].get_legend().remove()

fig.suptitle('Temp', fontsize='xx-large')

plt.show()


In [None]:
# fig.savefig(os.path.join(output_dir, 'performance_combined_figure.svg'), bbox_inches="tight", format='svg', dpi=1200)