In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
from PIL import Image
from tqdm import tqdm

root_dir = "../"
data_dir = os.path.join(root_dir, "data")
explanation_dir = os.path.join(root_dir, "explanations")
figure_dir = os.path.join(data_dir, "figures", "explanations")
gt_figure_dir = os.path.join(figure_dir, "ground_truth")
os.makedirs(gt_figure_dir, exist_ok=True)

df_train = pd.read_json(os.path.join(data_dir, "training.json"))
df_test = pd.read_json(os.path.join(data_dir, "test_cropped.json"))
frame = [df_train, df_test]
gt_df = pd.concat(frame, ignore_index=True)
image_name = []
for _, row in gt_df.iterrows():
    image_name.append(os.path.basename(row["image"]["pathname"]).split(".")[0])
gt_df["image_name"] = image_name
gt_df.set_index("image_name", inplace=True)

exp_mapper = {
    "gradcam": r"Grad-CAM",
    "lime": r"LIME",
    "gradexp": r"GradientExp",
    "deepexp": r"DeepExp",
    "partexp/500": r"PartitionExp ($m = 500$)",
    "partexp/128": r"PartitionExp ($m = 128$)",
    "partexp/64": r"PartitionExp ($m = 64$)",
    "partexp/32": r"PartitionExp ($m = 32$)",
    "partexp/16": r"PartitionExp ($m = 16$)",
    "hexp/absolute_0": r"h-Shap ($\tau = 0$)",
    "hexp/relative_70": r"h-Shap ($\tau=70\%$)",
}

true_positive = np.load(os.path.join(explanation_dir, "true_positive.npy"))

In [None]:
sns.set_context("paper", font_scale=1.5)

for i, image_path in enumerate(tqdm(true_positive[:5])):
    image = Image.open(os.path.join(root_dir, image_path))

    _, ax = plt.subplots(1, 1)
    ax.imshow(image)
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)
    ax.set_title("image")

    image_name = os.path.basename(image_path).split(".")[0]
    cell = gt_df.at[image_name, "objects"]
    for c in cell:
        category = c["category"]
        if category == "trophozoite":
            bbox = c["bounding_box"]
            ul_r = bbox["minimum"]["r"]
            ul_c = bbox["minimum"]["c"]
            br_r = bbox["maximum"]["r"]
            br_c = bbox["maximum"]["c"]
            w = abs(br_c - ul_c)
            h = abs(br_r - ul_r)
            bbox = patches.Rectangle(
                (ul_c, ul_r),
                w,
                h,
                linewidth=2,
                edgecolor="g",
                facecolor="none",
            )
            ax.add_patch(bbox)

    plt.savefig(os.path.join(gt_figure_dir, f"{image_name}.jpg"), bbox_inches="tight")
    plt.savefig(os.path.join(gt_figure_dir, f"{image_name}.pdf"), bbox_inches="tight")
    plt.close()

    for exp_name, exp_title in exp_mapper.items():
        explainer_dir = os.path.join(explanation_dir, exp_name)
        explainer_figure_dir = os.path.join(figure_dir, exp_name)
        os.makedirs(explainer_figure_dir, exist_ok=True)

        explanation = np.load(os.path.join(explainer_dir, f"{image_name}.npy"))
        _max = np.amax(np.abs(explanation.flatten()))

        _, ax = plt.subplots(1, 1)
        ax.imshow(explanation, cmap="bwr", vmin=-_max, vmax=_max)
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)
        ax.set_title(exp_title)

        plt.savefig(
            os.path.join(explainer_figure_dir, f"{image_name}.jpg"), bbox_inches="tight"
        )
        plt.savefig(
            os.path.join(explainer_figure_dir, f"{image_name}.pdf"), bbox_inches="tight"
        )
        plt.close()