In [None]:
import os

import pandas as pd
import numpy as np
import torch
from sklearn.metrics import roc_curve, precision_recall_curve, auc
import matplotlib.pyplot as plt
from collections import defaultdict

In [None]:
def plot_exp(exp, title, n_splits=5):
    fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
    data = defaultdict(lambda: defaultdict(list))
    mean_fpr = np.linspace(0, 1, 100)
    mean_recall = np.linspace(0, 1, 100)

    for split in range(n_splits):
        results = torch.load(f"runs/predict-{exp}/predict-{exp}-test-{split}/predictions_rank_0.pt")
        trues = results["trues"]
        n_targets = trues.numpy().max()+1

        for target in range(n_targets):
            preds = results["preds"][:, target]
            fpr, tpr, _ = roc_curve(trues, preds, pos_label=target)
            interp_tpr = np.interp(mean_fpr, fpr, tpr)
            interp_tpr[0] = 0.0
            interp_tpr[-1] = 1.0
            data[target]["tpr"].append(interp_tpr)
            data[target]["auroc"].append(auc(fpr, tpr))

            precision, recall, _ = precision_recall_curve(trues, preds, pos_label=target)
            interp_precision = np.interp(mean_recall[::-1], recall[::-1], precision[::-1])[::-1]
            interp_precision[0] = 1.0
            interp_precision[-1] = 0.0
            data[target]["precision"].append(interp_precision)
            data[target]["auprc"].append(auc(recall, precision))

    for target, metrics in data.items():
        mean_tpr = np.mean(metrics["tpr"], axis=0)
        mean_auroc = np.mean(metrics["auroc"])
        std_auroc = np.std(metrics["auroc"])
        l = ax1.plot(mean_fpr, mean_tpr, label=f"{target} (AUROC={mean_auroc:0.2f}±{std_auroc:0.2f})")

        std_tpr = np.std(metrics["tpr"], axis=0)
        tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
        tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
        ax1.fill_between(
            mean_fpr,
            tprs_lower,
            tprs_upper,
            color=l[0].get_color(),
            alpha=0.2,
            # label="1 std. dev.",
        )

        mean_precision = np.mean(metrics["precision"], axis=0)
        mean_auprc = np.mean(metrics["auprc"])
        std_auprc = np.std(metrics["auprc"])
        ax2.plot(mean_recall, mean_precision, label=f"{target} (AUPRC={mean_auprc:0.2f}±{std_auprc:0.2f})")

        std_precision = np.std(metrics["precision"], axis=0)
        precision_upper = np.minimum(mean_precision + std_precision, 1)
        precision_lower = np.maximum(mean_precision - std_precision, 0)
        ax2.fill_between(
            mean_recall,
            precision_lower,
            precision_upper,
            color=l[0].get_color(),
            alpha=0.2,
            # label="1 std. dev.",
        )

    ax1.legend(title=f"{title} Class", loc="lower right")
    ax2.legend(title=f"{title} Class", loc="lower right")

    ax1.set_title(f"{title} ROC Curve")
    ax2.set_title(f"{title} PR Curve")

    ax1.set_xlabel("1 - Specificity")
    ax1.set_ylabel("Sensitivity")

    ax2.set_xlabel("Recall")
    ax2.set_ylabel("Precision")

    return fig

In [None]:
fig = plot_exp(exp="cls-iss-24h", title="ISS (24h Notes)")
fig.tight_layout()
fig.savefig("figs/cls-iss-24h.png", dpi=300)

In [None]:
fig = plot_exp(exp="cls-iss-48h", title="ISS (48h Notes)")
fig.tight_layout()
fig.savefig("figs/cls-iss-48h.png", dpi=300)

In [None]:
fig = plot_exp(exp="cls-mort-24h", title="Mortality (24h Notes)")
fig.tight_layout()
fig.savefig("figs/cls-mort-24h.png", dpi=300)

In [None]:
fig = plot_exp(exp="cls-mort-48h", title="Mortality (48h Notes)")
fig.tight_layout()
fig.savefig("figs/cls-mort-48h.png", dpi=300)