Imports & definitions
---------------------

In [None]:
%load_ext autoreload
%autoreload 2

# imports
import os
import requests

import numpy as np
import timm
import torch
import matplotlib.pyplot as plt
from PIL import Image

from fgvc.special.grad_cam import GradCamTimm, plot_grad_cam

# constants
IMG_URL = "https://cdn.pixabay.com/photo/2015/11/16/22/14/cat-1046544_960_720.jpg"
IMG_SIZE = 224

# settings
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

Input image
-----------

In [None]:
# Get and prepare an image
img = Image.open(requests.get(IMG_URL, stream=True).raw).convert("RGB")
# img = Image.open(IMG).convert("RGB")

# resize the image
if img.size[0] > img.size[1]:
    im_newsize = (IMG_SIZE, int(IMG_SIZE / img.size[0] * img.size[1]))
else:
    im_newsize = (int(IMG_SIZE / img.size[1] * img.size[0]), IMG_SIZE)
img = img.resize(im_newsize)
np_img = np.array(img, dtype=np.uint8)

# create a batch
tensor_img = torch.from_numpy(np_img).permute(2, 0, 1).float() / 255.0  # (H, W, C) to (C, H, W)

# show the image
plt.axis("off")
plt.imshow(np_img)

Model attention
---------------

#### 1. Select the last convolutional layer automatically (default)

- target_layer of GradCamTimm must be None:

        grad_cam = GradCamTimm(<timm model>, target_layer=None)  # or just: GradCamTimm(<timm model>)

- then you need to call the instance of GradCamTimm to receive the attentions for your image
- additionally, you can pass the target_cls as number in range <i>[0, N - 1]</i>, where N is number of classes,
  to get attentions. Argmax of the classification head is taken in default

        attn = grad_cam(<single image>, target_cls=<None (default) or number in range [0, N - 1]>)

- finally, you can visualize the attentions
    1. as heatmap with a scale using
       <code>grad_cam.visualize_as_heatmap(&lt;subplot ax&gt;, attn)</code>
    2. as attention to partial parts of the image using
       <code>grad_cam.visualize_as_image(&lt;subplot ax&gt;, attn, &lt;single image&gt;)</code>

- besides, you can get:
    1. the original features that has been weighted with attention
       <code>feats = grad_cam.get_features()</code>
    2. gradients that was used for weighting the features
       <code>grads = grad_cam.get_gradients()</code>

In [None]:
# create a model and Grad-CAM instance
net = timm.create_model("resnet50", pretrained=True)
grad_cam = GradCamTimm(net, device=device)

# get the attentions
attn, (feats, grads) = grad_cam(tensor_img)

In [None]:
plot_grad_cam(tensor_img, model=net, device=device)

#### 2. Select the target layer manually

- target_layer of GradCamTimm must not be None:

        grad_cam = GradCamTimm(<timm model>, target_layer=<required layer>)

- you can also set the target_layer using:

        grad_cam.set_target_layer("<target layer>")

- you can list possible layers typing:

        pos_targ_layers = grad_cam.get_possible_target_layers()

In [None]:
pos_targ_layers = grad_cam.get_possible_target_layers()
print("Possible target layers for your model:")
[print(f"- {l}") for l in pos_targ_layers];

In [None]:
plot_grad_cam(tensor_img, model=net, device=device, target_layer="layer2")
