## Imports & definitions

In [None]:
%load_ext autoreload
%autoreload 2

# imports
import requests

import numpy as np
import timm
import torch
import matplotlib.pyplot as plt
from PIL import Image

from fgvc.special.grad_cam import GradCamTimm, plot_grad_cam
from fgvc.utils.utils import set_cuda_device

# constants
IMG_URL = "https://cdn.pixabay.com/photo/2015/11/16/22/14/cat-1046544_960_720.jpg"
IMG_SIZE = 224

device = set_cuda_device("0")

## Input image

In [None]:
def preprocess_image(img_pil, img_size):
    # resize the image
    # if img.size[0] > img.size[1]:
    #     im_newsize = (IMG_SIZE, int(IMG_SIZE / img.size[0] * img.size[1]))
    # else:
    #     im_newsize = (int(IMG_SIZE / img.size[1] * img.size[0]), IMG_SIZE)
    img_np = np.asarray(img_pil.resize(img_size), dtype=np.uint8)

    # create a batch
    img_torch = torch.from_numpy(img_np).permute(2, 0, 1).float() / 255.0  # (H, W, C) to (C, H, W)

    return img_np, img_torch


# Get and prepare an image
img_pil = Image.open(requests.get(IMG_URL, stream=True).raw).convert("RGB")
img_np, img_torch = preprocess_image(img_pil, img_size=(IMG_SIZE, IMG_SIZE))

# show the image
plt.axis("off")
plt.imshow(img_np)

## Model attention

### 1. Select the last convolutional layer automatically (default)

- target_layer of GradCamTimm must be None:

        grad_cam = GradCamTimm(<timm model>, target_layer=None)  # or just: GradCamTimm(<timm model>)

- then you need to call the instance of GradCamTimm to receive the attentions for your image
- additionally, you can pass the target_cls as number in range <i>[0, N - 1]</i>, where N is number of classes,
  to get attentions. Argmax of the classification head is taken in default

        attn = grad_cam(<single image>, target_cls=<None (default) or number in range [0, N - 1]>)

- finally, you can visualize the attentions
    1. as heatmap with a scale using
       <code>grad_cam.visualize_as_heatmap(&lt;subplot ax&gt;, attn)</code>
    2. as attention to partial parts of the image using
       <code>grad_cam.visualize_as_image(&lt;subplot ax&gt;, attn, &lt;single image&gt;)</code>

- besides, you can get:
    1. the original features that has been weighted with attention
       <code>feats = grad_cam.get_features()</code>
    2. gradients that was used for weighting the features
       <code>grads = grad_cam.get_gradients()</code>

In [None]:
# create a model and Grad-CAM instance
net = timm.create_model("resnet50", pretrained=True)
grad_cam = GradCamTimm(net, device=device)

# get the attentions
attn, (feats, grads) = grad_cam(img_torch)

In [None]:
plot_grad_cam(img_torch, model=net, device=device)

### 2. Select the target layer manually

- target_layer of GradCamTimm must not be None:

        grad_cam = GradCamTimm(<timm model>, target_layer=<required layer>)

- you can also set the target_layer using:

        grad_cam.set_target_layer("<target layer>")

- you can list possible layers typing:

        pos_targ_layers = grad_cam.get_possible_target_layers()

In [None]:
pos_targ_layers = grad_cam.get_possible_target_layers()
print("Possible target layers for your model:")
[print(f"- {l}") for l in pos_targ_layers];

In [None]:
plot_grad_cam(img_torch, model=net, device=device, target_layer="layer4")

## Use Different Architectures

In [None]:
_, img_torch_224 = preprocess_image(img_pil, img_size=(224, 224))
_, img_torch_384 = preprocess_image(img_pil, img_size=(384, 384))

### ResNet-50

In [None]:
net = timm.create_model("resnet50", pretrained=True)
plot_grad_cam(img_torch_224, model=net, device=device)

In [None]:
net = timm.create_model("resnet50", pretrained=True)
plot_grad_cam(img_torch_384, model=net, device=device)

### ViT

In [None]:
net = timm.create_model("vit_base_patch16_384", pretrained=True)
plot_grad_cam(img_torch_384, model=net, device=device, target_layer="blocks")

In [None]:
net = timm.create_model("vit_base_patch8_224", pretrained=True)
plot_grad_cam(img_torch_224, model=net, device=device, target_layer="blocks")

## SwinT

In [None]:
net = timm.create_model("swin_large_patch4_window12_384", pretrained=True)
plot_grad_cam(img_torch_384, model=net, device=device, target_layer="layers")

In [None]:
net = timm.create_model("swin_base_patch4_window7_224_in22k", pretrained=True)
plot_grad_cam(img_torch_224, model=net, device=device, target_layer="layers")

## Use Different Architectures

In [None]:
_, img_torch_224 = preprocess_image(img_pil, img_size=(224, 224))
_, img_torch_384 = preprocess_image(img_pil, img_size=(384, 384))

### ResNet-50

In [None]:
net = timm.create_model("resnet50", pretrained=True)
plot_grad_cam(img_torch_224, model=net, device=device)

In [None]:
net = timm.create_model("resnet50", pretrained=True)
plot_grad_cam(img_torch_384, model=net, device=device)

### ViT

In [None]:
net = timm.create_model("vit_base_patch16_384", pretrained=True)
plot_grad_cam(img_torch_384, model=net, device=device, target_layer="blocks")

In [None]:
net = timm.create_model("vit_base_patch8_224", pretrained=True)
plot_grad_cam(img_torch_224, model=net, device=device, target_layer="blocks")

## SwinT

In [None]:
net = timm.create_model("swin_large_patch4_window12_384", pretrained=True)
plot_grad_cam(img_torch_384, model=net, device=device, target_layer="layers")

In [None]:
net = timm.create_model("swin_base_patch4_window7_224_in22k", pretrained=True)
plot_grad_cam(img_torch_224, model=net, device=device, target_layer="layers")