In [None]:
import torch
from torch import nn
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')

# Function
----------

In [None]:
def saliency_map(input_path,
                 input_model,
                 device: str = torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
                 target_label: int or float or np.array or torch.Tensor = torch.empty(0),
                 visualize: bool = True):

    # Check target type
    if not isinstance(target_label, torch.Tensor):
        target_label = torch.as_tensor(int(target_label)).unsqueeze(0)

    # Load, normalize and convert image to torch.Tensor
    input_image = plt.imread(input_path) / 255.
    input_tensor = torch.from_numpy(input_image).permute(2, 0, 1).unsqueeze(0).float().requires_grad_().to(device)

    # loss function is needed for computing loss and derivative with respect to input image
    criterion = nn.CrossEntropyLoss()

    # push model to eval mode and make all parameters requires_grad to false (only input image needs gradient)
    input_model.to(device)
    input_model.eval()
    for p in input_model.parameters():
        p.requires_grad = False

    # Forward: use model output as target! (in model we trust)
    y_hat = model(input_tensor)
    y_true = torch.argmax(y_hat).unsqueeze(0) if not target_label.numel() else target_label
    loss = criterion(y_hat, y_true)

    # Compute gradient of loss with respect to input image
    input_grad = torch.autograd.grad(loss, input_tensor)[0][0].detach().cpu()

    # Only magnitude of gradients are needed
    input_grad.abs_()

    # Normalize gradients for visualization
    input_grad = (input_grad - input_grad.min()) / (input_grad.max() - input_grad.min())

    if visualize:
        plt.imshow(input_grad.sum(0), alpha=0.8, cmap='hot')
        plt.imshow(input_image, alpha=0.2)
        plt.axis('off')
        plt.show()

    return input_grad

In [None]:
model = models.resnet18(pretrained=True)
image_path = "pics/golden_retriever_1.jpeg"
_ = saliency_map(image_path, model, visualize=True)

# Step by Step
----------

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = models.resnet18(pretrained=True).to(device)
model.eval()
for param in model.parameters():
    param.requires_grad = False

In [None]:
image = plt.imread("pics/golden_retriever_1.jpeg") / 255.
input_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().requires_grad_().to(device)
plt.imshow(image)
plt.axis('off')
plt.show()

In [None]:
criterion = nn.CrossEntropyLoss()
y_true = torch.tensor([285], dtype=torch.int64)
print(y_true.dtype)

In [None]:
y_hat = model(input_tensor)
loss = criterion(y_hat, torch.argmax(y_hat).unsqueeze(0))

input_grad = torch.autograd.grad(loss, input_tensor)[0][0].detach().cpu()
print(input_grad.shape)

In [None]:
input_grad.abs_()
input_grad = (input_grad - input_grad.min()) / (input_grad.max() - input_grad.min())

In [None]:
plt.imshow(input_grad.sum(0), alpha=0.7, cmap='jet')
plt.imshow(image, alpha=0.3)
plt.show()

In [None]:
plt.imshow(input_grad.sum(0), alpha=0.8, cmap='hot')
plt.imshow(image, alpha=0.2)
plt.show()

In [None]:
plt.imshow(input_grad.permute(1, 2, 0), cmap='hot')
plt.show()