In [None]:
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from skimage.filters import threshold_otsu
from functools import reduce
from tqdm import tqdm

root_dir = "../../"
sys.path.append(root_dir)
from dataset import CTICHDataset
from utils import models, explainers, window, annotate

data_dir = os.path.join(root_dir, "data")
ctich_dir = os.path.join(data_dir, "CT-ICH")
image_dir = os.path.join(ctich_dir, "images")
mask_dir = os.path.join(ctich_dir, "masks")
figure_dir = os.path.join(root_dir, "figures", "image_level", "explanations", "CT-ICH")
os.makedirs(figure_dir, exist_ok=True)

explanation_dirs = [
    os.path.join(ctich_dir, "explanations", model) for _, model, _ in models
]

explained_slice_ids = (
    np.load(os.path.join(explanation_dir, explainer["name"], "explained_slice_ids.npy"))
    for explanation_dir in explanation_dirs
    for explainer in explainers
)
explained_slice_ids = reduce(np.intersect1d, explained_slice_ids)

diagnosis_df = pd.read_csv(os.path.join(ctich_dir, "hemorrhage_diagnosis_raw_ct.csv"))
diagnosis_df.set_index(["PatientNumber", "SliceNumber"], inplace=True)

annotation_df = pd.read_csv(os.path.join(ctich_dir, "annotations.csv"))
annotation_df.set_index(["PatientNumber", "SliceNumber"], inplace=True)

dataset = CTICHDataset(ctich_dir)

sns.set_theme()
sns.set_context("paper", font_scale=1.5)

In [None]:
for explained_slice_id in tqdm(explained_slice_ids):
    series_idx, slice_idx = explained_slice_id.split("_")

    patient_id = dataset.series[int(series_idx)]
    patient_number = int(patient_id.lstrip("0"))
    slice_idx = int(slice_idx)
    slice_number = slice_idx + 1

    slice_row = diagnosis_df.loc[patient_number, slice_number]

    print(patient_number, slice_number)
    if not int(slice_row["No_Hemorrhage"]):
        _figure_dir = os.path.join(figure_dir, f"{patient_number}_{slice_number}")
        os.makedirs(_figure_dir, exist_ok=True)

        image = np.load(os.path.join(image_dir, f"{patient_id}_{slice_idx}.npy"))
        image = window(image, window_level=40, window_width=80)

        _, ax = plt.subplots()
        ax.imshow(image, cmap="gray")
        annotate(annotation_df.loc[patient_number, slice_number], ax)
        ax.set_title("Ground-truth")
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        plt.savefig(os.path.join(_figure_dir, "ground_truth.pdf"), bbox_inches="tight")
        plt.savefig(os.path.join(_figure_dir, "ground_truth.jpg"), bbox_inches="tight")
        plt.close()

        for (model_title, model, _), explanation_dir in zip(models, explanation_dirs):
            model_figure_dir = os.path.join(_figure_dir, model)
            os.makedirs(model_figure_dir, exist_ok=True)

            for i, explainer in enumerate(explainers):
                explainer_title = explainer["title"]
                explainer = explainer["name"]

                explanation = np.load(
                    os.path.join(
                        explanation_dir,
                        explainer,
                        f"{series_idx}_{slice_idx}.npy",
                    )
                )
                _t = threshold_otsu(explanation.flatten(), nbins=1024)
                explanation = explanation * (explanation > _t)
                abs_values = np.abs(explanation.flatten())
                _max = np.nanpercentile(abs_values, 99)

                _, ax = plt.subplots()
                ax.imshow(image, cmap="gray")
                ax.imshow(
                    explanation,
                    cmap="bwr",
                    vmin=-_max,
                    vmax=_max,
                    alpha=0.5,
                )
                ax.set_title(f"{model_title}, {explainer_title}")
                ax.xaxis.set_visible(False)
                ax.yaxis.set_visible(False)
                plt.savefig(
                    os.path.join(
                        model_figure_dir, f"{explainer.replace('/', '_')}.pdf"
                    ),
                    bbox_inches="tight",
                )
                plt.savefig(
                    os.path.join(
                        model_figure_dir, f"{explainer.replace('/', '_')}.jpg"
                    ),
                    bbox_inches="tight",
                )
                plt.close()