# Grad CAM examples

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import requests
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from fgvc.core.models import get_model
from fgvc.special.grad_cam import GradCamTimm, plot_grad_cam, plot_heatmap, plot_image_heatmap
from fgvc.utils.utils import set_cuda_device

IMG_URL = "https://cdn.pixabay.com/photo/2015/11/16/22/14/cat-1046544_960_720.jpg"

device = set_cuda_device("0")

## Input image

In [None]:
# load image
image = np.asarray(Image.open(requests.get(IMG_URL, stream=True).raw).convert("RGB"))

# show the image
plt.imshow(image)
plt.axis("off")
plt.show()

## Model attention

### 1. Select the last convolutional layer automatically (default)

- target_layer of GradCamTimm must be None:

        grad_cam = GradCamTimm(<timm model>, target_layer=None)  # or just: GradCamTimm(<timm model>)

- then you need to call the instance of GradCamTimm to receive the attentions for your image
- additionally, you can pass the target_cls as number in range <i>[0, N - 1]</i>, where N is number of classes,
  to get attentions. Argmax of the classification head is taken in default

        attn = grad_cam(<single image>, target_cls=<None (default) or number in range [0, N - 1]>)

- finally, you can visualize the attentions
    1. as heatmap with a scale using
       <code>grad_cam.visualize_as_heatmap(&lt;subplot ax&gt;, attn)</code>
    2. as attention to partial parts of the image using
       <code>grad_cam.visualize_as_image(&lt;subplot ax&gt;, attn, &lt;single image&gt;)</code>

- besides, you can get:
    1. the original features that has been weighted with attention
       <code>feats = grad_cam.get_features()</code>
    2. gradients that was used for weighting the features
       <code>grads = grad_cam.get_gradients()</code>

In [None]:
# create a model and test-time augmentation
model = get_model("resnet50", pretrained=True)

# create Grad-CAM instance and get the attentions
grad_cam = GradCamTimm(model, device=device)
weighted_features, (features, gradients) = grad_cam(image)

In [None]:
plot_grad_cam(image, model, device=device)

### 2. Select the target layer manually

- target_layer of GradCamTimm must not be None:

        grad_cam = GradCamTimm(<timm model>, target_layer=<required layer>)

- you can also set the target_layer using:

        grad_cam.set_target_layer("<target layer>")

- you can list possible layers typing:

        pos_targ_layers = grad_cam.get_possible_target_layers()

In [None]:
pos_targ_layers = grad_cam.get_possible_target_layers()
print("Possible target layers for your model:")
[print(f"- {l}") for l in pos_targ_layers];

In [None]:
plot_grad_cam(image, model, device=device, target_layer="layer4")

In [None]:
plot_grad_cam(image, model, device=device, target_layer="layer4", use_min_zero=False)

## Use Different Architectures

### ViT

In [None]:
model = get_model("vit_base_patch16_384", pretrained=True)
plot_grad_cam(image, model, device=device, target_layer="blocks")

In [None]:
model = get_model("vit_base_patch8_224", pretrained=True)
plot_grad_cam(image, model, device=device, target_layer="blocks")

## SwinT

In [None]:
model = get_model("swin_large_patch4_window12_384", pretrained=True)
plot_grad_cam(image, model, device=device, target_layer="layers")

In [None]:
model = get_model("swin_base_patch4_window7_224_in22k", pretrained=True)
plot_grad_cam(image, model, device=device, target_layer="layers")