In [1]:
# Importing Libraries
import torch
from torch.autograd import Variable
import torchvision
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
from torchsummary import summary
from PIL import Image

In [2]:
def compute_map(X, Y, model):
    """
    Input:
        X: tensor of images with shape (N, 3, H, W) where 1. N is the number of images, 2. 3 is the number of color channels (RGB), 3. H and W are the height/width of each image.
        Y: labels for X: contains shape (N, ) which contains class labels for each image in X
        model: the CNN model used to calculate saliency maps

    Returns:
        saliency: a tensor of shape (N, H, W) giving saliency maps for the input images.
    """
    model.eval()
    X_var = Variable(X, requires_grad=True) # allows us to compute gradients
    Y_var = Variable(y)

    scores = model.forward(X_var) # calculates model's output scores for each image
    loss = torch.sum(scores.gather(1, Y_var.view(-1, 1)).squeeze())

    loss.backward()

    grad = torch.abs(X_var.grad.data)
    saliency = torch.max(grad, 1)[0].squeeze()

    return saliency

In [3]:
def visualize(X, Y):

    X_tensor = torch.cat([preprocess(Image.fromarray(x)) for x in X], dim=0) # Converts list of images into a list of processed PIL images
    Y_tensor = torch.LongTensor(y)

    saliency = compute_map(X_tensor, Y_tensor, model)
    saliency = saliency.numpy()

    N = X.shape[0]
    for i in range(N):
        plt.subplot(2, N, i + 1)
        plt.imshow(X[i])
        plt.axis('off')
        plt.subplot(2, N, N + i + 1)
        plt.imshow(saliency[i], cmap=plt.cm.hot)
        plt.axis('off')
        plt.gcf().set_size_inches(12, 5)
    plt.show()