# Grad-CAM for ViT

This notebook uses adjusted code from: https://github.com/jacobgil/pytorch-grad-cam

In [27]:
import cv2
import numpy as np
import torch

from pytorch_grad_cam import GradCAM, \
    ScoreCAM, \
    GradCAMPlusPlus, \
    AblationCAM, \
    XGradCAM, \
    EigenCAM, \
    EigenGradCAM, \
    LayerCAM, \
    FullGrad

from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image, \
    preprocess_image

from transformers import ViTForImageClassification

In [10]:
original_path = "../dataset/test/Healthy/0c1667a2-61d7-4dee-b4d9-0d141a1ceb20___Mt.N.V_HL 9127_new30degFlipLR.JPG"
augmented_path = "../augmented-dataset/test/Black Rot/0aff8add-93ad-4099-97ae-23515744e620___FAM_B.Rot 0748_flipLR_aug2.png"

device = "cuda" if torch.cuda.is_available() else "cpu"
dirs = ["../final-models/hp-vit-base", "../final-models/aug-hp-vit-base"]

In [None]:
def reshape_transform(x, height=14, width=14):
    if isinstance(x, tuple):
        x = x[0]
    x = x[:, 1:, :].reshape(x.size(0), height, width, x.size(2))
    x = x.permute(0, 3, 1, 2)
    return x


if __name__ == '__main__':
    methods = \
        {"gradcam": GradCAM,
         "scorecam": ScoreCAM,
         "gradcam++": GradCAMPlusPlus,
         "ablationcam": AblationCAM,
         "xgradcam": XGradCAM,
         "eigencam": EigenCAM,
         "eigengradcam": EigenGradCAM,
         "layercam": LayerCAM,
         "fullgrad": FullGrad}

    model = ViTForImageClassification.from_pretrained(
        "google/vit-base-patch16-224", 
        num_labels=4, 
        ignore_mismatched_sizes=True
    ).to(device)

    model.load_state_dict(torch.load('../final-models/aug-hp-vit-base/model_lr0.0001_bs16.pth', map_location=device))
    model.eval()

In [None]:
target_layer = model.vit.encoder.layer[9].output.dense

class ViTWrapper(torch.nn.Module):
    def __init__(self, vit_model):
        super(ViTWrapper, self).__init__()
        self.vit_model = vit_model

    def forward(self, x):
        return self.vit_model(x).logits

wrapped_model = ViTWrapper(model)

cam = GradCAM(model=wrapped_model, target_layers=[target_layer], reshape_transform=reshape_transform)

rgb_img = cv2.imread(augmented_path)[:, :, ::-1]
rgb_img = cv2.resize(rgb_img, (224, 224))
rgb_img_display = np.float32(rgb_img) / 255.0
input_tensor = preprocess_image(rgb_img_display.copy() , mean=[0.5, 0.5, 0.5],
                                std=[0.5, 0.5, 0.5]).to(device)

targets = None

cam.batch_size = 16

# grayscale_cam = cam(input_tensor=input_tensor, targets=targets, aug_smooth=True)[0]

grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]


cam_image = show_cam_on_image(rgb_img_display, grayscale_cam, use_rgb=False)
cv2.imwrite(f'grad_cam.jpg', cam_image)

print("rgb_img_display shape:", rgb_img_display.shape, "range:", rgb_img_display.min(), rgb_img_display.max())
print("grayscale_cam shape:", grayscale_cam.shape, "range:", grayscale_cam.min(), grayscale_cam.max())


rgb_img_display shape: (224, 224, 3) range: 0.0 1.0
grayscale_cam shape: (224, 224) range: 0.0 0.9999999
