In [None]:
import os
import sys
import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import metrics
from skimage.filters import threshold_otsu
from functools import reduce
from tqdm import tqdm

root_dir = "../../"
sys.path.append(root_dir)
from dataset import CTICHDataset
from utils import models, image_explainers as explainers, window

data_dir = os.path.join(root_dir, "data")
ctich_dir = os.path.join(data_dir, "CT-ICH")
image_dir = os.path.join(ctich_dir, "images")
mask_dir = os.path.join(ctich_dir, "masks")
figure_dir = os.path.join(root_dir, "figures", "image_level")
os.makedirs(figure_dir, exist_ok=True)

explanation_dirs = [
    os.path.join(ctich_dir, "explanations", model) for _, model, _ in models
]

explained_slice_ids = (
    np.load(os.path.join(explanation_dir, explainer["name"], "explained_slice_ids.npy"))
    for explanation_dir in explanation_dirs
    for explainer in explainers
)
explained_slice_ids = reduce(np.intersect1d, explained_slice_ids)

diagnosis_df = pd.read_csv(os.path.join(ctich_dir, "hemorrhage_diagnosis_raw_ct.csv"))
diagnosis_df.set_index(["PatientNumber", "SliceNumber"], inplace=True)

annotation_df = pd.read_csv(os.path.join(ctich_dir, "annotations.csv"))
annotation_df.set_index(["PatientNumber", "SliceNumber"], inplace=True)

dataset = CTICHDataset(ctich_dir)

sns.set_theme()
sns.set_context("paper", font_scale=1.5)

In [None]:
if os.path.exists(os.path.join(ctich_dir, "explanations", "image_level_f1")):
    f1_df = pd.read_pickle(os.path.join(ctich_dir, "explanations", "image_level_f1"))
else:
    count = 0
    f1_df = []
    for explained_slice_id in tqdm(explained_slice_ids):
        series_idx, slice_idx = explained_slice_id.split("_")

        patient_id = dataset.series[int(series_idx)]
        patient_number = int(patient_id.lstrip("0"))
        slice_idx = int(slice_idx)
        slice_number = slice_idx + 1

        slice_row = diagnosis_df.loc[patient_number, slice_number]
        hemorrhage_types = set()
        for k, t in [
            ("Intraventricular", "IVH"),
            ("Intraparenchymal", "IPH"),
            ("Subarachnoid", "SAH"),
            ("Epidural", "EDH"),
            ("Subdural", "SDH"),
        ]:
            if slice_row[k] == 1:
                hemorrhage_types.add(t)
        if len(hemorrhage_types) > 1:
            hemorrhage_types = set()
            hemorrhage_types.add("Any")
        else:
            hemorrhage_types.add("Any")

        if not int(slice_row["No_Hemorrhage"]):
            count += 1
            image = np.load(os.path.join(image_dir, f"{patient_id}_{slice_idx}.npy"))
            image = window(image, window_level=40, window_width=80)

            ground_truth = np.zeros_like(image)
            for _, annotation_row in annotation_df.loc[
                patient_number, slice_number
            ].iterrows():
                annotation = annotation_row["data"].replace("'", '"')
                annotation = json.loads(annotation)
                bbox_x = int(annotation["x"])
                bbox_y = int(annotation["y"])
                bbox_width = int(annotation["width"])
                bbox_height = int(annotation["height"])
                ground_truth[
                    bbox_y : bbox_y + bbox_height, bbox_x : bbox_x + bbox_width
                ] = 1
            ground_truth = ground_truth.flatten()

            for (model_title, _, _), explanation_dir in zip(models, explanation_dirs):
                for i, explainer in enumerate(explainers):
                    explainer_title = explainer["title"]
                    explainer = explainer["name"]

                    explanation = np.load(
                        os.path.join(
                            explanation_dir,
                            explainer,
                            f"{series_idx}_{slice_idx}.npy",
                        )
                    )
                    explanation = explanation.flatten()
                    _t = threshold_otsu(explanation, nbins=1024)

                    (
                        precision,
                        recall,
                        score,
                        _,
                    ) = metrics.precision_recall_fscore_support(
                        ground_truth,
                        explanation > _t,
                        beta=1,
                        pos_label=1,
                        average="binary",
                        zero_division=0,
                    )
                    f1_df.append(
                        {
                            "model": model_title,
                            "explainer": explainer_title,
                            "model_explainer": f"{model_title}, {explainer_title}",
                            "patient_number": patient_number,
                            "slice_idx": slice_idx,
                            "hemorrhage_types": hemorrhage_types,
                            "t": "otsu",
                            "precision": precision,
                            "recall": recall,
                            "f1": score,
                        }
                    )
    f1_df = pd.DataFrame(f1_df)
    f1_df.to_pickle(os.path.join(ctich_dir, "explanations", f"image_level_f1"))
    print(f"f1 score evaluated on {count} images")

In [None]:
df = f1_df.explode("hemorrhage_types", ignore_index=True)

_, ax = plt.subplots(figsize=(16 / 2, 9 / 2))
ax = sns.boxplot(
    data=df.explode("hemorrhage_types", ignore_index=True),
    x="hemorrhage_types",
    y="f1",
    hue="model_explainer",
    ax=ax,
    palette=sns.color_palette()[:2] + sns.color_palette("pastel")[:2],
    order=sorted(pd.unique(df["hemorrhage_types"])),
)
ax.set_xlabel("Hemorrhage type")
ax.set_ylabel(r"$f_1$ score")
ax.set_ylim(None, 1)
xticklabels = ax.get_xticklabels()
for tick in xticklabels:
    hem_type = tick.get_text()
    tick.set_text(f"{hem_type} ({len(df[df['hemorrhage_types'] == hem_type]) // 4})")
ax.set_xticklabels(xticklabels)
ax.legend(title="", loc="upper left", bbox_to_anchor=(1, 1))

plt.savefig(os.path.join(figure_dir, "CT-ICH_f1_ICH_type.jpg"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "CT-ICH_f1_ICH_type.pdf"), bbox_inches="tight")
plt.show()