# Intro
In this lab scenario you will implement  [Grad-CAM](https://arxiv.org/abs/1610.02391) and guided Grad-CAM.
Those two methods are used to explain what parts of the image are important for the model prediction.

Grad-CAM produces a heatmap of important parts in the image.  
Guided Grad-CAM combines the output of Grad-Cam with gradients calculated with respect to input pixels.

# Data and Visualizations

## Dataset preparation
First, let's get the input images.
We will download a few zebra and a few magpie images from Wikipedia.

In [None]:
from pathlib import Path
import os

# we create directories for zebra and magpie images
!mkdir -p animals/zebra
!mkdir -p animals/magpie

# we download zebra images
# By Ltshears at English Wikipedia - Transferred from en.wikipedia to Commons by Calliopejen1 using CommonsHelper., Public Domain, https://commons.wikimedia.org/w/index.php?curid=17695258
!wget https://upload.wikimedia.org/wikipedia/commons/7/77/Equus_zebra_hartmannae_%281%29.jpg?download -O animals/zebra/zebra1.jpg
#By Photographie de AndrĂŠ ALLIOT - Own work, CC0, https://commons.wikimedia.org/w/index.php?curid=95062050
!wget https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Z%C3%A8bres_de_Gr%C3%A9vy.jpg/640px-Z%C3%A8bres_de_Gr%C3%A9vy.jpg?download -O animals/zebra/zebra2.jpg
#By George Brits georgebrits_cableandgrain - https://unsplash.com/photos/wvO5tPfTpugarchive copyImage, CC0, https://commons.wikimedia.org/w/index.php?curid=62235911
!wget https://upload.wikimedia.org/wikipedia/commons/thumb/b/b0/Zebra_squared_%28Unsplash%29.jpg/640px-Zebra_squared_%28Unsplash%29.jpg?download -O animals/zebra/zebra3.jpg

# we download magpie images
# By Adrian Pingstone (Arpingstone) - Own work, Public Domain, https://commons.wikimedia.org/w/index.php?curid=3343840
!wget https://upload.wikimedia.org/wikipedia/commons/thumb/e/e7/Magpie_arp.jpg/373px-Magpie_arp.jpg?download -O animals/magpie/mp1.jpg
# By Vincent Oostelbos - https://www.inaturalist.org/photos/122369356, CC0, https://commons.wikimedia.org/w/index.php?curid=104787230
!wget https://upload.wikimedia.org/wikipedia/commons/thumb/4/44/Pica_pica_122369356.jpg/640px-Pica_pica_122369356.jpg?download -O animals/magpie/mp2.jpg
# By Аимаина хикари - Own work, CC0, https://commons.wikimedia.org/w/index.php?curid=27314113
!wget https://upload.wikimedia.org/wikipedia/commons/thumb/e/e4/Pica_pica_fledgling2.JPG/556px-Pica_pica_fledgling2.JPG?download -O animals/magpie/mp3.jpg

# directory with images
ROOT = Path("animals/")


In [None]:
!pip install opencv-python



In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import cv2
from PIL import Image

In [None]:
class ImagesDataset(torch.utils.data.Dataset):
    """
    Loads images from the directory and applies transforms to them.
    """

    def __init__(self, directory: Path, transforms):
        self.transforms = transforms
        self.img_paths = sorted(list(directory.glob("*")))

    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx]).convert("RGB")

        img = self.transforms(img)

        return img

    def __len__(self):
        return len(self.img_paths)

In [None]:
# as we are going to use a pre-trained model later
# we want to match its data distribution
image_net_mean = [0.485, 0.456, 0.406]
image_net_std = [0.229, 0.224, 0.225]
input_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(
            (256, 256), interpolation=torchvision.transforms.InterpolationMode.BILINEAR
        ),
        torchvision.transforms.CenterCrop((224, 224)),
        torchvision.transforms.ToTensor(),  # our input is an image
        torchvision.transforms.Normalize(image_net_mean, image_net_std),
    ]
)


ZEBRA = ImagesDataset(ROOT / "zebra", input_transforms)
MAGPIE = ImagesDataset(ROOT / "magpie", input_transforms)

In [None]:
SAMPLE_SIZE = 3

LOADER_ZEBRA = torch.utils.data.DataLoader(
    ZEBRA, shuffle=False, batch_size=SAMPLE_SIZE
)

LOADER_MAGPIE = torch.utils.data.DataLoader(
    MAGPIE, shuffle=False, batch_size=SAMPLE_SIZE
)


SAMPLE_ZEBRA = next(iter(LOADER_ZEBRA))
SAMPLE_MAGPIE = next(iter(LOADER_MAGPIE))

## Visualizations
Below you can find a function that will help us inspect both images and Grad-CAM class-discriminative
localization map.

In [None]:
def denormalize(imgs):
    return imgs * 0.225 + 0.5


def show_images(images):
    """
    Given a tensor of shape [BATCH, C, H, W]
    prints BATCH images in one row
    for C = 1, those images are converted to heatmaps
    for C = 3, images are just printed
    for C = 4, the first channel is converted to a heatmap and
                imposed over the image described by the last 3
                channels
    """
    assert len(images.shape) == 4
    num_images = images.shape[0]
    _, axes = plt.subplots(1, num_images, figsize=(28, 28))
    images = images.permute(0, 2, 3, 1).detach().numpy()

    def handle_img(img, axe):
        axe.axis("off")
        if img.shape[-1] == 1:
            img = (img - img.min())/(img.max() - img.min() + 1e-10)
            img = cv2.applyColorMap(np.uint8(img * 255), colormap=cv2.COLORMAP_JET)
        elif img.shape[-1] == 3:
            img = np.clip(denormalize(img), 0, 1)
        elif img.shape[-1] == 4:
            heatmap = img[..., 0]
            heatmap = (heatmap - heatmap.min())/(heatmap.max() - heatmap.min() + 1e-10)
            heatmap = cv2.applyColorMap(np.uint8(heatmap * 255), colormap=cv2.COLORMAP_JET)
            img = img[..., 1:]
            img = np.uint8(np.clip(denormalize(img), 0, 1) * 255)
            img = cv2.addWeighted(heatmap, 0.5, img, 0.5, 0)

        axe.imshow(img)

    if num_images == 0:
        handle_img(images[0], axes)
    else:
        for i, img in enumerate(images):
            handle_img(img, axes[i])


show_images(SAMPLE_ZEBRA)
show_images(SAMPLE_MAGPIE)

# Grad-CAM

## The Model
First, we are going to download and inspect a model that we will try later to explain with Grad-CAM.

In [None]:
# for showing the model structure
!pip install torchinfo

In [None]:
import copy
import gc

def create_fresh_model():
    model = copy.deepcopy(torchvision.models.vgg19(weights="DEFAULT"))
    gc.collect()
    # later we are going to use a full backward hook
    # that won't work with in-place ops
    for m in model.features:
        if isinstance(m, torch.nn.ReLU):
            m.inplace = False

    model.eval() # we disable dropout

    for m in model.classifier:
        if isinstance(m, torch.nn.ReLU):
            m.inplace = False

    return model


model = create_fresh_model()

In [None]:
import torchinfo
torchinfo.summary(model, input_size=(1, 3, 224, 224), device="cpu")

In [None]:
model

This model discriminates between 1000 classes. The class for zebra is 340, whereas the class for magpie is 18.  
If you are interested in other classes check this [link](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a).

In [None]:
ZEBRA_ID = 340
MAGPIE_ID = 18

logits = model(SAMPLE_ZEBRA)
argmax = torch.argmax(logits, axis=-1)
print(f"ZEBRA: {argmax}")

logits = model(SAMPLE_MAGPIE)
argmax = torch.argmax(logits, axis=-1)
print(f"MAGPIE: {argmax}")

## The Grad-CAM
A brief description of Grad-CAM, a more detailed one can be found in [the paper in section 3 Grad-CAM](https://arxiv.org/abs/1610.02391).  
We omit the batch dimension for simplicity but your implementation should be able to handle it.  

Let $I$ be the image of a zebra.  
Let $M$ be our model.   
Let $M(I)[z]$ be the logit corresponding to the zebra class - that is, for each image, our model outputs 1000 numbers that we will call logits.
The higher the i-th logit the more probable that the image belongs to the i-th class.  
Let $A$ be the **map activations** of some (usually last) convolution layer created when calling $M(I)$.

We want to calculate $G[k, i, j] = \frac{dM(I)[z]}{dA[k, i, j]}$.  
That is, we want to calculate the gradient of the logit with respect to each value from $A$.  
Note that as $A$ is the map activations of some convolution layer therefore it has channel, height, and width dimensions.  

Now we average $G$ over spatial dimensions (height and width) creating $G'$, which we use to get the class-discriminative
localization map as follows:

$L^c_{\text{Grad-CAM}} = \mathrm{ReLU}(\sum_{k}{G'[k]A[k]})$

We finally scale the result by putting the values in the range $[0, 1]$ and resize the result to match the image size using bilinear interpolation.

Here your task is to finish the implementation of Grad-CAM.


Few hints:
* Let `x` be a tensor. If `x` is not a scalar then calling `x.backward()` will result in an error. That is because we should provide a starting gradient. For details, see [this doc](https://pytorch.org/docs/stable/generated/torch.Tensor.backward.html).
* One can extract outputs and gradients of outputs of the layers using [forward](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=register_forward_hook#torch.nn.Module.register_forward_hook) and [backward](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=register_full_backward#torch.nn.Module.register_full_backward_hook) hooks.


Below you can find examples of forward and backward hooks.

In [None]:
## EXAMPLE FORWARD HOOK
def forward_hook(module, input, output):
    altered_output = output * 2
    print(
        f"""Forward hook called
          module = {module}
          input = {input}
          output = {output}
          altered_output = {altered_output}"""
    )
    return altered_output


n = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 1))
print(n)
n[1].register_forward_hook(forward_hook)
n[2].register_forward_hook(forward_hook)

x = torch.randn((1, 10))
l = n(x)
print(l)


In [None]:
## EXAMPLE BACKWARD HOOK
bakward_hook_on = True


def backward_hook(module, grad_input, grad_output):
    global bakward_hook_on
    altered_grad_input = (grad_input[0] * 10,) if bakward_hook_on else grad_input
    print(
        f"""Backward hook called
          module = {module}
          grad_input = {grad_input}
          grad_output = {grad_output}
          altered_grad_input = {altered_grad_input}"""
    )
    return altered_grad_input


n = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), torch.nn.Linear(5, 1))
print(n)
n[0].register_full_backward_hook(backward_hook)

x = torch.randn((1, 10))
x.requires_grad = True
l = n(x)
l.backward()
print(f"x.grad {x.grad}")

bakward_hook_on = False
x.grad = None
l = n(x)
l.backward()
print(f"x.grad {x.grad}")

In [None]:
class ModelWithGradCam:
    def __init__(self, model_creator, layer_id):
        """
        Args:
          model_creator - a function that creates a model upon a call
          layer_id - id of the layer whose output we are paying attention to
        """
        self.model: torch.nn.Module = model_creator()
        self.model_layer: torch.nn.Module = self.model.features[
            layer_id
        ]  # the layer whose output we are paying attention to
        self.forward_pass = None  # the result of the forward pass on self.model_layer (A in the description above)
        self.grad_pass = None  # the gradient with respect to self.forward_pass (G in the description above)

        self.register_hooks()

    def register_hooks(self):
        """
        Registers hooks for getting the output of self.model_layer
        and the gradient with respect to this output
        """

        # HINT:
        # https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=register_forward_hook#torch.nn.Module.register_forward_hook
        # https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=register_full_backward#torch.nn.Module.register_full_backward_hook
        # Note be careful, inspect forward and backward hook inputs
        ## TODO {

        ## }

    def get_grad_cam(self, input: torch.Tensor, class_id: int) -> torch.Tensor:
        """
        For a given input of shape [BATCH, C, H, W]
        calculates the class-discriminative
        localization map of shape [BATCH, 1, H, W]
        (one for each element of the batch) for class class_id
        """
        self.model.zero_grad()

        ## TODO {


        ## }

        assert self.forward_pass.shape == self.grad_pass.shape
        assert len(self.forward_pass.shape) == 4  # [BATCH, C', H', W']
        assert len(self.grad_pass.shape) == 4

        assert gcam.shape[0] == input.shape[0]
        assert gcam.shape[1] == 1
        assert gcam.shape[2:] == input.shape[2:]

        return gcam

Let's test the implementation. You can experiment with different layer_id (pay attention to the brief description above). In general, deeper layers should contain more high-level information.

In [None]:
model_with_cam = ModelWithGradCam(create_fresh_model, 35)
try:
    gcam = model_with_cam.get_grad_cam(SAMPLE_ZEBRA, ZEBRA_ID)
finally:
    del model_with_cam
    gc.collect()
show_images(torch.concat([gcam, SAMPLE_ZEBRA], dim=-3))

In [None]:
model_with_cam = ModelWithGradCam(create_fresh_model, 35)
try:
    gcam = model_with_cam.get_grad_cam(SAMPLE_MAGPIE, MAGPIE_ID)
finally:
    del model_with_cam
    gc.collect()
show_images(torch.concat([gcam, SAMPLE_MAGPIE], dim=-3))

## Guided Grad-Cam
Finish implementation of guided Grad-CAM below.  
The brief explanation behind this method is presented in  [the paper section 3.2 Guided Grad-CAM](https://arxiv.org/abs/1610.02391).   
The main idea is to additionally calculate the gradient of class logit with respect to each pixel of the input image, but during this computation
we change negative gradients with respect to ReLU inputs to zero.    
In the end, we take image gradients and multiply them point-wise by the class-discriminative
localization map calculated by Grad-CAM.  

Few hints:
* `input.requires_grad = True`
* `x.detach()`


In [None]:
class ModelWithGuidedGradCam:
    def __init__(self, model_creator, layer_id):
        """
        Args:
          model_creator - a function that creates a model upon a call
          layer_id - id of the layer whose output we are paying attention to
        """
        self.model = model_creator()
        self.model_with_gc = ModelWithGradCam(model_creator, layer_id)
        self.register_hooks()

    def register_hooks(self):
        """
        Registers hook for each ReLU in the model
        for updating the gradient with respect to
        the ReLU input
        """

        def modify_backward(module, i, o):
            ## TODO {

            ## }

        for module in self.model.features:
            module: torch.nn.Module
            if isinstance(module, torch.nn.ReLU):
                module.register_full_backward_hook(modify_backward)

        for module in self.model.classifier:
            module: torch.nn.Module
            if isinstance(module, torch.nn.ReLU):
                module.register_full_backward_hook(modify_backward)

    def get_guided_grad_cam(self, input, class_id, use_grad_cam=True):
        gcam = self.model_with_gc.get_grad_cam(input, class_id)

        self.model.zero_grad()
        ## TODO {

        ## }
        assert img_grad.shape == input.shape
        assert len(img_grad.shape) == len(gcam.shape)
        if use_grad_cam:
            res = img_grad * gcam
        else:
            res = img_grad  # gradient of class logit with respect to the input image pixels

        return res

In [None]:
model_with_ggcam = ModelWithGuidedGradCam(create_fresh_model, 35)
grad = model_with_ggcam.get_guided_grad_cam(
    SAMPLE_ZEBRA, ZEBRA_ID, use_grad_cam=False
)  # first without Grad-CAM
try:
    gcam = model_with_ggcam.get_guided_grad_cam(SAMPLE_ZEBRA, ZEBRA_ID)
finally:
    del model_with_ggcam
    gc.collect()
show_images(grad)
show_images(gcam)

In [None]:
model_with_ggcam = ModelWithGuidedGradCam(create_fresh_model, 35)
grad = model_with_ggcam.get_guided_grad_cam(
    SAMPLE_MAGPIE, MAGPIE_ID, use_grad_cam=False
)
try:
    gcam = model_with_ggcam.get_guided_grad_cam(SAMPLE_MAGPIE, MAGPIE_ID)
finally:
    del model_with_ggcam
    gc.collect()

show_images(grad)
show_images(gcam)