# Grad-CAM for SwinV2

This notebook uses adjusted code from: https://github.com/jacobgil/pytorch-grad-cam

In [79]:
import cv2
import numpy as np
import torch

from pytorch_grad_cam import GradCAM, \
    ScoreCAM, \
    GradCAMPlusPlus, \
    AblationCAM, \
    XGradCAM, \
    EigenCAM, \
    EigenGradCAM, \
    LayerCAM, \
    FullGrad

from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image, \
    preprocess_image

from transformers import Swinv2ForImageClassification

In [80]:
original_path = "../dataset/test/Healthy/0c1667a2-61d7-4dee-b4d9-0d141a1ceb20___Mt.N.V_HL 9127_new30degFlipLR.JPG"
augmented_path = "../augmented-dataset/test/Black Rot/0aff8add-93ad-4099-97ae-23515744e620___FAM_B.Rot 0748_flipLR_aug2.png"

device = "cuda" if torch.cuda.is_available() else "cpu"
dirs = ["../final-models/hp-vit-base", "../final-models/aug-hp-vit-base"]

In [None]:
def reshape_transform(x):
    B, N, C = x.size()
    side = int(N ** 0.5)
    return x.view(B, side, side, C).permute(0, 3, 1, 2)


if __name__ == '__main__':
    methods = \
        {"gradcam": GradCAM,
         "scorecam": ScoreCAM,
         "gradcam++": GradCAMPlusPlus,
         "ablationcam": AblationCAM,
         "xgradcam": XGradCAM,
         "eigencam": EigenCAM,
         "eigengradcam": EigenGradCAM,
         "layercam": LayerCAM,
         "fullgrad": FullGrad}

    model = Swinv2ForImageClassification.from_pretrained(
        "microsoft/swinv2-base-patch4-window8-256",
        num_labels=4,
        ignore_mismatched_sizes=True
    )

    model.load_state_dict(torch.load('../final-models/aug-hp-swinv2-base/model_lr1e-05_bs32.pth', map_location=device))
    model.eval()

In [123]:
# target_layer = model.swinv2.encoder.layers[-1].blocks[-1].layernorm_after
target_layer = model.swinv2.layernorm

class SwinV2Wrapper(torch.nn.Module):
    def __init__(self, swin_model):
        super().__init__()
        self.model = swin_model

    def forward(self, x):
        return self.model(x).logits
    
wrapped_model = SwinV2Wrapper(model)

cam = GradCAM(model=wrapped_model, target_layers=[target_layer], reshape_transform=reshape_transform)

rgb_img = cv2.imread(augmented_path)[:, :, ::-1]
rgb_img = cv2.resize(rgb_img, (256, 256))
rgb_img_display = np.float32(rgb_img) / 255.0
input_tensor = preprocess_image(rgb_img_display.copy() , mean=[0.5, 0.5, 0.5],
                                std=[0.5, 0.5, 0.5]).to(device)

targets = None

cam.batch_size = 16

# grayscale_cam = cam(input_tensor=input_tensor, targets=targets, aug_smooth=True)[0]

grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]


cam_image = show_cam_on_image(rgb_img_display, grayscale_cam, use_rgb=False)
cv2.imwrite(f'grad_cam.jpg', cam_image)

print("rgb_img_display shape:", rgb_img_display.shape, "range:", rgb_img_display.min(), rgb_img_display.max())
print("grayscale_cam shape:", grayscale_cam.shape, "range:", grayscale_cam.min(), grayscale_cam.max())


rgb_img_display shape: (256, 256, 3) range: 0.0 1.0
grayscale_cam shape: (256, 256) range: 0.0 0.9999999
