In [None]:
import pandas as pd
import pickle
import os
import numpy as np
from prediction.outcome_prediction.data_loading.data_loader import load_data
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
import matplotlib.patches as mpatches
from sklearn.utils import resample
from tqdm import tqdm
from matplotlib.legend_handler import HandlerTuple
import matplotlib.lines as mlines


In [None]:
mrs02_predictions_over_time_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_mrs02/transformer_20230402_184459_test_set_evaluation/predictions_over_timesteps_cv2.pkl'
death_predictions_over_time_path = '/Users/jk1/temp/opsum_prediction_output/transformer/3M_Death/testing/predictions_over_timesteps_cv1.pkl'
features_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_features_01012023_233050.csv'
labels_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_outcomes_01012023_233050.csv'

In [None]:
model_name = 'Transformer'
test_size = 0.2
seed = 42
n_time_steps = 72
n_splits = 5

In [None]:
with open(mrs02_predictions_over_time_path, 'rb') as handle:
    mrs02_predictions_over_time = pickle.load(handle)

with open(death_predictions_over_time_path, 'rb') as handle:
    death_predictions_over_time = pickle.load(handle)

In [None]:
mrs02_pids, mrs02_train_data, mrs02_test_data, mrs02_train_splits, mrs02_test_features_lookup_table = load_data(features_path, labels_path, '3M mRS 0-2', test_size, n_splits, seed)

death_pids, death_train_data, death_test_data, death_train_splits, death_test_features_lookup_table = load_data(features_path, labels_path, '3M Death', test_size, n_splits, seed)

In [None]:
mrs02_test_X_np, mrs02_test_y_np = mrs02_test_data
death_test_X_np, death_test_y_np = death_test_data

In [None]:
def get_roc_auc_scores(predictions_over_time, test_y_np, n_time_steps):
    roc_auc_scores = []
    roc_auc_scores_bs = pd.DataFrame()
    for ts in tqdm(range(n_time_steps)):

        # bootstrap to get 95% CI
        n_bs_samples = 1000
        roc_auc_scores_bs_at_ts = []
        for i in range(n_bs_samples):
            y_pred_bs, y_bs = resample(predictions_over_time[ts], test_y_np, replace=True)

            # evaluate model
            roc_auc_bs = roc_auc_score(y_bs, y_pred_bs)
            roc_auc_scores_bs_at_ts.append(roc_auc_bs)
        roc_auc_scores_bs_at_ts_df = pd.DataFrame(roc_auc_scores_bs_at_ts, columns=['roc_auc_score'])
        roc_auc_scores_bs_at_ts_df['n_hours'] = ts
        roc_auc_scores_bs = pd.concat([roc_auc_scores_bs, roc_auc_scores_bs_at_ts_df])

        # non bootstrapped score
        y_pred = predictions_over_time[ts]
        roc_auc_scores.append([ts, roc_auc_score(test_y_np, y_pred)])

    roc_auc_scores_df = pd.DataFrame(roc_auc_scores, columns=['n_hours', 'roc_auc_score'])
    return roc_auc_scores_df, roc_auc_scores_bs

In [None]:
mrs02_roc_auc_scores, mrs02_roc_auc_scores_bs = get_roc_auc_scores(mrs02_predictions_over_time, mrs02_test_y_np, n_time_steps)
death_roc_auc_scores, death_roc_auc_scores_bs = get_roc_auc_scores(death_predictions_over_time, death_test_y_np, n_time_steps)

In [None]:
mrs02_roc_auc_scores_bs.head()

In [None]:
death_roc_auc_scores_bs.head()

In [None]:
all_colors_palette = sns.color_palette(['#f61067', '#049b9a', '#012D98', '#a76dfe'], n_colors=4)
all_colors_palette

In [None]:
plot_zoom = True
plot_title = False
plot_legend = True

tick_label_size = 11
label_font_size = 13

errorbar = 'sd'

fig, ax = plt.subplots(figsize=(10, 10))

sns.lineplot(x='n_hours', y='roc_auc_score', data=mrs02_roc_auc_scores_bs, legend=True, ax=ax, errorbar=errorbar, color=all_colors_palette[2])
sns.lineplot(x='n_hours', y='roc_auc_score', data=death_roc_auc_scores_bs , legend=True, ax=ax, errorbar=errorbar, color=all_colors_palette[3])
if plot_title:
    ax.set_title(f'{model_name} performance in the holdout test dataset as a function of observation period')
ax.set_xlabel('Time after admission (hours)')
ax.set_ylabel('ROC AUC')
ax.set_ylim([0, 1])


if plot_zoom:
    ax2 = plt.axes([0.2, 0.2, .7, .5], facecolor='w')
    sns.lineplot(x='n_hours', y='roc_auc_score', data=mrs02_roc_auc_scores_bs, legend='auto', ax=ax2, errorbar=errorbar, color=all_colors_palette[2])
    sns.lineplot(x='n_hours', y='roc_auc_score', data=death_roc_auc_scores_bs , legend='auto', ax=ax2, errorbar=errorbar, color=all_colors_palette[3])
    ax2.set_title('Zoomed in')
    ax2.set_ybound(0.8,0.92)
    ax2.set_xlabel('Time after admission (hours)')
    ax2.set_ylabel('ROC AUC')

if plot_legend:
    legend_markers, legend_labels = ax.get_legend_handles_labels()
    sd1_patch = mpatches.Patch(color=all_colors_palette[2], alpha=0.3)
    sd2_patch = mpatches.Patch(color=all_colors_palette[3], alpha=0.3)
    sd_marker = (sd1_patch, sd2_patch)
    sd_labels = '± s.d.'
    legend_markers.append(sd_marker)
    legend_labels.append(sd_labels)

    mrs02_line = mlines.Line2D([], [], color=all_colors_palette[2], linestyle='-')
    mrs02_line_label = 'ROC AUC for functional outcome'
    legend_markers.append(mrs02_line)
    legend_labels.append(mrs02_line_label)

    death_line = mlines.Line2D([], [], color=all_colors_palette[3], linestyle='-')
    death_line_label = 'ROC AUC for mortality'
    legend_markers.append(death_line)
    legend_labels.append(death_line_label)

    ax.legend(legend_markers, legend_labels, fontsize=label_font_size,
              handler_map={tuple: HandlerTuple(ndivide=None)})

plt.tight_layout()
plt.show()

In [None]:
# fig.savefig(os.path.join('/Users/jk1/Downloads', f'{model_name}_roc_auc_scores_over_time.png'), bbox_inches='tight')


## Performance at 24h

In [None]:
from sklearn.utils import resample

roc_auc_scores_bs = []
ts = 24

n_iterations = 1000
for i in range(n_iterations):
    y_pred_bs, y_bs = resample(death_predictions_over_time[ts], death_test_y_np, replace=True)

    # evaluate model
    roc_auc_bs = roc_auc_score(y_bs, y_pred_bs)
    roc_auc_scores_bs.append(roc_auc_bs)

median_roc_auc = np.percentile(roc_auc_scores_bs, 50)
# get 95% interval
alpha = 100 - 95
lower_ci_roc_auc = np.percentile(roc_auc_scores_bs, alpha / 2)
upper_ci_roc_auc = np.percentile(roc_auc_scores_bs, 100 - alpha / 2)

print(median_roc_auc, lower_ci_roc_auc, upper_ci_roc_auc)