In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from f1_utils import ctich_f1

root_dir = "../../"

data_dir = os.path.join(root_dir, "data")
ctich_dir = os.path.join(data_dir, "CT-ICH")
image_dir = os.path.join(ctich_dir, "images")
mask_dir = os.path.join(ctich_dir, "masks")
figure_dir = os.path.join(root_dir, "figures", "image_level")
os.makedirs(figure_dir, exist_ok=True)

sns.set_theme()
sns.set_context("paper", font_scale=1.5)

In [None]:
if not os.path.exists(os.path.join(ctich_dir, "explanations", "image_level_f1")):
    ctich_f1()
f1_df = pd.read_pickle(os.path.join(ctich_dir, "explanations", "image_level_f1"))

In [None]:
from scipy.stats import iqr

df = f1_df.explode("hemorrhage_types", ignore_index=True)

for metric, name in [
    ("f1", "Dice score"),
    ("precision", "precision"),
    ("recall", "recall"),
]:
    _, ax = plt.subplots(figsize=(16 / 2, 9 / 2))
    ax = sns.boxplot(
        data=df.explode("hemorrhage_types", ignore_index=True),
        x="hemorrhage_types",
        y=metric,
        hue="model_explainer",
        ax=ax,
        palette=sns.color_palette()[:2] + sns.color_palette("pastel")[:2],
        order=sorted(pd.unique(df["hemorrhage_types"])),
    )
    ax.set_xlabel("Hemorrhage type")
    ax.set_ylabel(name)
    ax.set_ylim(-0.05, 1.05)
    xticklabels = ax.get_xticklabels()
    for tick in xticklabels:
        hem_type = tick.get_text()
        tick.set_text(
            f"{hem_type} ({len(df[df['hemorrhage_types'] == hem_type]) // 4})"
        )
    ax.set_xticklabels(xticklabels)
    ax.legend(title="", loc="upper left", bbox_to_anchor=(1, 1))
    ax.set_title("CT-ICH")

    plt.savefig(
        os.path.join(figure_dir, f"CT-ICH_{metric}_ICH_type.jpg"), bbox_inches="tight"
    )
    plt.savefig(
        os.path.join(figure_dir, f"CT-ICH_{metric}_ICH_type.pdf"), bbox_inches="tight"
    )
    plt.show()

In [None]:
median_iqr_df = df.groupby(["hemorrhage_types", "model_explainer"]).agg(
    {"f1": ["median", iqr]}
)
median_iqr_df

In [None]:
worst_df = df.loc[df.groupby(["hemorrhage_types", "model_explainer"])["f1"].idxmin()]

for _, data in worst_df.iterrows():
    print(
        data["hemorrhage_types"],
        data["model"],
        data["explainer"],
        data["patient_number"],
        data["slice_idx"] + 1,
        data["f1"],
    )