In [None]:
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
from skimage.filters import threshold_otsu
from functools import reduce
from tqdm import tqdm

root_dir = "../../"
sys.path.append(root_dir)
from utils import models, explainers, window, annotate

data_dir = os.path.join(root_dir, "data")
cq500_dir = os.path.join(data_dir, "CQ500")
image_dir = os.path.join(cq500_dir, "images")
bhx_dir = os.path.join(data_dir, "BHX")
figure_dir = os.path.join(root_dir, "figures", "image_level", "explanations", "CQ500")
os.makedirs(figure_dir, exist_ok=True)

explanation_dirs = [
    os.path.join(cq500_dir, "explanations", model) for _, model, _ in models
]

explained_sop_ids = [
    np.load(os.path.join(explanation_dir, explainer["name"], "explained_sop_ids.npy"))[
        :, 1
    ]
    for explanation_dir in explanation_dirs
    for explainer in explainers
]
explained_sop_ids = reduce(np.intersect1d, explained_sop_ids)

plain_thick_series_df = pd.read_csv(os.path.join(cq500_dir, "plain_thick_series.csv"))
manual_annotation_df = pd.read_csv(
    os.path.join(bhx_dir, "1_Initial_Manual_Labeling.csv")
)

hem_type_dict = {
    "Intraventricular": "IVH",
    "Intraparenchymal": "IPH",
    "Subarachnoid": "SAH",
    "Epidural": "EDH",
    "Subdural": "SDH",
    "Chronic": "SDH",
}

sns.set_theme()
sns.set_context("paper", font_scale=1.5)

In [None]:
for _, series_row in tqdm(
    plain_thick_series_df.iterrows(), total=len(plain_thick_series_df)
):
    exam_dir = series_row["exam_dir"]
    series_dir = series_row["series_dir"]
    series_image_dir = os.path.join(image_dir, exam_dir, series_dir)
    sop_ids = np.load(os.path.join(series_image_dir, "sop_ids.npy"))
    for image_name, sop_id in sop_ids:
        if sop_id in explained_sop_ids:
            _figure_dir = os.path.join(figure_dir, sop_id)
            os.makedirs(_figure_dir, exist_ok=True)

            sop_annotation_df = manual_annotation_df[
                manual_annotation_df["SOPInstanceUID"] == sop_id
            ]
            if len(sop_annotation_df) > 0:
                image = np.load(
                    os.path.join(series_image_dir, image_name.replace(".dcm", ".npy"))
                )
                image = window(image, window_level=40, window_width=80)

                gt = np.zeros((512, 512))
                for _, annotation_row in sop_annotation_df.iterrows():
                    annotation = annotation_row["data"].replace("'", '"')
                    annotation = json.loads(annotation)
                    bbox_x = int(annotation["x"])
                    bbox_y = int(annotation["y"])
                    bbox_width = int(annotation["width"])
                    bbox_height = int(annotation["height"])
                    gt[bbox_y : bbox_y + bbox_height, bbox_x : bbox_x + bbox_width] = 1

                _, ax = plt.subplots()
                ax.imshow(image, cmap="gray")
                annotate(sop_annotation_df, ax)
                ax.set_title("Ground-truth")
                ax.xaxis.set_visible(False)
                ax.yaxis.set_visible(False)
                plt.savefig(
                    os.path.join(_figure_dir, "ground_truth.pdf"), bbox_inches="tight"
                )
                plt.savefig(
                    os.path.join(_figure_dir, "ground_truth.jpg"), bbox_inches="tight"
                )
                plt.close()

                for (model_title, model, _), explanation_dir in zip(
                    models, explanation_dirs
                ):
                    model_figure_dir = os.path.join(_figure_dir, model)
                    os.makedirs(model_figure_dir, exist_ok=True)

                    for i, explainer in enumerate(explainers):
                        explainer_title = explainer["title"]
                        explainer = explainer["name"]

                        explanation = np.load(
                            os.path.join(explanation_dir, explainer, f"{sop_id}.npy")
                        )

                        _t = threshold_otsu(explanation.flatten(), nbins=1024)
                        explanation_mask = explanation > _t

                        explanation = explanation * explanation_mask

                        abs_values = np.abs(explanation.flatten())
                        _max = np.nanpercentile(abs_values, 99)

                        _, ax = plt.subplots()
                        ax.imshow(image, cmap="gray")
                        ax.imshow(
                            explanation,
                            cmap="bwr",
                            vmin=-_max,
                            vmax=_max,
                            alpha=0.5,
                        )
                        ax.set_title(f"{model_title}, {explainer_title}")
                        ax.xaxis.set_visible(False)
                        ax.yaxis.set_visible(False)
                        plt.savefig(
                            os.path.join(
                                model_figure_dir, f"{explainer.replace('/', '_')}.pdf"
                            ),
                            bbox_inches="tight",
                        )
                        plt.savefig(
                            os.path.join(
                                model_figure_dir, f"{explainer.replace('/', '_')}.jpg"
                            ),
                            bbox_inches="tight",
                        )
                        plt.close()