In [None]:
import os
import sys
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import metrics

root_dir = "../../"
sys.path.append(root_dir)

from utils import models

data_dir = os.path.join(root_dir, "data")
figure_dir = os.path.join(root_dir, "figures", "global")
os.makedirs(figure_dir, exist_ok=True)

sns.set_theme()
sns.set_context("paper", font_scale=1.5)

In [None]:
def global_roc(dataset_name):
    prediction_dir = os.path.join(data_dir, dataset_name, "predictions")

    _, ax = plt.subplots(figsize=(5, 5))
    for model_title, model_name, weak_supervision in models:
        prediction_df = pd.read_pickle(
            os.path.join(prediction_dir, model_name, "predictions")
        )

        if weak_supervision:
            global_prediction = prediction_df["global_logit"].to_numpy()
        else:
            global_prediction = (
                prediction_df["single_slice_logits"]
                .apply(lambda logits: max(logits))
                .to_numpy()
            )
        target = prediction_df["target"].to_numpy()

        fpr, tpr, _ = metrics.roc_curve(target, global_prediction)
        auc = metrics.auc(fpr, tpr)

        ax.plot(
            fpr,
            tpr,
            label=f"{model_title} (AUC={auc:.3f})",
        )

    ax.legend(title="", loc="lower right")
    ax.set_aspect("equal", "box")
    ax.set_xlabel("FPR")
    ax.set_ylabel("TPR")
    ax.set_xticks([0, 0.5, 1])
    ax.set_yticks([0, 0.5, 1])
    ax.set_title(dataset_name)
    plt.savefig(
        os.path.join(figure_dir, f"{dataset_name}_roc.jpg"), bbox_inches="tight"
    )
    plt.savefig(
        os.path.join(figure_dir, f"{dataset_name}_roc.pdf"), bbox_inches="tight"
    )
    plt.show()
    plt.close()

In [None]:
datasets = ["RSNA", "CQ500", "CT-ICH"]
for dataset in datasets:
    global_roc(dataset)