In [None]:
import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
from visualization import HeatmapVisualizer, visualize_original
import torch.nn as nn
import torch
from methods import ig_pipeline, sm_pipeline, sg_pipeline, dl_pipeline
my_font = fm.FontProperties(fname="fonts/SimHei.ttf")
mask_viz = HeatmapVisualizer(blur=7, normalization_type="signed_max")
device = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
dataloader = ... # load your data

In [None]:
model = ... # load your model
model = model.to(device).eval()

### IG

In [None]:
for data, target in dataloader:
    data = data.to(device)
    target = target.to(device)
    correct_index = torch.argmax(model(data), dim=-1) == target
    data = data[correct_index]
    target = target[correct_index]
    attribution_map = ig_pipeline(model, data, target)
    for i in range(len(data)):
        img = mask_viz._normalize(data[i:i+1].cpu().detach().numpy(), "01")
        img = img.squeeze().transpose(1, 2, 0) * 255
        img_integrated_gradient_overlay = visualize_original(
            attribution_map[i].transpose(1, 2, 0),
            img, clip_above_percentile=99, clip_below_percentile=0,
            overlay=True, mask_mode=True)
        img_integrated_gradient = visualize_original(attribution_map[i].transpose(
            1, 2, 0), img, clip_above_percentile=99, clip_below_percentile=0, overlay=False)
        fig, axes = plt.subplots(1, 4, figsize=(15, 5))
        im_, mask = mask_viz(attribution_map[i:i+1], data[i:i+1].cpu().detach(
        ).numpy(), overlay_opacity=0.5, imshow=False, return_tiled=True)
        axes[0].axis('off')
        axes[0].set_title("Original Image", fontproperties=my_font)
        axes[0].imshow(img / 255)
        axes[0].axes.xaxis.set_ticks([])
        axes[0].axes.yaxis.set_visible(False)
        
        axes[1].axis('off')
        axes[1].set_title("IntegratedGradient", fontproperties=my_font)
        axes[1].imshow(im_)
        axes[1].imshow(mask, alpha=0.5, cmap='jet')
        axes[1].axes.xaxis.set_ticks([])
        axes[1].axes.yaxis.set_visible(False)
        
        axes[2].axis('off')
        axes[2].imshow(img_integrated_gradient / 255)
        axes[2].axes.xaxis.set_ticks([])
        axes[2].axes.yaxis.set_visible(False)
        
        axes[3].axis('off')
        axes[3].imshow(img_integrated_gradient_overlay / 255)
        axes[3].axes.xaxis.set_ticks([])
        axes[3].axes.yaxis.set_visible(False)
        plt.show()
    break


### SM

In [None]:
for data, target in dataloader:
    data = data.to(device)
    target = target.to(device)
    correct_index = torch.argmax(model(data), dim=-1) == target
    data = data[correct_index]
    target = target[correct_index]
    attribution_map = sm_pipeline(model, data, target)
    for i in range(len(data)):
        img = mask_viz._normalize(data[i:i+1].cpu().detach().numpy(), "01")
        img = img.squeeze().transpose(1, 2, 0) * 255
        img_integrated_gradient_overlay = visualize_original(
            attribution_map[i].transpose(1, 2, 0),
            img, clip_above_percentile=99, clip_below_percentile=0,
            overlay=True, mask_mode=True)
        img_integrated_gradient = visualize_original(attribution_map[i].transpose(
            1, 2, 0), img, clip_above_percentile=99, clip_below_percentile=0, overlay=False)
        fig, axes = plt.subplots(1, 4, figsize=(15, 5))
        im_, mask = mask_viz(attribution_map[i:i+1], data[i:i+1].cpu().detach(
        ).numpy(), overlay_opacity=0.5, imshow=False, return_tiled=True)
        axes[0].axis('off')
        axes[0].set_title("Original Image", fontproperties=my_font)
        axes[0].imshow(img / 255)
        axes[0].axes.xaxis.set_ticks([])
        axes[0].axes.yaxis.set_visible(False)
        
        axes[1].axis('off')
        axes[1].set_title("SaliencyGradient", fontproperties=my_font)
        axes[1].imshow(im_)
        axes[1].imshow(mask, alpha=0.5, cmap='jet')
        axes[1].axes.xaxis.set_ticks([])
        axes[1].axes.yaxis.set_visible(False)
        
        axes[2].axis('off')
        axes[2].imshow(img_integrated_gradient / 255)
        axes[2].axes.xaxis.set_ticks([])
        axes[2].axes.yaxis.set_visible(False)
        
        axes[3].axis('off')
        axes[3].imshow(img_integrated_gradient_overlay / 255)
        axes[3].axes.xaxis.set_ticks([])
        axes[3].axes.yaxis.set_visible(False)
        plt.show()
    break


### SG

In [None]:
for data, target in dataloader:
    data = data.to(device)
    target = target.to(device)
    correct_index = torch.argmax(model(data), dim=-1) == target
    data = data[correct_index]
    target = target[correct_index]
    attribution_map = sg_pipeline(model, data, target)
    for i in range(len(data)):
        img = mask_viz._normalize(data[i:i+1].cpu().detach().numpy(), "01")
        img = img.squeeze().transpose(1, 2, 0) * 255
        img_integrated_gradient_overlay = visualize_original(
            attribution_map[i].transpose(1, 2, 0),
            img, clip_above_percentile=99, clip_below_percentile=0,
            overlay=True, mask_mode=True)
        img_integrated_gradient = visualize_original(attribution_map[i].transpose(
            1, 2, 0), img, clip_above_percentile=99, clip_below_percentile=0, overlay=False)
        fig, axes = plt.subplots(1, 4, figsize=(15, 5))
        im_, mask = mask_viz(attribution_map[i:i+1], data[i:i+1].cpu().detach(
        ).numpy(), overlay_opacity=0.5, imshow=False, return_tiled=True)
        axes[0].axis('off')
        axes[0].set_title("Original Image", fontproperties=my_font)
        axes[0].imshow(img / 255)
        axes[0].axes.xaxis.set_ticks([])
        axes[0].axes.yaxis.set_visible(False)
        
        axes[1].axis('off')
        axes[1].set_title("SmoothGradient", fontproperties=my_font)
        axes[1].imshow(im_)
        axes[1].imshow(mask, alpha=0.5, cmap='jet')
        axes[1].axes.xaxis.set_ticks([])
        axes[1].axes.yaxis.set_visible(False)
        
        axes[2].axis('off')
        axes[2].imshow(img_integrated_gradient / 255)
        axes[2].axes.xaxis.set_ticks([])
        axes[2].axes.yaxis.set_visible(False)
        
        axes[3].axis('off')
        axes[3].imshow(img_integrated_gradient_overlay / 255)
        axes[3].axes.xaxis.set_ticks([])
        axes[3].axes.yaxis.set_visible(False)
        plt.show()
    break


### DL

In [None]:
for data, target in dataloader:
    data = data.to(device)
    target = target.to(device)
    correct_index = torch.argmax(model(data), dim=-1) == target
    data = data[correct_index]
    target = target[correct_index]
    attribution_map = dl_pipeline(model, data, target)
    for i in range(len(data)):
        img = mask_viz._normalize(data[i:i+1].cpu().detach().numpy(), "01")
        img = img.squeeze().transpose(1, 2, 0) * 255
        img_integrated_gradient_overlay = visualize_original(
            attribution_map[i].transpose(1, 2, 0),
            img, clip_above_percentile=99, clip_below_percentile=0,
            overlay=True, mask_mode=True)
        img_integrated_gradient = visualize_original(attribution_map[i].transpose(
            1, 2, 0), img, clip_above_percentile=99, clip_below_percentile=0, overlay=False)
        fig, axes = plt.subplots(1, 4, figsize=(15, 5))
        im_, mask = mask_viz(attribution_map[i:i+1], data[i:i+1].cpu().detach(
        ).numpy(), overlay_opacity=0.5, imshow=False, return_tiled=True)
        axes[0].axis('off')
        axes[0].set_title("Original Image", fontproperties=my_font)
        axes[0].imshow(img / 255)
        axes[0].axes.xaxis.set_ticks([])
        axes[0].axes.yaxis.set_visible(False)
        
        axes[1].axis('off')
        axes[1].set_title("DeepLift", fontproperties=my_font)
        axes[1].imshow(im_)
        axes[1].imshow(mask, alpha=0.5, cmap='jet')
        axes[1].axes.xaxis.set_ticks([])
        axes[1].axes.yaxis.set_visible(False)
        
        axes[2].axis('off')
        axes[2].imshow(img_integrated_gradient / 255)
        axes[2].axes.xaxis.set_ticks([])
        axes[2].axes.yaxis.set_visible(False)
        
        axes[3].axis('off')
        axes[3].imshow(img_integrated_gradient_overlay / 255)
        axes[3].axes.xaxis.set_ticks([])
        axes[3].axes.yaxis.set_visible(False)
        plt.show()
    break
