# Model Interpretations of Treatment Effect Estimation 

In [None]:
import os
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[0])
os.getcwd()

import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch
import math
import plotly.graph_objects as go

from predimgbmanalysis.eval_biomarkers import *

from predimgbmanalysis.train import ToyModelImgModule
from predimgbmanalysis.get_toydata import CUB2011, ISIC2018
import scipy

import yaml

import tqdm.notebook as tq
import time

from torch import nn
from matplotlib.colors import LinearSegmentedColormap

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import (
    ClassifierOutputTarget,
    RawScoresOutputTarget,
)
from pytorch_grad_cam.utils.image import show_cam_on_image
from captum.attr import (
    GradientShap,
    GuidedGradCam,
    GradientShap,
    GuidedBackprop,
    NoiseTunnel,
    IntegratedGradients,
)
from captum.attr import visualization as vis
from gradcam3D import GradCAM3D

import warnings

warnings.filterwarnings("ignore")

heat_cmap = LinearSegmentedColormap.from_list(
    "heatmap attribution", [(0, "#FF0051"), (0.5, "#ffffff"), (1, "#008BFB")], N=256
)

abs_cmap = LinearSegmentedColormap.from_list(
    "absolute attribution", [(0, "#ffffff"), (1, "#8F37BB")], N=256
)


class TwoHead_Wrapper(nn.Module):

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input):
        output = self.model(input)
        output = torch.hstack(list(output))
        return output


class CATE_Wrapper(nn.Module):
    # difference of two heads

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input):
        output = self.model(input)
        return output[1] - output[0]


def normalize_pos_zero(data):
    return (data - data.min()) / ((data.max() - data.min()) + 0.00000000001)


def normalize_pos_neg(data):
    """
    Normalize a PyTorch tensor of data to the range of -1 to 1.
    Negative values will be normalized between -1 and 0, and positive values between 0 and 1.

    Parameters:
    - data: A PyTorch tensor of numerical values.

    Returns:
    - A PyTorch tensor where negative values are normalized to [-1, 0] and positive values to [0, 1].
    """
    # Handle negative values
    negative_mask = data < 0
    negative_data = data[negative_mask]
    if negative_data.nelement() > 0:
        neg_min = torch.min(negative_data)
        neg_max = torch.max(negative_data)
        # Avoid division by zero
        if neg_min != neg_max:
            normalized_negative_data = (negative_data - neg_min) / (
                neg_max - neg_min
            ) - 1
        else:
            normalized_negative_data = torch.full_like(
                negative_data, -0.5
            )  # Arbitrary choice when all values are equal
        data[negative_mask] = normalized_negative_data

    # Handle positive values
    positive_mask = data > 0
    positive_data = data[positive_mask]
    if positive_data.nelement() > 0:
        pos_min = torch.min(positive_data)
        pos_max = torch.max(positive_data)
        # Avoid division by zero
        if pos_min != pos_max:
            normalized_positive_data = (positive_data - pos_min) / (pos_max - pos_min)
        else:
            normalized_positive_data = torch.full_like(
                positive_data, 0.5
            )  # Arbitrary choice when all values are equal
        data[positive_mask] = normalized_positive_data

    return data

### Paths

In [None]:
network_drive_path = os.getenv("DATASET_LOCATION", "/absolute/path/to/datasets")

dataset_root = {
    "CMNIST": os.path.join(network_drive_path, ""),
    "cub": os.path.join(network_drive_path, "CUB_200_2011"),
    "isic": os.path.join(network_drive_path, "ISIC2018"),
    "lungCT": os.path.join(network_drive_path, "NSCLC_Radiomics"),  
}

experiments = os.getenv("EPXERIMENTS_LOCATION", "/absolute/path/to/experiments")
experiment_dirs = {
    "CMNIST_a": os.path.join(
        experiments, "2022-10-25_toymodel_miniresnetmtl_cmnist3a_02"
    ),
    "CMNIST_b": os.path.join(
        experiments, "2022-10-25_toymodel_miniresnetmtl_cmnist3b_02"
    ),
    "CMNISTvline_a": os.path.join(
        experiments, "2022-10-31_toymodel_miniresnetmtl_cmnist5a"
    ),
    "CMNISTvline_b": os.path.join(
        experiments, "2022-10-31_toymodel_miniresnetmtl_cmnist5b"
    ),
    "cub_a": os.path.join(
        experiments, "2022-10-27_toymodel_resnet18mtl_cub2011allclassesa_01"
    ),
    "cub_b": os.path.join(
        experiments, "2022-10-27_toymodel_resnet18mtl_cub2011allclassesb_01"
    ),
    "isic_a": os.path.join(
        experiments, "2022-10-19_toymodel_resnet18mtl_isic2018a_binary"
    ),
    "isic_b": os.path.join(
        experiments, "2022-10-19_toymodel_resnet18mtl_isic2018b_binary"
    ),
    "isic_a_newarch": os.path.join(
        experiments, "2022-10-28_toymodel_resnet18mtl_isic2018a_binary_05"
    ),
    "isic_b_newarch": os.path.join(
        experiments, "2022-10-28_toymodel_resnet18mtl_isic2018b_binary_05"
    ),
    "lungCT_a_fold0": os.path.join(
        experiments,
        "test_nsclctumourpatchesnnunet_linear_zscore_mtl4fc_a_extendedrandtransformnew_fold0_complete",
    ),
    "lungCT_a_fold1": os.path.join(
        experiments,
        "test_nsclctumourpatchesnnunet_linear_zscore_mtl4fc_a_extendedrandtransformnew_fold1_complete",
    ),
    "lungCT_a_fold2": os.path.join(
        experiments,
        "test_nsclctumourpatchesnnunet_linear_zscore_mtl4fc_a_extendedrandtransformnew_fold2_complete",
    ),
    "lungCT_a_fold3": os.path.join(
        experiments,
        "test_nsclctumourpatchesnnunet_linear_zscore_mtl4fc_a_extendedrandtransformnew_fold3_complete",
    ),
    "lungCT_a_fold4": os.path.join(
        experiments,
        "test_nsclctumourpatchesnnunet_linear_zscore_mtl4fc_a_extendedrandtransformnew_fold4_complete",
    ),
    "lungCT_b_fold0": os.path.join(
        experiments,
        "test_nsclctumourpatchesnnunet_linear_zscore_mtl4fc_b_extendedrandtransformnew_fold0_complete",
    ),
    "lungCT_b_fold1": os.path.join(
        experiments,
        "test_nsclctumourpatchesnnunet_linear_zscore_mtl4fc_b_extendedrandtransformnew_fold1_complete",
    ),
    "lungCT_b_fold2": os.path.join(
        experiments,
        "test_nsclctumourpatchesnnunet_linear_zscore_mtl4fc_b_extendedrandtransformnew_fold2_complete",
    ),
    "lungCT_b_fold3": os.path.join(
        experiments,
        "test_nsclctumourpatchesnnunet_linear_zscore_mtl4fc_b_extendedrandtransformnew_fold3_complete",
    ),
    "lungCT_b_fold4": os.path.join(
        experiments,
        "test_nsclctumourpatchesnnunet_linear_zscore_mtl4fc_b_extendedrandtransformnew_fold4_complete",
    ),
    "lungCT_a": os.path.join(
        experiments,
        "test_nsclctumourpatchesnnunet_linear_zscore_mtl4fc_a_extendedrandtransformnewfulltraincv_complete",
    ),
    "lungCT_b": os.path.join(
        experiments,
        "test_nsclctumourpatchesnnunet_linear_zscore_mtl4fc_b_extendedrandtransformnewfulltraincv_complete",
    ),
}


log_names = {
    "CMNIST_a": [
        "2022-10-27_13-05-44_120_bpred=1.0_bprog=1.0_finalact=None",
        "2022-10-27_04-03-14_96_bpred=0.8_bprog=0.8_finalact=None",
    ],
    "CMNIST_b": [
        "2022-10-26_21-51-30_120_bpred=1.0_bprog=1.0_finalact=None",
        "2022-10-26_15-56-12_96_bpred=0.8_bprog=0.8_finalact=None",
    ],
    "CMNISTvline_a": [
        "2022-11-02_11-35-26_120_bpred=1.0_bprog=1.0_finalact=None",
        "2022-11-01_21-32-49_84_bpred=0.7_bprog=0.7_finalact=None",
        "2022-11-01_12-13-19_60_bpred=0.5_bprog=0.5_finalact=None",
    ],
    "CMNISTvline_b": [
        "2022-11-02_11-47-04_120_bpred=1.0_bprog=1.0_finalact=None",
        "2022-11-01_21-42-48_84_bpred=0.7_bprog=0.7_finalact=None",
        "2022-11-01_12-20-51_60_bpred=0.5_bprog=0.5_finalact=None",
    ],
    "cub_a": [
        "2022-10-30_07-41-48_35_bpred=1.0_bprog=1.0_finalact=None",
        "2022-10-29_21-04-59_28_bpred=0.8_bprog=0.8_finalact=None",
        "2022-10-29_09-31-50_21_bpred=0.6_bprog=0.6_finalact=None",
    ],
    "cub_b": [
        "2022-10-30_08-12-09_35_bpred=1.0_bprog=1.0_finalact=None",
        "2022-10-29_21-35-28_28_bpred=0.8_bprog=0.8_finalact=None",
        "2022-10-29_09-59-09_21_bpred=0.6_bprog=0.6_finalact=None",
    ],
    "isic_a": [
        "2022-10-23_17-37-07_35_bpred=1.0_bprog=1.0_finalact=None",
        "2022-10-23_03-42-00_28_bpred=0.8_bprog=0.8_finalact=None",
        "2022-10-22_12-42-59_21_bpred=0.6_bprog=0.6_finalact=None",
    ],
    "isic_b": [
        "2022-10-23_19-03-41_35_bpred=1.0_bprog=1.0_finalact=None",
        "2022-10-23_05-53-30_28_bpred=0.8_bprog=0.8_finalact=None",
        "2022-10-22_15-42-35_21_bpred=0.6_bprog=0.6_finalact=None",
    ],
    "isic_a_newarch": [
        "2022-11-01_04-19-04_35_bpred=1.0_bprog=1.0_finalact=None",
        "2022-10-31_12-00-21_28_bpred=0.8_bprog=0.8_finalact=None",
        "2022-10-30_19-40-27_21_bpred=0.6_bprog=0.6_finalact=None",
    ],
    "isic_b_newarch": [
        "2022-10-30_07-50-50_35_bpred=1.0_bprog=1.0_finalact=None",
        "2022-10-30_01-22-52_28_bpred=0.8_bprog=0.8_finalact=None",
        "2022-10-29_17-37-53_21_bpred=0.6_bprog=0.6_finalact=None",
    ],
    "lungCT_a_fold0": [
        "2024-01-26_14-28-16_35_bpred=1.0_bprog=1.0_finalact=None_kfold_idx=0",
        "2024-01-25_03-39-05_28_bpred=0.8_bprog=0.8_finalact=None_kfold_idx=0",
    ],
    "lungCT_a_fold1": [
        "2024-01-29_19-17-43_0_bpred=1.0_bprog=1.0_finalact=None_kfold_idx=1",
        "2024-01-30_19-53-21_4_bpred=0.8_bprog=0.8_finalact=None_kfold_idx=1",
    ],
    "lungCT_a_fold2": [
        "2024-01-25_19-45-39_35_bpred=1.0_bprog=1.0_finalact=None_kfold_idx=2",
        "2024-01-24_13-26-30_28_bpred=0.8_bprog=0.8_finalact=None_kfold_idx=2",
    ],  
    "lungCT_a_fold3": [
        "2024-01-26_16-34-54_35_bpred=1.0_bprog=1.0_finalact=None_kfold_idx=3",
        "2024-01-25_04-24-35_28_bpred=0.8_bprog=0.8_finalact=None_kfold_idx=3",
    ],  
    "lungCT_a_fold4": [
        "2024-01-28_06-23-02_5_bpred=1.0_bprog=1.0_finalact=None_kfold_idx=4",
        "2024-01-28_20-20-36_0_bpred=0.8_bprog=0.8_finalact=None_kfold_idx=4",
    ],
    "lungCT_b_fold0": [
        "2024-01-25_19-48-28_35_bpred=1.0_bprog=1.0_finalact=None_kfold_idx=0",
        "2024-01-24_13-29-39_28_bpred=0.8_bprog=0.8_finalact=None_kfold_idx=0",
    ],
    "lungCT_b_fold1": [
        "2024-01-31_04-00-52_17_bpred=1.0_bprog=1.0_finalact=None_kfold_idx=1",
        "2024-01-30_00-06-48_13_bpred=0.8_bprog=0.8_finalact=None_kfold_idx=1",
    ],
    "lungCT_b_fold2": [
        "2024-01-29_14-41-45_0_bpred=1.0_bprog=1.0_finalact=None_kfold_idx=2",
        "2024-01-30_14-13-38_4_bpred=0.8_bprog=0.8_finalact=None_kfold_idx=2",
    ],
    "lungCT_b_fold3": [
        "2024-01-31_14-17-02_11_bpred=1.0_bprog=1.0_finalact=None_kfold_idx=3",
        "2024-01-30_20-09-55_8_bpred=0.8_bprog=0.8_finalact=None_kfold_idx=3",
    ],  
    "lungCT_b_fold4": [
        "2024-02-02_05-20-30_17_bpred=1.0_bprog=1.0_finalact=None_kfold_idx=4",
        "2024-02-01_04-58-41_13_bpred=0.8_bprog=0.8_finalact=None_kfold_idx=4",
    ],  
    "lungCT_a": [
        "2024-01-26_22-15-32_0_bpred=1.0_bprog=1.0_finalact=None_kfold_idx=None",
        "2024-01-28_16-10-15_4_bpred=0.8_bprog=0.8_finalact=None_kfold_idx=None",
    ],
    "lungCT_b": [
        "2024-01-26_21-54-06_0_bpred=1.0_bprog=1.0_finalact=None_kfold_idx=None",
        "2024-01-29_01-23-41_4_bpred=0.8_bprog=0.8_finalact=None_kfold_idx=None",
    ],
}

## Natural Image Datasets:
### Coloured MNIST

In [None]:
name = "CMNIST_a"
data_name = "CMNIST"
model_CMNIST_a, dl_CMNIST_a, bprog_CMNIST_a, bpred_CMNIST_a = (
    get_interpretation(  # Two head
        experiment_dir=experiment_dirs[name],
        use_cuda=False,
        n_batch=1000,
        log_name=log_names[name][1],
        get_saliency_maps=False,
        dataset_root=dataset_root[data_name],
        env="test",
    )
)

name = "CMNIST_b"
data_name = "CMNIST"
model_CMNIST_b, dl_CMNIST_b, bprog_CMNIST_b, bpred_CMNIST_b = (
    get_interpretation(  # Cate
        experiment_dir=experiment_dirs[name],
        use_cuda=False,
        n_batch=100,
        log_name=log_names[name][1],
        get_saliency_maps=False,
        dataset_root=dataset_root[data_name],
        env="test",
    )
)


data_CMNIST = next(iter(dl_CMNIST_a))
data_CMNIST = data_CMNIST[0]

In [None]:
n = 17  # 4, 12, 61, 10
nt_type = "smoothgrad"
nt_samples = 100
nt_samples_batch_size = 10
stdevs = 0.1

# Exp. Gradients
attr_eg_cmnist = []

for model in (model_CMNIST_a, model_CMNIST_b):
    model.cpu()
    attr_method = GradientShap(TwoHead_Wrapper(model))
    for target in (0, 1):
        attr_values = attr_method.attribute(
            data_CMNIST[n].unsqueeze(0),
            n_samples=150,
            stdevs=stdevs,
            baselines=data_CMNIST,
            target=target,
        )
        attr_eg_cmnist.append(attr_values.squeeze(0))

for model in (model_CMNIST_a, model_CMNIST_b):
    attr_method = GradientShap(CATE_Wrapper(model))
    attr_values = attr_method.attribute(
        data_CMNIST[n].unsqueeze(0),
        n_samples=150,
        stdevs=stdevs,
        baselines=data_CMNIST,
        target=0,
    )
    attr_eg_cmnist.append(attr_values.squeeze(0))

attr_eg_cmnist = [attr_eg_cmnist[i] for i in [0, 1, 4, 2, 3, 5]]

# Int. Gradients

attr_ig_cmnist = []

for model in (model_CMNIST_a, model_CMNIST_b):
    attr_method = NoiseTunnel(IntegratedGradients(TwoHead_Wrapper(model)))
    for target in (0, 1):
        attr_values = attr_method.attribute(
            data_CMNIST[n].unsqueeze(0),
            target=target,
            baselines=-0.4242 + 0.0001,
            n_steps=50,
            stdevs=stdevs,
            nt_type=nt_type,
            nt_samples=nt_samples,
            nt_samples_batch_size=nt_samples_batch_size,
        )
        attr_ig_cmnist.append(attr_values.squeeze(0))

for model in (model_CMNIST_a, model_CMNIST_b):
    attr_method = NoiseTunnel(IntegratedGradients(CATE_Wrapper(model)))
    attr_values = attr_method.attribute(
        data_CMNIST[n].unsqueeze(0),
        target=0,
        baselines=-0.4242 + 0.0001,
        n_steps=20,
        stdevs=stdevs,
        nt_type=nt_type,
        nt_samples=nt_samples,
        nt_samples_batch_size=nt_samples_batch_size,
    )
    attr_ig_cmnist.append(attr_values.squeeze(0))

attr_ig_cmnist = [attr_ig_cmnist[i] for i in [0, 1, 4, 2, 3, 5]]

# Guided GradCAM

attr_gcam_cmnist = []

for model in (model_CMNIST_a, model_CMNIST_b):
    attr_method = GuidedGradCam(
        TwoHead_Wrapper(model), TwoHead_Wrapper(model).model.model.layer1[0]
    )
    for target in (0, 1):
        attr_values = attr_method.attribute(
            data_CMNIST[n].unsqueeze(0),
            target=target,
        )
        attr_gcam_cmnist.append(attr_values.squeeze(0))

for model in (model_CMNIST_a, model_CMNIST_b):
    attr_method = GuidedGradCam(
        CATE_Wrapper(model), CATE_Wrapper(model).model.model.layer1[0]
    )
    attr_values = attr_method.attribute(
        data_CMNIST[n].unsqueeze(0),
        target=0,
    )
    attr_gcam_cmnist.append(attr_values.squeeze(0))

attr_gcam_cmnist = [attr_gcam_cmnist[i] for i in [0, 1, 4, 2, 3, 5]]

# Guided Backprob

attr_gbp_cmnist = []

for model in (model_CMNIST_a, model_CMNIST_b):
    attr_method = GuidedBackprop(TwoHead_Wrapper(model))
    for target in (0, 1):
        attr_values = attr_method.attribute(
            data_CMNIST[n].unsqueeze(0),
            target=target,
        )
        attr_gbp_cmnist.append(attr_values.squeeze(0))

for model in (model_CMNIST_a, model_CMNIST_b):
    attr_method = GuidedBackprop(CATE_Wrapper(model))
    attr_values = attr_method.attribute(
        data_CMNIST[n].unsqueeze(0),
        target=0,
    )
    attr_gbp_cmnist.append(attr_values.squeeze(0))

attr_gbp_cmnist = [attr_gbp_cmnist[i] for i in [0, 1, 4, 2, 3, 5]]

In [None]:
img = np.transpose(data_CMNIST[n].cpu().detach().numpy(), (1, 2, 0))
img = (img - img.min()) / (img.max() - img.min())
channel = ["red", "green", "blue"]
attr_type = ["EG", "IG", "GGCAM", "GBP"]
model_type = [
    "TwoHead_0_a",
    "TwoHead_1_a",
    "CATE_a",
    "TwoHead_0_b",
    "TwoHead_1_b",
    "CATE_b",
]

if not os.path.exists("./Images/cmnist/" + str(n) + "/"):
    os.makedirs("./Images/cmnist/" + str(n) + "/")

plt.imshow(img)
plt.axis("off")
plt.tight_layout()
plt._original_dpi = 200

plt.savefig(
    "./Images/cmnist/" + str(n) + "/original.png", bbox_inches="tight", pad_inches=0
)
plt.close()


for idx, attr in enumerate(
    [attr_eg_cmnist, attr_ig_cmnist, attr_gcam_cmnist, attr_gbp_cmnist]
):
    for model in range(6):
        for rgb in range(3):
            fig, axis = vis.visualize_image_attr_multiple(
                attr[model][rgb].unsqueeze(-1).cpu().detach().numpy(),
                img,
                ["heat_map"],
                ["all"],
                cmap=heat_cmap,
                show_colorbar=False,
                use_pyplot=False,
            )

            managed_fig = plt.figure()
            canvas_manager = managed_fig.canvas.manager
            canvas_manager.canvas.figure = fig
            fig.set_canvas(canvas_manager.canvas)
            fig._original_dpi = 200

            plt.savefig(
                "./Images/cmnist/"
                + str(n)
                + "/"
                + str(attr_type[idx])
                + "_"
                + model_type[model]
                + "_"
                + channel[rgb]
                + ".png",
                bbox_inches="tight",
                pad_inches=0,
            )
            plt.close()

### CUB 2011

In [None]:
name = "cub_a"
data_name = "cub"
model_cub_a, dl_cub_a, bprog_cub_a, bpred_cub_a = get_interpretation(
    experiment_dir=experiment_dirs[name],
    use_cuda=True,
    n_batch=1100,
    log_name=log_names[name][0],
    get_saliency_maps=False,
    dataset_root=dataset_root[data_name],
    env="val",
)

name = "cub_b"
data_name = "cub"
model_cub_b, dl_cub_b, bprog_cub_b, bpred_cub_b = get_interpretation(
    experiment_dir=experiment_dirs[name],
    use_cuda=True,
    n_batch=100,
    log_name=log_names[name][0],
    get_saliency_maps=False,
    dataset_root=dataset_root[data_name],
    env="val",
)

data_cub = next(iter(dl_cub_a))
label_cub = data_cub[1]
data_cub = data_cub[0].to("cuda:0")

In [None]:
n = 775  # 335 #1015 #8 22
nt_type = "smoothgrad"
nt_samples = 50
nt_samples_batch_size = 5


# Exp. Gradients
attr_eg_cub = []

for model in (model_cub_a, model_cub_b):
    attr_method = GradientShap(TwoHead_Wrapper(model))
    for target in (0, 1):

        attr_values = attr_method.attribute(
            data_cub[n].unsqueeze(0),
            n_samples=300,
            # stdevs=0.001,
            baselines=data_cub,
            target=target,
            stdevs=0.2,
        )
        attr_eg_cub.append(attr_values.squeeze(0))

for model in (model_cub_a, model_cub_b):
    attr_method = GradientShap(CATE_Wrapper(model))
    attr_values = attr_method.attribute(
        data_cub[n].unsqueeze(0),
        n_samples=300,
        # stdevs=0.001,
        baselines=data_cub,
        target=0,
        stdevs=0.2,
    )
    attr_eg_cub.append(attr_values.squeeze(0))

attr_eg_cub = [attr_eg_cub[i] for i in [0, 1, 4, 2, 3, 5]]

# Int. Gradients

attr_ig_cub = []

for model in (model_cub_a, model_cub_b):
    attr_method = NoiseTunnel(IntegratedGradients(TwoHead_Wrapper(model)))
    for target in (0, 1):
        attr_values = attr_method.attribute(
            data_cub[n].unsqueeze(0),
            target=target,
            baselines=(torch.ones_like(data_cub[n]) * 0).unsqueeze(0),
            stdevs=0.2,
            nt_type=nt_type,
            nt_samples=nt_samples,
            nt_samples_batch_size=nt_samples_batch_size,
        )
        attr_ig_cub.append(attr_values.squeeze(0))

for model in (model_cub_a, model_cub_b):
    attr_method = NoiseTunnel(IntegratedGradients(CATE_Wrapper(model)))
    attr_values = attr_method.attribute(
        data_cub[n].unsqueeze(0),
        target=0,
        baselines=(torch.ones_like(data_cub[n]) * 0).unsqueeze(0),
        stdevs=0.2,
        nt_type=nt_type,
        nt_samples=nt_samples,
        nt_samples_batch_size=nt_samples_batch_size,
    )
    attr_ig_cub.append(attr_values.squeeze(0))

attr_ig_cub = [attr_ig_cub[i] for i in [0, 1, 4, 2, 3, 5]]

# GradCAM

attr_gcam_cub = []
for model in (model_cub_a, model_cub_b):
    attr_method = GradCAM(
        model=TwoHead_Wrapper(model),
        target_layers=[TwoHead_Wrapper(model).model.model.layer4[-1]],
    )
    for target in (0, 1):
        attr_values = attr_method(
            input_tensor=data_cub[n].unsqueeze(0),
            targets=[ClassifierOutputTarget(target)],
        )
        attr_gcam_cub.append(attr_values.squeeze(0))

for model in (model_cub_a, model_cub_b):
    attr_method = GradCAM(
        model=CATE_Wrapper(model),
        target_layers=[CATE_Wrapper(model).model.model.layer4[-1]],
    )
    attr_values = attr_method(
        input_tensor=data_cub[n].unsqueeze(0), targets=[RawScoresOutputTarget()]
    )
    attr_gcam_cub.append(attr_values.squeeze(0))

attr_gcam_cub = [attr_gcam_cub[i] for i in [0, 1, 4, 2, 3, 5]]

# Guided Backprob

attr_gbp_cub = []
for model in (model_cub_a, model_cub_b):
    attr_method = NoiseTunnel(GuidedBackprop(TwoHead_Wrapper(model)))
    for target in (0, 1):
        attr_values = attr_method.attribute(
            data_cub[n].unsqueeze(0),
            target=target,
            stdevs=0.2,
            nt_type=nt_type,
            nt_samples=nt_samples,
            nt_samples_batch_size=nt_samples_batch_size,
        )
        attr_gbp_cub.append(attr_values.squeeze(0))

for model in (model_cub_a, model_cub_b):
    attr_method = NoiseTunnel(GuidedBackprop(CATE_Wrapper(model)))
    attr_values = attr_method.attribute(
        data_cub[n].unsqueeze(0),
        target=0,
        stdevs=0.2,
        nt_type=nt_type,
        nt_samples=nt_samples,
        nt_samples_batch_size=nt_samples_batch_size,
    )
    attr_gbp_cub.append(attr_values.squeeze(0))

attr_gbp_cub = [attr_gbp_cub[i] for i in [0, 1, 4, 2, 3, 5]]


# Guided GradCAM
attr_ggcam_cub = []

for model in (model_cub_a, model_cub_b):
    attr_method = NoiseTunnel(
        GuidedGradCam(TwoHead_Wrapper(model), TwoHead_Wrapper(model).model.model.conv)
    )
    for target in (0, 1):
        attr_values = attr_method.attribute(
            data_cub[n].unsqueeze(0),
            target=target,
            stdevs=0.2,
            nt_type=nt_type,
            nt_samples=nt_samples,
            nt_samples_batch_size=nt_samples_batch_size,
        )

        attr_ggcam_cub.append(attr_values.squeeze(0))

for model in (model_cub_a, model_cub_b):
    attr_method = NoiseTunnel(
        GuidedGradCam(CATE_Wrapper(model), CATE_Wrapper(model).model.model.conv)
    )
    attr_values = attr_method.attribute(
        data_cub[n].unsqueeze(0),
        target=0,
        stdevs=0.2,
        nt_type=nt_type,
        nt_samples=nt_samples,
        nt_samples_batch_size=nt_samples_batch_size,
    )
    attr_ggcam_cub.append(attr_values.squeeze(0))

attr_ggcam_cub = [attr_ggcam_cub[i] for i in [0, 1, 4, 2, 3, 5]]

In [None]:
img = np.transpose(data_cub[n].cpu().detach().numpy(), (1, 2, 0))
img = (img - img.min()) / (img.max() - img.min())
attr_type = ["EG", "IG", "GCAM", "GBP", "GGCAM"]
model_type = [
    "TwoHead_0_a",
    "TwoHead_1_a",
    "CATE_a",
    "TwoHead_0_b",
    "TwoHead_1_b",
    "CATE_b",
]

if not os.path.exists("./Images/cub2011/" + str(n) + "/"):
    os.makedirs("./Images/cub2011/" + str(n) + "/")

plt.imshow(img)
plt.axis("off")
plt.tight_layout()
plt._original_dpi = 200

plt.savefig(
    "./Images/cub2011/" + str(n) + "/original.png", bbox_inches="tight", pad_inches=0
)
plt.close()

norm = [0.0005, 0.0005, None, 0.0005, 0.0]

for idx, attr in enumerate(
    [attr_eg_cub, attr_ig_cub, attr_gcam_cub, attr_gbp_cub, attr_ggcam_cub]
):
    for model in range(6):
        if idx == 2:
            gcam_vis = show_cam_on_image(img, attr[model], use_rgb=True)
            plt.imshow(gcam_vis)
            plt.axis("off")
            plt.tight_layout()
            plt._original_dpi = 200

            plt.savefig(
                "./Images/cub2011/"
                + str(n)
                + "/"
                + str(attr_type[idx])
                + "_"
                + model_type[model]
                + ".png",
                bbox_inches="tight",
                pad_inches=0,
            )
            plt.close()
        else:
            fig, axis = vis.visualize_image_attr_multiple(
                np.where(
                    np.transpose(np.abs(attr[model].cpu().detach().numpy()), (1, 2, 0))
                    < norm[idx],
                    0,
                    np.transpose(attr[model].cpu().detach().numpy(), (1, 2, 0)),
                ),
                img,
                ["heat_map"],
                ["all"],
                cmap=heat_cmap,
                show_colorbar=False,
                use_pyplot=False,
            )

            managed_fig = plt.figure()
            canvas_manager = managed_fig.canvas.manager
            canvas_manager.canvas.figure = fig
            fig.set_canvas(canvas_manager.canvas)
            fig._original_dpi = 200

            plt.savefig(
                "./Images/cub2011/"
                + str(n)
                + "/"
                + str(attr_type[idx])
                + "_"
                + model_type[model]
                + ".png",
                bbox_inches="tight",
                pad_inches=0,
            )
            plt.close()

## Medical Image Dataset

### ISIC 2018 Skin Lesions

In [None]:
name = "isic_a"
# name = "isic_a_newarch"
data_name = "isic"
model_isic_a, dl_isic_a, bprog_isic_a, bpred_isic_a = get_interpretation(
    experiment_dir=experiment_dirs[name],
    use_cuda=False,
    n_batch=300,
    log_name=log_names[name][1],
    get_saliency_maps=False,
    dataset_root=dataset_root[data_name],
    env="val",
)  # use validation data for preliminary experiments

name = "isic_b"
# name = "isic_b_newarch"
data_name = "isic"
model_isic_b, dl_isic_b, bprog_isic_b, bpred_isic_b = get_interpretation(
    experiment_dir=experiment_dirs[name],
    use_cuda=False,
    n_batch=150,
    log_name=log_names[name][1],
    get_saliency_maps=False,
    dataset_root=dataset_root[data_name],
    env="val",
)  # use validation data for preliminary experiments


data_isic = next(iter(dl_isic_a))
label_isic = data_isic[1]
data_isic = data_isic[0].to("cuda:0")

In [None]:
n = 13  # 77 39  76  101 80
nt_type = "smoothgrad"
nt_samples = 50
nt_samples_batch_size = 5


# Exp. Gradients
attr_eg_isic = []

for model in (model_isic_a, model_isic_b):
    attr_method = GradientShap(TwoHead_Wrapper(model))
    for target in (0, 1):

        attr_values = attr_method.attribute(
            data_isic[n].unsqueeze(0),
            n_samples=300,
            # stdevs=0.001,
            baselines=data_isic,
            target=target,
            stdevs=0.0001,
        )
        attr_eg_isic.append(attr_values.squeeze(0))

for model in (model_isic_a, model_isic_b):
    attr_method = GradientShap(CATE_Wrapper(model))
    attr_values = attr_method.attribute(
        data_isic[n].unsqueeze(0),
        n_samples=300,
        # stdevs=0.001,
        baselines=data_isic,
        target=0,
        stdevs=0.0001,
    )
    attr_eg_isic.append(attr_values.squeeze(0))

attr_eg_isic = [attr_eg_isic[i] for i in [0, 1, 4, 2, 3, 5]]

# Int. Gradients

attr_ig_isic = []

for model in (model_isic_a, model_isic_b):
    attr_method = NoiseTunnel(IntegratedGradients(TwoHead_Wrapper(model)))
    for target in (0, 1):
        attr_values = attr_method.attribute(
            data_isic[n].unsqueeze(0),
            target=target,
            baselines=(torch.ones_like(data_isic[n]) * 0).unsqueeze(0),
            stdevs=0.2,
            nt_type=nt_type,
            nt_samples=nt_samples,
            nt_samples_batch_size=nt_samples_batch_size,
        )
        attr_ig_isic.append(attr_values.squeeze(0))

for model in (model_isic_a, model_isic_b):
    attr_method = NoiseTunnel(IntegratedGradients(CATE_Wrapper(model)))
    attr_values = attr_method.attribute(
        data_isic[n].unsqueeze(0),
        target=0,
        baselines=(torch.ones_like(data_isic[n]) * 0).unsqueeze(0),
        stdevs=0.2,
        nt_type=nt_type,
        nt_samples=nt_samples,
        nt_samples_batch_size=nt_samples_batch_size,
    )
    attr_ig_isic.append(attr_values.squeeze(0))

attr_ig_isic = [attr_ig_isic[i] for i in [0, 1, 4, 2, 3, 5]]

# GradCAM

attr_gcam_isic = []
for model in (model_isic_a, model_isic_b):
    attr_method = GradCAM(
        model=TwoHead_Wrapper(model),
        target_layers=[TwoHead_Wrapper(model).model.model.layer4[-1]],
    )
    for target in (0, 1):
        attr_values = attr_method(
            input_tensor=data_isic[n].unsqueeze(0),
            targets=[ClassifierOutputTarget(target)],
        )
        attr_gcam_isic.append(attr_values.squeeze(0))

for model in (model_isic_a, model_isic_b):
    attr_method = GradCAM(
        model=CATE_Wrapper(model),
        target_layers=[CATE_Wrapper(model).model.model.layer4[-1]],
    )
    attr_values = attr_method(
        input_tensor=data_isic[n].unsqueeze(0), targets=[RawScoresOutputTarget()]
    )
    attr_gcam_isic.append(attr_values.squeeze(0))

attr_gcam_isic = [attr_gcam_isic[i] for i in [0, 1, 4, 2, 3, 5]]

# Guided Backprob
attr_gbp_isic = []
for model in (model_isic_a, model_isic_b):
    attr_method = NoiseTunnel(GuidedBackprop(TwoHead_Wrapper(model)))
    for target in (0, 1):
        attr_values = attr_method.attribute(
            data_isic[n].unsqueeze(0),
            target=target,
            stdevs=0.2,
            nt_type=nt_type,
            nt_samples=nt_samples,
            nt_samples_batch_size=nt_samples_batch_size,
        )
        attr_gbp_isic.append(attr_values.squeeze(0))

for model in (model_isic_a, model_isic_b):
    attr_method = NoiseTunnel(GuidedBackprop(CATE_Wrapper(model)))
    attr_values = attr_method.attribute(
        data_isic[n].unsqueeze(0),
        target=0,
        stdevs=0.2,
        nt_type=nt_type,
        nt_samples=nt_samples,
        nt_samples_batch_size=nt_samples_batch_size,
    )
    attr_gbp_isic.append(attr_values.squeeze(0))

attr_gbp_isic = [attr_gbp_isic[i] for i in [0, 1, 4, 2, 3, 5]]


# Guided GradCAM
attr_ggcam_isic = []

for model in (model_isic_a, model_isic_b):
    attr_method = NoiseTunnel(
        GuidedGradCam(TwoHead_Wrapper(model), TwoHead_Wrapper(model).model.model.conv)
    )
    for target in (0, 1):
        attr_values = attr_method.attribute(
            data_isic[n].unsqueeze(0),
            target=target,
            stdevs=0.0001,
            nt_type=nt_type,
            nt_samples=nt_samples,
            nt_samples_batch_size=nt_samples_batch_size,
        )

        attr_ggcam_isic.append(attr_values.squeeze(0))

for model in (model_isic_a, model_isic_b):
    attr_method = NoiseTunnel(
        GuidedGradCam(CATE_Wrapper(model), CATE_Wrapper(model).model.model.conv)
    )
    attr_values = attr_method.attribute(
        data_isic[n].unsqueeze(0),
        target=0,
        stdevs=0.0001,
        nt_type=nt_type,
        nt_samples=nt_samples,
        nt_samples_batch_size=nt_samples_batch_size,
    )
    attr_ggcam_isic.append(attr_values.squeeze(0))

attr_ggcam_isic = [attr_ggcam_isic[i] for i in [0, 1, 4, 2, 3, 5]]

In [None]:
img = np.transpose(data_isic[n].cpu().detach().numpy(), (1, 2, 0))
img = (img - img.min()) / (img.max() - img.min())
attr_type = ["EG", "IG", "GCAM", "GBP", "GGCAM"]
model_type = [
    "TwoHead_0_a",
    "TwoHead_1_a",
    "CATE_a",
    "TwoHead_0_b",
    "TwoHead_1_b",
    "CATE_b",
]

if not os.path.exists("./Images/isic/" + str(n) + "/"):
    os.makedirs("./Images/isic/" + str(n) + "/")

plt.imshow(img)
plt.axis("off")
plt.tight_layout()
plt._original_dpi = 200

plt.savefig(
    "./Images/isic/" + str(n) + "/original.png", bbox_inches="tight", pad_inches=0
)
plt.close()

norm = [0.0007, 0.0007, None, 0.0007, 0.0]

for idx, attr in enumerate(
    [attr_eg_isic, attr_ig_isic, attr_gcam_isic, attr_gbp_isic, attr_ggcam_isic]
):
    for model in range(6):
        if idx == 2:
            gcam_vis = show_cam_on_image(img, attr[model], use_rgb=True)
            plt.imshow(gcam_vis)
            plt.axis("off")
            plt.tight_layout()
            plt._original_dpi = 200

            plt.savefig(
                "./Images/isic/"
                + str(n)
                + "/"
                + str(attr_type[idx])
                + "_"
                + model_type[model]
                + ".png",
                bbox_inches="tight",
                pad_inches=0,
            )
            plt.close()
        else:
            fig, axis = vis.visualize_image_attr_multiple(
                np.where(
                    np.transpose(np.abs(attr[model].cpu().detach().numpy()), (1, 2, 0))
                    < norm[idx],
                    0,
                    np.transpose(attr[model].cpu().detach().numpy(), (1, 2, 0)),
                ),
                img,
                ["heat_map"],
                ["all"],
                cmap=heat_cmap,
                show_colorbar=False,
                use_pyplot=False,
            )

            managed_fig = plt.figure()
            canvas_manager = managed_fig.canvas.manager
            canvas_manager.canvas.figure = fig
            fig.set_canvas(canvas_manager.canvas)
            fig._original_dpi = 200

            plt.savefig(
                "./Images/isic/"
                + str(n)
                + "/"
                + str(attr_type[idx])
                + "_"
                + model_type[model]
                + ".png",
                bbox_inches="tight",
                pad_inches=0,
            )
            plt.close()

### Lung CT

In [None]:
name = "lungCT_a"
data_name = "lungCT"
model_lungCT_a, dl_lungCT_a, bprog_lungCT_a, bpred_lungCT_a, dl_lungCT_seg_a = (
    get_interpretation(
        experiment_dir=experiment_dirs[name],
        use_cuda=False,
        n_batch=83,
        log_name=log_names[name][1],
        get_saliency_maps=False,
        dataset_root=dataset_root[data_name],
        env="test",
        get_segmentation_dl=True,
        checkpoint_num=-1,
    )
)

name = "lungCT_b"
data_name = "lungCT"
model_lungCT_b, dl_lungCT_b, bprog_lungCT_b, bpred_lungCT_b = get_interpretation(
    experiment_dir=experiment_dirs[name],
    use_cuda=False,
    n_batch=20,
    log_name=log_names[name][1],
    get_saliency_maps=False,
    dataset_root=dataset_root[data_name],
    env="test",
    checkpoint_num=-1,
)


seg_lungCT = next(iter(dl_lungCT_seg_a)).numpy()
data_lungCT = next(iter(dl_lungCT_a))
data_lungCT = torch.Tensor(data_lungCT[0]).to("cuda:0")

In [None]:
n = 8
nt_type = "smoothgrad"
nt_samples = 10
nt_samples_batch_size = 2
stdevs = 80.0

# Exp. Gradients
attr_eg_lungCT = []

for model in (model_lungCT_a, model_lungCT_b):
    attr_method = GradientShap(TwoHead_Wrapper(model))
    for target in (0, 1):

        attr_values = attr_method.attribute(
            data_lungCT[n].unsqueeze(0),
            n_samples=50,
            # stdevs=0.001,
            baselines=data_lungCT,
            target=target,
            stdevs=stdevs,
        )
        attr_eg_lungCT.append(normalize_pos_neg(attr_values.squeeze(0)))

for model in (model_lungCT_a, model_lungCT_b):
    attr_method = GradientShap(CATE_Wrapper(model))
    attr_values = attr_method.attribute(
        data_lungCT[n].unsqueeze(0),
        n_samples=50,
        # stdevs=0.001,
        baselines=data_lungCT,
        target=0,
        stdevs=stdevs,
    )
    attr_eg_lungCT.append(normalize_pos_neg(attr_values.squeeze(0)))

attr_eg_lungCT = [attr_eg_lungCT[i] for i in [0, 1, 4, 2, 3, 5]]

# Int. Gradients

attr_ig_lungCT = []
from captum.attr import Saliency

for model in (model_lungCT_a, model_lungCT_b):
    attr_method = NoiseTunnel(IntegratedGradients(TwoHead_Wrapper(model)))
    for target in (0, 1):
        attr_values = attr_method.attribute(
            data_lungCT[n].unsqueeze(0),
            target=target,
            baselines=-1150,
            n_steps=20,
            stdevs=stdevs,
            nt_type=nt_type,
            nt_samples=nt_samples,
            nt_samples_batch_size=nt_samples_batch_size,
        )
        attr_ig_lungCT.append(normalize_pos_neg(attr_values.squeeze(0)))

for model in (model_lungCT_a, model_lungCT_b):
    attr_method = NoiseTunnel(IntegratedGradients(CATE_Wrapper(model)))
    attr_values = attr_method.attribute(
        data_lungCT[n].unsqueeze(0),
        target=0,
        baselines=-1150,
        n_steps=20,
        stdevs=stdevs,
        nt_type=nt_type,
        nt_samples=nt_samples,
        nt_samples_batch_size=nt_samples_batch_size,
    )
    attr_ig_lungCT.append(normalize_pos_neg(attr_values.squeeze(0)))

attr_ig_lungCT = [attr_ig_lungCT[i] for i in [0, 1, 4, 2, 3, 5]]

# GradCAM

attr_gcam_lungCT = []
for model in (model_lungCT_a, model_lungCT_b):
    attr_method = GradCAM3D(
        model=TwoHead_Wrapper(model),
        target_layers=[
            TwoHead_Wrapper(model).model.model.layer1,
            TwoHead_Wrapper(model).model.model.layer2,
            TwoHead_Wrapper(model).model.model.layer3,
        ],
    )
    for target in (0, 1):
        attr_values = attr_method(
            input_tensor=data_lungCT[n].unsqueeze(0),
            targets=[ClassifierOutputTarget(target)],
        )
        attr_gcam_lungCT.append(normalize_pos_zero(attr_values.squeeze(0)))

for model in (model_lungCT_a, model_lungCT_b):
    attr_method = GradCAM3D(
        model=CATE_Wrapper(model),
        target_layers=[
            CATE_Wrapper(model).model.model.layer1,
            CATE_Wrapper(model).model.model.layer2,
            CATE_Wrapper(model).model.model.layer3,
        ],
    )
    attr_values = attr_method(
        input_tensor=data_lungCT[n].unsqueeze(0), targets=[RawScoresOutputTarget()]
    )
    attr_gcam_lungCT.append(normalize_pos_zero(attr_values.squeeze(0)))

attr_gcam_lungCT = [attr_gcam_lungCT[i] for i in [0, 1, 4, 2, 3, 5]]

# Guided Backprob
attr_gbp_lungCT = []
for model in (model_lungCT_a, model_lungCT_b):
    attr_method = NoiseTunnel(GuidedBackprop(TwoHead_Wrapper(model)))
    for target in (0, 1):
        attr_values = attr_method.attribute(
            data_lungCT[n].unsqueeze(0),
            target=target,
            stdevs=stdevs,
            nt_type=nt_type,
            nt_samples=nt_samples,
            nt_samples_batch_size=nt_samples_batch_size,
        )
        attr_gbp_lungCT.append(normalize_pos_neg(attr_values.squeeze(0)))

for model in (model_lungCT_a, model_lungCT_b):
    attr_method = NoiseTunnel(GuidedBackprop(CATE_Wrapper(model)))
    attr_values = attr_method.attribute(
        data_lungCT[n].unsqueeze(0),
        target=0,
        stdevs=stdevs,
        nt_type=nt_type,
        nt_samples=nt_samples,
        nt_samples_batch_size=nt_samples_batch_size,
    )
    attr_gbp_lungCT.append(normalize_pos_neg(attr_values.squeeze(0)))

attr_gbp_lungCT = [attr_gbp_lungCT[i] for i in [0, 1, 4, 2, 3, 5]]

In [None]:
import cv2

attr_type = [
    "EG",
    "IG",
    "GCAM",
    "GGCAM",
]
model_type = [
    "TwoHead_0_a",
    "TwoHead_1_a",
    "CATE_a",
    "TwoHead_0_b",
    "TwoHead_1_b",
    "CATE_b",
]

if not os.path.exists("./Images/lungCT/" + str(n) + "/"):
    os.makedirs("./Images/lungCT/" + str(n) + "/")

norm = [0.2, 0.15, None, 0.05]
vmin = -1150
vmax = 550

attr_ggcam_lungCT = [
    normalize_pos_neg(torch.Tensor(np.moveaxis(a[None, :], 3, 1)) * b.cpu())
    for a, b in zip(attr_gcam_lungCT, attr_gbp_lungCT)
]

for idx, attr in enumerate(
    [attr_eg_lungCT, attr_ig_lungCT, attr_gcam_lungCT, attr_ggcam_lungCT]
):
    img = np.transpose(data_lungCT[n, 0].cpu().detach().numpy(), (1, 2, 0))
    img = (img - vmin) / (vmax - vmin)
    img = np.clip(img, 0, 1)

    first = np.where(img.mean((0, 1)) > img.mean((0, 1))[0])[0][0]
    last = np.where(img.mean((0, 1)) > img.mean((0, 1))[0])[0][-1]
    dist = last - first
    step = 2 if dist <= 10 else 4

    for model in range(6):
        if idx == 2:
            fig, axes = plt.subplots(
                2,
                int(np.ceil((dist) / step)),
                figsize=(int(np.ceil((dist) / step)) * 2, 3),
                sharey=True,
                sharex=True,
            )

            for i, slice in enumerate(list(range(first, last, step))):
                img = np.transpose(
                    data_lungCT[n, :, slice].cpu().detach().numpy(), (1, 2, 0)
                )
                img = (img - vmin) / (vmax - vmin)
                img = np.clip(img, 0, 1)
                img = np.repeat(img, 3, axis=2)

                gcam_vis = show_cam_on_image(
                    img, attr[model][:, :, slice], use_rgb=True
                )

                if seg_lungCT[n, 0, slice].sum() > 0:
                    contours, im = cv2.findContours(
                        seg_lungCT[n, 0, slice], cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
                    )
                    cv2.drawContours(img, contours, -1, (1, 0.525, 0), 2)

                axes[0, i].imshow(img)
                axes[0, i].axis("off")
                axes[0, i].set_title("Slice: " + str(slice), fontsize=10)
                axes[1, i].imshow(gcam_vis)
                axes[1, i].axis("off")

            plt.tight_layout()
        else:
            fig, axes = plt.subplots(
                2,
                int(np.ceil((dist) / step)),
                figsize=(int(np.ceil((dist) / step)) * 2, 3),
                sharey=True,
                sharex=True,
            )

            for i, slice in enumerate(list(range(first, last, step))):
                img = np.transpose(
                    data_lungCT[n, :, slice].cpu().detach().numpy(), (1, 2, 0)
                )
                img = (img - vmin) / (vmax - vmin)
                img = np.clip(img, 0, 1)
                img = np.repeat(img, 3, axis=2)

                if seg_lungCT[n, 0, slice].sum() > 0:
                    contours, im = cv2.findContours(
                        seg_lungCT[n, 0, slice], cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
                    )
                    cv2.drawContours(img, contours, -1, (1, 0.525, 0), 2)

                axes[0, i].imshow(img)
                axes[0, i].axis("off")
                axes[0, i].set_title("Slice: " + str(slice), fontsize=10)

                vis.visualize_image_attr(
                    np.where(
                        np.abs(
                            np.transpose(
                                attr[model][:, slice].cpu().detach().numpy(), (1, 2, 0)
                            )
                        )
                        < norm[idx],
                        0,
                        np.transpose(
                            attr[model][:, slice].cpu().detach().numpy(), (1, 2, 0)
                        ),
                    ),
                    img,
                    "heat_map",
                    "all",
                    plt_fig_axis=(fig, axes[1, i]),
                    cmap=heat_cmap,
                    show_colorbar=False,
                    use_pyplot=False,
                )

                axes[1, i].axis("off")

        plt._original_dpi = 200 if dist <= 10 else 300

        plt.savefig(
            "./Images/lungCT/"
            + str(n)
            + "/"
            + str(attr_type[idx])
            + "_"
            + model_type[model]
            + ".png",
            bbox_inches="tight",
            pad_inches=0,
        )
        plt.close()

#### Single Slice Export

In [None]:
import cv2

attr_type = ["GCAM", "GGCAM"]
model_type = [
    "TwoHead_0_a",
    "TwoHead_1_a",
    "CATE_a",
    "TwoHead_0_b",
    "TwoHead_1_b",
    "CATE_b",
]

if not os.path.exists("./Images/lungCT/" + str(n) + "/"):
    os.makedirs("./Images/lungCT/" + str(n) + "/")

norm = [None, 0.09]
vmin = -1150
vmax = 550

attr_ggcam_lungCT = [
    normalize_pos_neg(torch.Tensor(np.moveaxis(a[None, :], 3, 1)) * b.cpu())
    for a, b in zip(attr_gcam_lungCT, attr_gbp_lungCT)
]

slice = 23

img = np.transpose(data_lungCT[n, :, slice].cpu().detach().numpy(), (1, 2, 0))
img = (img - vmin) / (vmax - vmin)
img = np.clip(img, 0, 1)
img = np.repeat(img, 3, axis=2)
contours, im = cv2.findContours(
    seg_lungCT[n, 0, slice], cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
cv2.drawContours(img, contours, -1, (1, 0.525, 0), 2)

plt.imshow(img)
plt.axis("off")
plt.tight_layout()

plt._original_dpi = 200

plt.savefig(
    "./Images/lungCT/" + str(n) + "/Slice_" + str(slice) + ".png",
    bbox_inches="tight",
    pad_inches=0,
)
plt.close()

for model in range(3):
    for idx, attr in enumerate([attr_gcam_lungCT, attr_ggcam_lungCT]):
        if idx == 0:
            img = np.transpose(
                data_lungCT[n, :, slice].cpu().detach().numpy(), (1, 2, 0)
            )
            img = (img - vmin) / (vmax - vmin)
            img = np.clip(img, 0, 1)
            img = np.repeat(img, 3, axis=2)

            gcam_vis = show_cam_on_image(img, attr[model][:, :, slice], use_rgb=True)

            plt.imshow(gcam_vis)

            plt.axis("off")
            plt.tight_layout()
            plt._original_dpi = 200

            plt.savefig(
                "./Images/lungCT/"
                + str(n)
                + "/Slice_"
                + str(slice)
                + "_"
                + attr_type[idx]
                + "_"
                + model_type[model]
                + ".png",
                bbox_inches="tight",
                pad_inches=0,
            )
            plt.close()
        else:
            img = np.transpose(
                data_lungCT[n, :, slice].cpu().detach().numpy(), (1, 2, 0)
            )
            img = (img - vmin) / (vmax - vmin)
            img = np.clip(img, 0, 1)
            img = np.repeat(img, 3, axis=2)

            fig, ax = vis.visualize_image_attr(
                np.where(
                    np.abs(
                        np.transpose(
                            attr[model][:, slice].cpu().detach().numpy(), (1, 2, 0)
                        )
                    )
                    < norm[idx],
                    0,
                    np.transpose(
                        attr[model][:, slice].cpu().detach().numpy(), (1, 2, 0)
                    ),
                ),
                img,
                "heat_map",
                "all",
                cmap=heat_cmap,
                show_colorbar=False,
                use_pyplot=False,
            )

            ax.axis("off")
            fig.tight_layout()
            fig._original_dpi = 100

            fig.savefig(
                "./Images/lungCT/"
                + str(n)
                + "/Slice_"
                + str(slice)
                + "_"
                + attr_type[idx]
                + "_"
                + model_type[model]
                + ".png",
                bbox_inches="tight",
                pad_inches=0,
            )
            plt.close()