In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm import tqdm

RANDOMIZE_MODEL = False
root_dir = "../"
data_dir = os.path.join(root_dir, "data")
ground_truth_dir = os.path.join(data_dir, "ground_truth")
explanation_dir = os.path.join(
    root_dir, "explanations" if not RANDOMIZE_MODEL else "random_model_explanations"
)
figure_dir = os.path.join(
    data_dir,
    "figures",
    "explanations" if not RANDOMIZE_MODEL else "random_model_explanations",
)
gt_figure_dir = os.path.join(figure_dir, "ground_truth")
os.makedirs(gt_figure_dir, exist_ok=True)

exp_mapper = {
    "gradcam": r"Grad-CAM",
    "lime": r"LIME",
    "gradexp": r"GradientExp",
    "deepexp": r"DeepExp",
    "partexp/500": r"PartitionExp ($m = 500$)",
    "partexp/64": r"PartitionExp ($m = 64$)",
    "partexp/32": r"PartitionExp ($m = 32$)",
    "partexp/16": r"PartitionExp ($m = 16$)",
    "hexp/absolute_0": r"h-Shap ($\tau = 0$)",
    "hexp/relative_70": r"h-Shap ($\tau=70\%$)",
}

true_positives = np.load(
    os.path.join(explanation_dir, "true_positive.npy"), allow_pickle=True
).item()

In [None]:
sns.set_context("paper", font_scale=1.5)

for n in [1, 6]:
    for i, image_path in enumerate(
        tqdm(true_positives[n][:5] if not RANDOMIZE_MODEL else true_positives[n][:4])
    ):
        image_name = os.path.basename(image_path).split(".")[0]
        image = Image.open(os.path.join(root_dir, image_path))

        _, ax = plt.subplots(1, 1)
        ax.imshow(image)
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)
        ax.set_title("image")

        plt.savefig(
            os.path.join(gt_figure_dir, f"{image_name}.jpg"), bbox_inches="tight"
        )
        plt.savefig(
            os.path.join(gt_figure_dir, f"{image_name}.pdf"), bbox_inches="tight"
        )
        plt.close()

        for exp_name, exp_title in exp_mapper.items():
            explainer_dir = os.path.join(explanation_dir, exp_name)
            explainer_figure_dir = os.path.join(figure_dir, exp_name)
            os.makedirs(explainer_figure_dir, exist_ok=True)

            explanation = np.load(os.path.join(explainer_dir, f"{image_name}.npy"))
            _max = np.amax(np.abs(explanation.flatten()))

            _, ax = plt.subplots(1, 1)
            ax.imshow(explanation, cmap="bwr", vmin=-_max, vmax=_max)
            ax.axes.xaxis.set_visible(False)
            ax.axes.yaxis.set_visible(False)
            ax.set_title(exp_title)

            plt.savefig(
                os.path.join(explainer_figure_dir, f"{image_name}.jpg"),
                bbox_inches="tight",
            )
            plt.savefig(
                os.path.join(explainer_figure_dir, f"{image_name}.pdf"),
                bbox_inches="tight",
            )
            plt.close()