In [None]:
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
from sklearn import metrics
from skimage.filters import threshold_otsu
from functools import reduce
from tqdm import tqdm

root_dir = "../../"
sys.path.append(root_dir)
from utils import models, image_explainers as explainers, window

data_dir = os.path.join(root_dir, "data")
cq500_dir = os.path.join(data_dir, "CQ500")
bhx_dir = os.path.join(data_dir, "BHX")
image_dir = os.path.join(cq500_dir, "images")
figure_dir = os.path.join(root_dir, "figures", "image_level")
os.makedirs(figure_dir, exist_ok=True)

explanation_dirs = [
    os.path.join(cq500_dir, "explanations", model) for _, model, _ in models
]

explained_sop_ids = [
    np.load(os.path.join(explanation_dir, explainer["name"], "explained_sop_ids.npy"))[
        :, 1
    ]
    for explanation_dir in explanation_dirs
    for explainer in explainers
]
explained_sop_ids = reduce(np.intersect1d, explained_sop_ids)

plain_thick_series_df = pd.read_csv(os.path.join(cq500_dir, "plain_thick_series.csv"))
manual_annotation_df = pd.read_csv(
    os.path.join(bhx_dir, "1_Initial_Manual_Labeling.csv")
)

hem_type_dict = {
    "Intraventricular": "IVH",
    "Intraparenchymal": "IPH",
    "Subarachnoid": "SAH",
    "Epidural": "EDH",
    "Subdural": "SDH",
    "Chronic": "SDH",
}

sns.set_theme()
sns.set_context("paper", font_scale=1.5)

In [None]:
if os.path.exists(os.path.join(cq500_dir, "explanations", "image_level_f1")):
    f1_df = pd.read_pickle(os.path.join(cq500_dir, "explanations", "image_level_f1"))
else:
    f1_df = []
    count = 0
    for _, series_row in tqdm(
        plain_thick_series_df.iterrows(), total=len(plain_thick_series_df)
    ):
        exam_dir = series_row["exam_dir"]
        series_dir = series_row["series_dir"]
        series_image_dir = os.path.join(image_dir, exam_dir, series_dir)
        sop_ids = np.load(os.path.join(image_dir, exam_dir, series_dir, "sop_ids.npy"))
        for image_name, sop_id in sop_ids:
            if sop_id in explained_sop_ids:
                sop_annotation_df = manual_annotation_df[
                    manual_annotation_df["SOPInstanceUID"] == sop_id
                ]
                if len(sop_annotation_df) > 0:
                    count += 1
                    image = np.load(
                        os.path.join(
                            series_image_dir, image_name.replace(".dcm", ".npy")
                        )
                    )
                    if image.shape != (512, 512):
                        print("Reshape needed")
                        continue
                    image = window(image, window_level=40, window_width=80)

                    hemorrhage_types = set()
                    ground_truth = np.zeros((512, 512))
                    for _, annotation_row in sop_annotation_df.iterrows():
                        annotation = annotation_row["data"].replace("'", '"')
                        annotation = json.loads(annotation)
                        bbox_x = int(annotation["x"])
                        bbox_y = int(annotation["y"])
                        bbox_width = int(annotation["width"])
                        bbox_height = int(annotation["height"])
                        ground_truth[
                            bbox_y : bbox_y + bbox_height, bbox_x : bbox_x + bbox_width
                        ] = 1
                        hemorrhage_types.add(hem_type_dict[annotation_row["labelName"]])
                    if len(hemorrhage_types) > 1:
                        hemorrhage_types = set()
                        hemorrhage_types.add("Any")
                    else:
                        hemorrhage_types.add("Any")
                    ground_truth = ground_truth.flatten()

                    for (model_title, _, _), explanation_dir in zip(
                        models, explanation_dirs
                    ):
                        for i, explainer in enumerate(explainers):
                            explainer_title = explainer["title"]
                            explainer = explainer["name"]

                            explanation = np.load(
                                os.path.join(
                                    explanation_dir,
                                    explainer,
                                    f"{sop_id}.npy",
                                )
                            )
                            explanation = explanation.flatten()
                            _t = threshold_otsu(explanation, nbins=1024)
                            (
                                precision,
                                recall,
                                score,
                                _,
                            ) = metrics.precision_recall_fscore_support(
                                ground_truth,
                                explanation > _t,
                                beta=1,
                                pos_label=1,
                                average="binary",
                                zero_division=0,
                            )
                            f1_df.append(
                                {
                                    "model": model_title,
                                    "explainer": explainer_title,
                                    "model_explainer": f"{model_title}, {explainer_title}",
                                    "sop_id": sop_id,
                                    "hemorrhage_types": list(hemorrhage_types),
                                    "t": "otsu",
                                    "precision": precision,
                                    "recall": recall,
                                    "f1": score,
                                }
                            )
    f1_df = pd.DataFrame(f1_df)
    f1_df.to_pickle(os.path.join(cq500_dir, "explanations", "image_level_f1"))
    print(f"f1 score evaluated on {count} images")

In [None]:
df = f1_df.explode("hemorrhage_types", ignore_index=True)

_, ax = plt.subplots(figsize=(16 / 2, 9 / 2))
ax = sns.boxplot(
    data=df.explode("hemorrhage_types", ignore_index=True),
    x="hemorrhage_types",
    y="f1",
    hue="model_explainer",
    ax=ax,
    palette=sns.color_palette()[:2] + sns.color_palette("pastel")[:2],
    order=sorted(pd.unique(df["hemorrhage_types"])),
)
ax.set_xlabel("Hemorrhage type")
ax.set_ylabel(r"$f_1$ score")
ax.set_ylim(None, 1)
xticklabels = ax.get_xticklabels()
for tick in xticklabels:
    hem_type = tick.get_text()
    tick.set_text(f"{hem_type} ({len(df[df['hemorrhage_types'] == hem_type]) // 4})")
ax.set_xticklabels(xticklabels)
ax.legend(title="", loc="upper left", bbox_to_anchor=(1, 1))

plt.savefig(os.path.join(figure_dir, "CQ500_f1_ICH_type.jpg"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "CQ500_f1_ICH_type.pdf"), bbox_inches="tight")
plt.show()