In [1]:
import argparse
import cv2
import numpy as np
import torch

from pytorch_grad_cam import GradCAM, \
    ScoreCAM, \
    GradCAMPlusPlus, \
    AblationCAM, \
    XGradCAM, \
    EigenCAM, \
    EigenGradCAM, \
    LayerCAM, \
    FullGrad

from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, \
    preprocess_image
from pytorch_grad_cam.ablation_layer import AblationLayerVit

import timm

In [2]:
def reshape_transform(tensor, height=14, width=14):
    result = tensor[:, 1:, :].reshape(tensor.size(0),
                                      height, width, tensor.size(2))

    # Bring the channels to the first dimension,
    # like in CNNs.
    result = result.transpose(2, 3).transpose(1, 2)
    return result


if __name__ == '__main__':
    """ python vit_gradcam.py --image-path <path_to_image>
    Example usage of using cam-methods on a VIT network.

    """
    
    # REAL ARGUMENTS FOR REAL MEN
    method = 'gradcam++'
    use_cuda = True
    image_path = r"datasets\140k Real vs Fake\real_vs_fake\real-vs-fake\train\fake\00B4R41FLE.jpg"
    eigen_smooth = True
    aug_smooth = True
    
    methods = \
        {"gradcam": GradCAM,
         "scorecam": ScoreCAM,
         "gradcam++": GradCAMPlusPlus,
         "ablationcam": AblationCAM,
         "xgradcam": XGradCAM,
         "eigencam": EigenCAM,
         "eigengradcam": EigenGradCAM,
         "layercam": LayerCAM,
         "fullgrad": FullGrad}


    if method not in list(methods.keys()):
        raise Exception(f"method should be one of {list(methods.keys())}")

    # This one does work, just uncomment the reshape_transform argument in the cam methods below.
    # Results still look a bit odd
    # model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
    # target_layers = [model.blocks[-1].norm1]

    # NO IDEA HOW TO MAKE THIS ONE WORK
    # model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=False, num_classes=2)
    # model.load_state_dict(torch.load('CollectedData\Models\swinNewInit-0_Epoch25_Batch64_LR0.001_Momentum0.9'))
    # target_layers = [model.layers[-1].blocks[-1].norm1]
    
    # Very stripey vertically, may be passing in the wrong layer?
    # model = torch.load('CollectedData\Models\swinNewInit-0_Epoch25_Batch64_LR0.001_Momentum0.9')
    # target_layers = [model.layers[-1].blocks[-1].norm1]

    # Very stripey vertically, may be passing in the wrong layer?
    # model = timm.models.swin_base_patch4_window7_224_in22k(pretrained=True, num_classes = 2)
    # target_layers = [model.layers[-1].blocks[-1].norm1]


    # Visualizable VIT, this one should work but IDK which layer to use
    # model = torch.load(r'CollectedData\Models\visualizableVIT-0_Epoch25_Batch64_LR0.001_Momentum0.9') # Replace with model path
    # modelChildren = list(model.children())
    # target_layers = [] # Change which layer we look at, it's a lot easier to examine modelChildren in the debugger



    # Custom CNN, this one does work
    model = torch.load('CollectedData\Models\DWSConvNet3_learnedPoolingHwy-0_Epoch25_Batch64_LR0.001_Momentum0.9') # Replace with model path
    modelChildren = list(model.children())
    print(modelChildren) # Uncomment me at your own risk... It's a long model
    target_layers = [modelChildren[3]] # Change which layer we look at, it's a lot easier to examine modelChildren in the debugger
    
    model.eval()

    if use_cuda:
        model = model.cuda()


    if method not in methods:
        raise Exception(f"Method {method} not implemented")

    if method == "ablationcam":
        cam = methods["ablationcam"](model=model,
                                   target_layers=target_layers,
                                #    use_cuda=use_cuda,
                                #    reshape_transform=reshape_transform,
                                   ablation_layer=AblationLayerVit())
    else:
        cam = methods[method](model=model,
                                   target_layers=target_layers,
                                #    use_cuda=use_cuda,
                                #    reshape_transform=reshape_transform
                                )

    rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
    rgb_img = cv2.resize(rgb_img, (224, 224))
    rgb_img = np.float32(rgb_img) / 255
    input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5],
                                    std=[0.5, 0.5, 0.5])

    # If None, returns the map for the highest scoring category.
    # Otherwise, targets the requested category.
    targets = None

    # AblationCAM and ScoreCAM have batched implementations.
    # You can override the internal batch size for faster computation.
    cam.batch_size = 32

    grayscale_cam = cam(input_tensor=input_tensor,
                        targets=targets,
                        eigen_smooth=eigen_smooth,
                        aug_smooth=aug_smooth)

    # Here grayscale_cam has only one image in the batch
    grayscale_cam = grayscale_cam[0, :]

    cam_image = show_cam_on_image(rgb_img, grayscale_cam)
    outFileName = f'{method}_cam.jpg'
    cv2.imwrite(outFileName, cam_image)
    
    print(f'Exported result to {outFileName}!')


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


AttributeError: 'VisualizableVIT' object has no attribute 'blocks'