# Grad-Cam Results

In [1]:
%pip install torch torchvision matplotlib seaborn opencv-python wilds ipykernel lightning-lite pytorch-lightning==1.8.6

Collecting torch
  Using cached torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata (28 kB)
Collecting torchvision
  Using cached torchvision-0.21.0-cp311-cp311-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting matplotlib
  Using cached matplotlib-3.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting seaborn
  Using cached seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Collecting opencv-python
  Using cached opencv_python-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting wilds
  Using cached wilds-2.0.0-py3-none-any.whl.metadata (22 kB)
Collecting lightning-lite
  Using cached lightning_lite-1.8.6-py3-none-any.whl.metadata (2.7 kB)
Collecting pytorch-lightning==1.8.6
  Using cached pytorch_lightning-1.8.6-py3-none-any.whl.metadata (23 kB)
Collecting numpy>=1.17.2 (from pytorch-lightning==1.8.6)
  Using cached numpy-2.2.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)

In [2]:
import torch
import torchvision.transforms as T
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from collections import OrderedDict
import numpy as np
import matplotlib.pyplot as plt
import cv2
from wilds import get_dataset
from torchvision.transforms import ToTensor, Normalize, Compose, CenterCrop, Resize
from torchvision.models import vgg16, alexnet, resnet18
from utils.datasets import BinaryDataset
import warnings 
warnings.filterwarnings('ignore')

In [5]:
class VGGWithGradCAM(torch.nn.Module):
    def __init__(self, model):
        super(VGGWithGradCAM, self).__init__()
        
        # get the pretrained VGG19 network
        self.vgg = model
        
        # disect the network to access its last convolutional layer
        self.features_conv = self.vgg.features[:30]
        
        # get the max pool of the features stem
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        
        # get the classifier of the vgg19
        self.classifier = self.vgg.classifier
        
        # placeholder for the gradients
        self.gradients = None
    
    # hook for the gradients of the activations
    def activations_hook(self, grad):
        self.gradients = grad
        
    def forward(self, x):
        x = self.features_conv(x)
        
        # register the hook
        h = x.register_hook(self.activations_hook)
        
        # apply the remaining pooling
        x = self.max_pool(x)
        x = x.view((1, -1))
        x = self.classifier(x)
        return x
    
    # method for the gradient extraction
    def get_activations_gradient(self):
        return self.gradients
    
    # method for the activation exctraction
    def get_activations(self, x):
        return self.features_conv(x)

class AlexNetWithGradCAM(torch.nn.Module):
    def __init__(self, model):
        super(AlexNetWithGradCAM, self).__init__()
        
        # get the pretrained AlexNet network
        self.alexnet = model
        
        # disect the network to access its last convolutional layer
        self.features_conv = self.alexnet.features[:12]
        
        # get the max pool of the features stem
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
        
        # get the classifier of the vgg19
        self.classifier = self.alexnet.classifier
        
        # placeholder for the gradients
        self.gradients = None
    
    # hook for the gradients of the activations
    def activations_hook(self, grad):
        self.gradients = grad
        
    def forward(self, x):
        x = self.features_conv(x)
        
        # register the hook
        h = x.register_hook(self.activations_hook)
        
        # apply the remaining pooling
        x = self.max_pool(x)
        x = x.view((1, -1))
        x = self.classifier(x)
        return x
    
    # method for the gradient extraction
    def get_activations_gradient(self):
        return self.gradients
    
    def get_activations(self, x):
        return self.features_conv(x)

class ResNet18WithGradCAM(nn.Module):
  def __init__(self, model):
    super(ResNet18WithGradCAM, self).__init__()

    self.model = model
    self.feature_extractor = nn.Sequential(*list(self.model.children())[:-2]) # we take layers before the classifier and the avgpool
    self.avgpool = self.model.avgpool
    self.out = self.model.fc
    # placeholder for the gradients
    self.gradients = None

  def forward(self, x, reg_hook=True):
      x = self.feature_extractor(x)

      # register hook (needed for grad-cam)
      if reg_hook:
        x.register_hook(self.activations_hook)

      x = self.avgpool(x)

      x = x.view(x.shape[0], -1)  # reshape the tensor
      x = self.out(x)
      return x

  # hook for the gradients of the activations
  def activations_hook(self, grad):
      self.gradients = grad

  def get_activations_gradient(self):
      return self.gradients
  
  def get_activations(self, x):
      return self.feature_extractor(x)
     


In [10]:
if __name__ == "__main__":
    transform = Compose([
        Resize((224,224)),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    main_dataset = get_dataset(dataset="waterbirds", download=False,root_dir='../SCRATCH/')
    data = main_dataset.get_subset("val", transform=transform)
    data = BinaryDataset(data,which_dataset='waterbirds')
    
    data_to_visualize = main_dataset.get_subset("val", transform=None)
    data_loader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False, num_workers=2)

    norm = "L2"
    state_dict = torch.load(f"../SCRATCH/CFE_datasets/WaterBirds_{norm}", map_location='cpu', weights_only=False)['classifier']
    state_dict = OrderedDict((k.replace('model.resnet.', '', 1), v) for k, v in state_dict.items())
    
    resnet18_model = resnet18().cuda()
    resnet18_model.fc = nn.Linear(resnet18_model.fc.in_features,1)
    
    resnet18_model.load_state_dict(state_dict)
    model = ResNet18WithGradCAM(resnet18_model).cuda()
    model.zero_grad()
    model.eval()

    sample_idx = 10
    iter_dataloader = iter(data_loader)
    for i in range(sample_idx + 1):
        img, target = next(iter_dataloader)

    Class = 'Landbird' if target.float() == -1 else 'Waterbird' 
    model.zero_grad()
    pred = model(img.cuda())
    print(pred)
    predi = pred * target.cuda()
    predi.backward()

    gradients = model.get_activations_gradient()
    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
    activations = model.get_activations(img.cuda()).detach()
    
    for i in range(activations.size(1)):
        activations[:, i, :, :] *= pooled_gradients[i]
    
    heatmap = torch.mean(activations.cpu(), dim=1).squeeze()
    heatmap = np.maximum(heatmap, 0)
    heatmap /= torch.max(heatmap)
    
    plt.matshow(heatmap.squeeze())
    plt.colorbar()
    
    raw_image = data_to_visualize[sample_idx][0]
    raw_image.save(f'results/GradCAM/WaterBirds/WaterBirds_Original_{Class}.jpg')
    
    img = cv2.imread(f'results/GradCAM/WaterBirds/WaterBirds_Original_{Class}.jpg')
    heatmap = cv2.resize(heatmap.numpy(), (img.shape[1], img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    superimposed_img = heatmap * 0.4 + img
    cv2.imwrite(f'results/GradCAM/WaterBirds/WaterBirds_{norm}_{Class}_heatmap.jpg', superimposed_img)

    predicted_label = "WaterBird" if pred.item() > 0 else "LandBird"
    print(f"Predicted label: {predicted_label}, True label: {Class}")