In [None]:
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import metrics
from tqdm import tqdm

root_dir = "../../"
sys.path.append(root_dir)

from utils import label_complexity_models as models

data_dir = os.path.join(root_dir, "data")
figure_dir = os.path.join(root_dir, "figures", "global")
os.makedirs(figure_dir, exist_ok=True)

sns.set_theme()
sns.set_context("paper", font_scale=1.5)

In [None]:
def global_auc_label_complexity(dataset):
    dataset_dir, dataset_name, prediction_file_name = dataset

    df = []
    for model_title, model, weak_supervision in tqdm(models):
        m = int(model.split("/")[-2])
        prediction_df = pd.read_pickle(
            os.path.join(dataset_dir, "predictions", model, prediction_file_name)
        )

        if weak_supervision:
            global_prediction = prediction_df["global_logit"].to_numpy()
        else:
            global_prediction = (
                prediction_df["single_slice_logits"]
                .apply(lambda logits: max(logits))
                .to_numpy()
            )
        target = prediction_df["target"].to_numpy()

        fpr, tpr, _ = metrics.roc_curve(target, global_prediction)
        auc = metrics.auc(fpr, tpr)
        df.append(
            {
                "dataset_name": dataset_name,
                "model_title": model_title,
                "m": m,
                "auc": auc,
            }
        )
    df = pd.DataFrame(df)

    _, ax = plt.subplots(figsize=(5, 5))
    sns.lineplot(data=df, x="m", y="auc", hue="model_title", estimator=np.mean)
    ax.set_xlabel(r"$m$")
    ax.set_ylabel("AUC")
    ax.set_xscale("log")
    ax.legend(loc="lower right")
    ax.set_title(dataset_name)

    plt.savefig(
        os.path.join(figure_dir, f"{dataset_name}_auc_label_complexity.jpg"),
        bbox_inches="tight",
    )
    plt.savefig(
        os.path.join(figure_dir, f"{dataset_name}_auc_label_complexity.pdf"),
        bbox_inches="tight",
    )
    plt.show()
    plt.close()

In [None]:
datasets = [
    (os.path.join(data_dir, "RSNA"), "RSNA", "predictions_fixed"),
    (os.path.join(data_dir, "CT-ICH"), "CT-ICH", "predictions"),
    (os.path.join(data_dir, "CQ500"), "CQ500", "predictions"),
]
for dataset in datasets:
    global_auc_label_complexity(dataset)