# Out of Distribution Testing (for main result in paper)
## FasterRisk in here is uncalibrated

In [None]:
import mimic_pipeline as mmp
import mimic_pipeline.utils as utils
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import joblib
from sklearn.metrics import roc_curve, auc, precision_recall_curve

plt.rcParams['figure.dpi'] = 200

In [None]:
eicu_df = pd.read_csv('data/eICU-union.csv')

In [None]:
X_test, y_test = eicu_df.drop(['uniquepid', 'patientunitstayid', 'hospital_expire_flag', 'apache_iv_prob', 'apache_iva_prob', 'oasis_prob', 'sapsii_prob'], axis=1), eicu_df['hospital_expire_flag']
apacheiv, apacheiva, oasis, sapsii = eicu_df['apache_iv_prob'], eicu_df['apache_iva_prob'], eicu_df['oasis_prob'], eicu_df['sapsii_prob']

In [None]:
def evaluate_fasterrisk(load_path: str, X_test, y_test, baselines: dict) -> dict:
    fasterrisk_name = load_path.split('/')[-1]
    fasterrisk = joblib.load(load_path)
    binarizer = joblib.load(f"{load_path}-binarizer")
    stats = {}
    
    X_test, _ = binarizer.transform(X_test)
    y_prob = fasterrisk.predict_proba(X_test.to_numpy())
    
    stats[f"{fasterrisk_name}_y_prob"] = y_prob
    fpr, tpr, _ = roc_curve(y_test, y_prob)
    stats[f"{fasterrisk_name}_fpr"], stats[f"{fasterrisk_name}_tpr"] = fpr, tpr
    auroc = auc(fpr, tpr)
    stats[f"{fasterrisk_name}_auroc"] = auroc
    precision, recall, _ = precision_recall_curve(y_test, y_prob)
    stats[f"{fasterrisk_name}_precision"], stats[f"{fasterrisk_name}_recall"] = precision, recall
    auprc = auc(recall, precision)
    stats[f"{fasterrisk_name}_auprc"] = auprc
    prob_true, prob_pred, h_stat, pvalue = mmp.metric.get_calibration_curve(y_test, y_prob)
    stats[f"{fasterrisk_name}_prob_true"], stats[f"{fasterrisk_name}_prob_pred"] = prob_true, prob_pred
    stats[f"{fasterrisk_name}_h_stat"], stats[f"{fasterrisk_name}_pvalue"] = h_stat, pvalue
    
    for name, baseline_prob in baselines.items():
        stats[f"{name}_y_prob"] = baseline_prob
        fpr, tpr, _ = roc_curve(y_test, baseline_prob)
        stats[f"{name}_fpr"], stats[f"{name}_tpr"] = fpr, tpr
        stats[f'{name}_auroc'] = auc(fpr, tpr)
        precision, recall, _ = precision_recall_curve(y_test, baseline_prob)
        stats[f"{name}_precision"], stats[f"{name}_recall"] = precision, recall
        stats[f'{name}_auprc'] = auc(recall, precision)
        prob_true, prob_pred, h_stat, pvalue = mmp.metric.get_calibration_curve(y_test, baseline_prob)
        stats[f"{name}_prob_true"], stats[f"{name}_prob_pred"] = prob_true, prob_pred
        stats[f"{name}_h_stat"], stats[f"{name}_pvalue"] = h_stat, pvalue
    
    return stats

In [None]:
def visualize_results(stats: dict, names: list, name_dict: dict, title: str):
    sns.set_style("white")
    print(f"{'-'*50} {title} {'-'*50}")
    for name in names:      # ROC
        fpr, tpr = stats[f"{name}_fpr"], stats[f"{name}_tpr"]
        if name.split('-')[0] == 'fasterrisk':
            alpha, linewidth = 1, 1.5
        else:
            alpha, linewidth = 0.5, 1
        ax = sns.lineplot(x=fpr, y=tpr, label=f"{name_dict[name]}, {stats[f'{name}_auroc']:.3f}", linewidth=linewidth, alpha=alpha)
        ax.figure.set_size_inches(8, 8)
    sns.lineplot(x=np.linspace(0,1), y=np.linspace(0,1), label='Random', linestyle='--', color='grey', linewidth=1)
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(title)
    plt.legend(bbox_to_anchor=(1, 1), loc="upper left")
    plt.show()
    
    for name in names:      # PR
        precision, recall = stats[f"{name}_precision"], stats[f"{name}_recall"]
        if name.split('-')[0] == 'fasterrisk':
            alpha, linewidth = 1, 1.5
        else:
            alpha, linewidth = 0.5, 1
        ax = sns.lineplot(x=recall, y=precision, label=f"{name_dict[name]}, {stats[f'{name}_auprc']:.3f}", linewidth=linewidth, alpha=alpha)
        ax.figure.set_size_inches(8, 8)
        plt.legend(bbox_to_anchor=(1, 1), loc="upper left")
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(title)
    plt.show()
    
    for name in names:      # calibration
        prob_true, prob_pred = stats[f"{name}_prob_true"], stats[f"{name}_prob_pred"]
        if name.split('-')[0] == 'fasterrisk':
            alpha, linewidth = 1, 1.5
        else:
            alpha, linewidth = 0.5, 1
            
        if stats[f'{name}_pvalue'] < 0.0001:
            p_label = 'p < 0.0001'
        else:
            p_label = f'p = {stats[f"{name}_pvalue"]:.3f}'
        ax = sns.lineplot(x=prob_pred, y=prob_true, label=f"{name_dict[name]}, H = {stats[f'{name}_h_stat']:.3f}, {p_label} ", linewidth=linewidth, alpha=alpha, marker='s')
        ax.figure.set_size_inches(8, 8)
    sns.lineplot(x=np.linspace(0,1), y=np.linspace(0,1), label='Perfect', linestyle='--', color='grey', linewidth=1)
    plt.xlabel('Predicted Probability')
    plt.ylabel('True Probability')
    plt.title(title)
    plt.legend(bbox_to_anchor=(1, 1), loc="upper left")
    plt.show()

In [None]:
for sparsity in range(10, 50, 5):
    stats = evaluate_fasterrisk(f'models/fasterrisk-{sparsity}', X_test, y_test, {'apacheiv': apacheiv, 'apacheiva': apacheiva, 'oasis': oasis, 'sapsii': sapsii})
    visualize_results(
        stats, 
        [f'fasterrisk-{sparsity}', 'apacheiv', 'apacheiva', 'oasis', 'sapsii'], 
        {f'fasterrisk-{sparsity}': f'FasterRisk-{sparsity}', 'apacheiv': 'APACHE IV', 'apacheiva': 'APACHE IVa', 'oasis': 'OASIS', 'sapsii': 'SAPS II'},
        title=f"FasterRisk with Group Sparsity of {sparsity}"
    )