<a href="https://colab.research.google.com/github/EffiSciencesResearch/ML4G-2.0/blob/master/workshops/gradcam/gradcam.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Visualization of CNN: Grad-CAM

* **Objective**: Convolutional Neural Networks are widely used on computer vision. It is powerful for processing grid-like data. However we hardly know how and why it works, due to the lack of decomposability into individually intuitive components. In this assignment, we will introduce the Grad-CAM which visualizes the heatmap of input images by highlighting the important region for visual question answering(VQA) task.

* NB: if `PIL` is not installed, try `pip install pillow`.

In [None]:
try:
    import google.colab

    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    !pip install jaxtyping einops -q

    # Download necessary files
    %cd /content
    !wget https://cozyfractal.com/static/ml4g-gradcam.zip
    !unzip -o ml4g-gradcam.zip
    %cd /content/gradCam
else:
    !wget https://cozyfractal.com/static/ml4g-gradcam.zip
    !unzip -o ml4g-gradcam.zip
    %cd gradCam/

In [None]:
import cv2
import einops
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from jaxtyping import Float, Int
from PIL import Image
from torch import Tensor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Visual Question Answering problem
Given an image and a question in natural language, the model choose the most likely answer from 3 000 classes according to the content of image. The VQA task is indeed a multi-classificaition problem.
<img src="https://github.com/EffiSciencesResearch/ML4G/blob/main/days/w1d4/gradCam/vqa_model.PNG?raw=1">

We provide you a pretrained model `vqa_resnet` for VQA tasks.

In [None]:
# load model
from load_model import load_model

vqa_resnet = load_model()

In [None]:
# Fixes a strange bug. Ideally, we would run the model in eval mode though.
vqa_resnet.train()
# Dropout of 0.5 is too big, for deterministic behavior, remove the dropout

# Loop through all the modules in the model
for module in vqa_resnet.modules():
    if isinstance(module, nn.Dropout):
        # Update the dropout probability for each dropout layer
        module.p = 0.0

The model uses to sets of tokens, which are all words, one for the questions and one for the answers. 

In [None]:
checkpoint = "2017-08-04_00.55.19.pth"
saved_state = torch.load(checkpoint, map_location=device)
# reading vocabulary from saved model
vocab = saved_state["vocab"]
print("Vocab:", set(vocab.keys()))

# reading word tokens from saved model
question_word_to_index = vocab["question"]
print("Tokens for questions:", question_word_to_index)

# reading answers from saved model
answer_word_to_index = vocab["answer"]
print("Tokens for answers:", answer_word_to_index)

num_tokens = len(question_word_to_index) + 1
print(f"{num_tokens=}")

# Mapping from integer to token string
index_to_answer_word = ["unk"] * len(answer_word_to_index)
for w, idx in answer_word_to_index.items():
    index_to_answer_word[idx] = w

print(index_to_answer_word)

### Inputs
In order to use the pretrained model, the input image should be normalized using `mean = [0.485, 0.456, 0.406]`, and `std = [0.229, 0.224, 0.225]`, and be resized as `(448, 448)`. You can call the function `image_to_features` to achieve image preprocessing. For input question, the function `encode_question` is provided to encode the question into a vector of indices. You can also use `preprocess` function for both image and question preprocessing.

In [None]:
def get_transform():
    target_size = 448
    central_fraction = 1.0
    return transforms.Compose(
        [
            transforms.Resize(int(target_size / central_fraction)),
            transforms.CenterCrop(target_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

In [None]:
def tokenize(question: str) -> Int[Tensor, "nb_words"]:
    """Turn a question into a vector of tokens."""
    # For this model, tokens are lowercase words, so we split on whitespace
    words = question.lower().split()
    # Then map each word to its index in the dictionary
    return torch.tensor([question_word_to_index[word] for word in words], device=device)

In [None]:
def preprocess(dir_path: str, question: str):
    """
    Load the image at `dir_path` and process it to be a suitable input for vqa_resnet.
    """
    tokens = tokenize(question)

    img = Image.open(dir_path).convert("RGB")
    img_transformed = get_transform()(img).unsqueeze(0).to(device)

    q_len = torch.tensor(tokens.shape, device=device)

    inputs = (img_transformed, tokens.unsqueeze(0), q_len.unsqueeze(0))
    return inputs

In [None]:
def check_answers(img_path: str, question: str, topk=10):
    """Show the top `topk` answers of the model for a given question."""
    inputs = preprocess(img_path, question)
    logits = vqa_resnet(*inputs)
    probas = F.softmax(logits.squeeze(), dim=0)
    values, tokens_indices = torch.topk(probas, k=topk)

    print("Output probablities:")
    for token, value in zip(tokens_indices, values):
        print(f"- {index_to_answer_word[token]!r: >10} \t-> {value:.2%}")

We provide you two pictures and questions. Is the model doing great?
If not, make a hypothesis for why it makes an eroneous prediction. What feature in the image did it pick up?

This is the question that GradCam tries to answer.

In [None]:
dog_cat_path = "dog_cat.png"
dog_cat_question = "What animal"
check_answers(dog_cat_path, dog_cat_question)
Image.open(dog_cat_path)

In [None]:
hydrant_path = "hydrant.png"
hydrant_question = "What color"
check_answers(hydrant_path, hydrant_question)

Image.open(hydrant_path)

## Hooks in  pytorch

The goal of this exercise is to familiarize yourself with the hook system in pytorch. The hooks are not used to manipulate the weights but to **manipulate the activations** of the model on a given input. You can read, and even modify the hidden activations of the model

1. Use hooks to log information about the inner working of the model. Here we will just print the shapes of the activations.
2. But we can also view more interesting information. In the second exercise we plot the norm of each channel. The norm of a channel is a (bad) proxy for how much information there is in a channel.
3. We actually modify the activation to flip the sign of the output of a convolution. This should change the output of our model: we are butchering through it! (this is also a completely meaningless operation, but let's see what it does...)

Hooks in pytorch are not super pleasant to work with:
- Once you add a hook to a module, it stays there until you remove it, using `handle.remove()`. But for this, you need to
have saved the handle in the first place. 
- Errors: if your hook function throws an error, it you will need to remove it, since they are not removed automatically. The best way to do this is to always wrap calls with hooks in a `try`/`finally` block, and remove the hook in the `finally` block. This way, you are garanteed that the hook will be removed after one pass through the model.
- If you mess up, you can always reload the model, the hooks will be gone.
- They encorage the use of `global` state, which can lead to the usual drawbacks of global state. 

We do our interventions on the last convolution of the resnet, defined below.

You may need to read the hook tuto https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html#forward-and-backward-function-hooks

In [None]:
# The last convolution
module_to_hook = vqa_resnet.resnet_layer4.r_model.layer4[2].conv3

### Exercise 1: Show the shapes of the activations

You need to print the shapes of the input and output of the second to last convolution.
What are their shape? What does each dimension represent?

Hint: both of them are 4D.


In [None]:
def show_shapes_hook(module, inputs, output):
    # Q: What's the type of `inputs`?
    # Hide: all
    # A: inputs is a tuple of one element (= all the inputs of the module)
    print(f"{inputs[0].shape=}")
    print(f"{output.shape=}")
    # Hide: none


hook_handle = module_to_hook.register_forward_hook(show_shapes_hook)

try:
    check_answers(dog_cat_path, dog_cat_question)
finally:
    hook_handle.remove()

### Exercise 2: Plotting in a hook

Goal: figuring out which channels have the highest norm (≈ are the most used).
You just need to compute the per-channel norm.

In [None]:
def plot_highest_output_norm(module, inputs, output):
    # Hide: solution
    norms = ...
    # Hide: all
    norms = torch.linalg.vector_norm(output, dim=(0, 2, 3))
    # Hide: none

    norms.squeeze_()
    assert norms.shape == (2048,)
    plt.plot(norms.detach().cpu())
    plt.show()


hook_handle = module_to_hook.register_forward_hook(plot_highest_output_norm)

try:
    check_answers(dog_cat_path, dog_cat_question)
finally:
    hook_handle.remove()

<details>
<summary>What does the plot tells you about the inner workings of the model?</summary>

Huh... nothing?
</details>

### Exercise 3: Modifying the output of a module

What would happen if we filp (i.e. multiply by $-1$) the contribution of this convolution to the residual stream?
**Make a prediction first!**


In [None]:
def flip_output(module, inputs, output):
    # You need to modify output *in place*
    # Hide: all
    output *= -1
    # Hide: none


hook_handle = module_to_hook.register_forward_hook(flip_output)

try:
    check_answers(dog_cat_path, dog_cat_question)
finally:
    hook_handle.remove()

### Exercise 4: Save the output activations so we can re-use them later.

There is no code to complete in this exercise, but you need to understand what is going on.
- What is `global`?
- Why do we use it, and what would happen if we did not?
- What are the drawbacks of using `global` variables?

In [None]:
saved_output = None
saved_output_grad = None


def forward_hook(module, inputs, output):
    global saved_output
    saved_output = output
    print("Saved output of shape:", output.shape)


def backward_hook(module, grad_input, grad_outputs):
    global saved_output_grad
    saved_output_grad = grad_outputs[0]
    print("Saved gradient of shape:", grad_outputs[0].shape)


forward_handle = module_to_hook.register_forward_hook(forward_hook)
backward_handle = module_to_hook.register_full_backward_hook(backward_hook)

try:
    inputs = preprocess(dog_cat_path, dog_cat_question)
    vqa_resnet.zero_grad()
    logits = vqa_resnet(*inputs)
    # Pretend the loss is the 'giraffe' logit.
    loss = logits[0, answer_word_to_index["giraffe"]]
    loss.backward()
finally:
    forward_handle.remove()
    backward_handle.remove()

print("Shape of saved output:", saved_output.shape)
print("Shape of saved grad:", saved_output_grad.shape)

### Grad-CAM
* **Overview:** Given an image with a question, and a category (‘dog’) as input, we forward propagate the image through the model to obtain the `raw class scores` before softmax. We backpropagate only the logit of the target class. This signal is then backpropagated to the `convolutional feature map` of interest, where we can compute the coarse Grad-CAM localization (blue heatmap).

* We will define a `grad_cam` function visualize each image and its saliency map.

* Here is the link of the paper [Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization](https://arxiv.org/pdf/1610.02391.pdf)

In [None]:
def grad_cam(
    img_path="dog_cat.png", question="What animal", answer="dog", module_to_hook=module_to_hook
):
    # Make a figure with 3 subplots
    fig, axs = plt.subplots(1, 3, figsize=(15, 6))
    # Plot the original image on the left
    img = Image.open(img_path)
    axs[0].imshow(img)
    axs[0].set_title("Original image")

    inputs = preprocess(img_path, question)

    # Add the hooks to store the feature map and its gradient
    # in the global variables saved_output and saved_output_grad.
    forward_handle = module_to_hook.register_forward_hook(forward_hook)
    backward_handle = module_to_hook.register_full_backward_hook(backward_hook)

    try:
        # Make sure there are no gradients
        vqa_resnet.zero_grad()
        # Compute the predictions of the model
        logits = vqa_resnet(*inputs)

        # Backpropagate just on the logit of the given answer
        answer_logit = logits[0, answer_word_to_index[answer]]
        answer_logit.backward()
    finally:
        forward_handle.remove()
        backward_handle.remove()

    # Compute the gradient camera (equation 1 and 2 of the paper)
    # Hide: hard
    mean_gradient = einops.reduce(
        # Hide: all
        saved_output_grad,
        "batch features w h -> features",
        "mean",
        # Hide: hard
    )
    grad_cam = einops.einsum(
        # Hide: all
        mean_gradient,
        saved_output.squeeze(0),
        "features, features w h -> w h",
        # Hide: hard
    )
    # Hide: none

    grad_cam = grad_cam.clip(min=0)

    # Upscale, normalize and convert to RGB
    grad_cam = grad_cam.cpu().detach().numpy()
    cam = cv2.resize(grad_cam, (224, 224))
    cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam))  # Normalize between 0-1
    cam = np.uint8(cam * 255)  # Scale between 0-255 to visualize

    # Heatmap of activation map. Plot in the center.
    activation_heatmap = cv2.applyColorMap(cam, cv2.COLORMAP_HSV)
    axs[1].imshow(activation_heatmap)
    axs[1].set_title("Heatmap of activation map")

    # Overlay heatmap and picture. Plot on the right.
    img = cv2.imread(img_path)
    org_img = cv2.resize(img, (224, 224))
    img_with_heatmap = np.float32(activation_heatmap) + np.float32(org_img)
    img_with_heatmap *= 0.99 / np.max(img_with_heatmap)
    axs[2].imshow(img_with_heatmap)
    axs[2].set_title("Heatmap on picture")

    plt.show()

In [None]:
grad_cam(img_path="dog_cat.png", question="What animal", answer="dog")

In [None]:
grad_cam(img_path="dog_cat.png", question="What animal", answer="cat")

In [None]:
grad_cam(img_path="dog_cat.png", question="What animal", answer="giraffe")

In [None]:
grad_cam(img_path="hydrant.png", question="What color", answer="green")

In [None]:
grad_cam(img_path="hydrant.png", question="What color", answer="yellow")

Whats the interpretation of those plots?

Note: don't try too much to interpret the cat.


## Bonus: Safari - hunting the giraffe
Note: please don't do this in real life. 😘🦒

At the start we saw that the model predicts the animal in the cat-dog picture is a giraffe. 
Using grad cam and hooks, can you find the parts of the model that are responsible for this prediction 
and remove them?

This is a very exploratory exercise, you will have to make your own hypothesis and experiments.
