In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

root_dir = "../../"
sys.path.append(root_dir)

from dataset import RSNADataset
from utils import models, location_performance_t, find_hemorrhages

data_dir = os.path.join(root_dir, "data")
rsna_dir = os.path.join(data_dir, "RSNA")
prediction_dir = os.path.join(rsna_dir, "predictions")
figure_dir = os.path.join(root_dir, "figures", "exam_level")
os.makedirs(figure_dir, exist_ok=True)

sns.set_theme()
sns.set_context("paper", font_scale=1.5)

In [None]:
dataset = RSNADataset(rsna_dir, op="val", augment=False, weak_supervision=True)

true_hemorrhage = []
for i in tqdm(range(len(dataset))):
    series_id = dataset.series_ids[i]
    series_obj = dataset.series_dictionary[series_id]

    series = series_obj["series"]
    series = np.array(series)

    labels = series[:, 1]
    labels = labels.astype(int)
    sorted_ids = np.load(
        os.path.join(rsna_dir, "series", f"{series_id}_sorted_ids.npy")
    )
    labels = labels[sorted_ids]

    _true_hemorrhage = find_hemorrhages(labels)
    if len(_true_hemorrhage) > 0:
        true_hemorrhage.extend([len(hem) for hem in _true_hemorrhage])
true_hemorrhage = np.array(true_hemorrhage)
print(
    f"{len(true_hemorrhage)} hemorrhages, {int(true_hemorrhage.mean())} images +- {int(true_hemorrhage.std())}"
)

In [None]:
def exam_level_tpr(cutoff):
    cutoff_method, cutoff_name = cutoff

    df = []
    for model_title, model, weak_supervision in tqdm(models):
        prediction_df = pd.read_pickle(
            os.path.join(prediction_dir, model, "predictions")
        )
        prediction_df.set_index("series_idx", inplace=True)

        _df = location_performance_t(
            model_title,
            prediction_df,
            weak_supervision,
            cutoff=cutoff_method,
        )
        df.append(_df)
    df = pd.concat(df)
    df = df.explode(["true_hemorrhage_length", "hit"]).reset_index()

    _, ax = plt.subplots(figsize=(16 / 2, 9 / 2))
    sns.lineplot(data=df, x="true_hemorrhage_length", y="hit", hue="est_name", ax=ax)
    ax.set_xlabel("Hemorrhage sequence length")
    ax.set_ylabel("TPR")
    ax.legend(title="", loc="lower right")
    ax.set_title(cutoff_name)

    plt.savefig(
        os.path.join(figure_dir, f"RSNA_TPR_{cutoff_method}.jpg"), bbox_inches="tight"
    )
    plt.savefig(
        os.path.join(figure_dir, f"RSNA_TPR_{cutoff_method}.pdf"), bbox_inches="tight"
    )
    plt.show()
    plt.close()

In [None]:
cutoffs = [("youden", r"Youden's $J$ statistic"), ("d", r"Distance to $(0,1)$ point")]
for cutoff in cutoffs:
    exam_level_tpr(cutoff)