In [1]:
import torch

import torch.nn.functional as F
from torchvision.utils import make_grid, save_image

from utils import visualize_cam, Normalize
from gradcam import GradCAM, GradCAMpp

import os

from PIL import Image as Image

import numpy as np

device = torch.device("cpu")
torch.manual_seed(42)

<torch._C.Generator at 0x2b237f2c170>

In this notebook we calculate GradCAM for all samples. GradCAM is a visual XAI method to show which parts of the image are the most influential for classification. The PyTorch implementation used for this project is available [here](https://github.com/1Konny/gradcam_plus_plus-pytorch)

In [2]:
# we load our previous resnet model
model = torch.load('model.pt')
model.eval()
model = model.to(device)
for param in model.parameters():
    param.requires_grad = True

In [3]:
root_dir = "./data"

Images = []

# load all images

original_dir = f"{root_dir}/original"
for filename in os.listdir(original_dir):
    path = f"{original_dir}/{filename}"
    filename_no_extension = filename.split(".")[0]
    Images.append(tuple([Image.open(path).convert('RGB'), filename_no_extension]))

augmented_dir = f"{root_dir}/augmented"
for filename in os.listdir(augmented_dir):
    path = f"{augmented_dir}/{filename}"
    filename_no_extension = filename.split(".")[0]
    Images.append(tuple([Image.open(path).convert('RGB'), filename_no_extension]))

In [4]:
for image, filename in Images:
    # normalize and transform the image to tensor
    normalizer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    torch_img = torch.from_numpy(np.asarray(image)).permute(2, 0, 1).unsqueeze(0).float().div(255)
    torch_img = F.interpolate(torch_img, size=(256, 256), mode='bilinear', align_corners=False)
    normed_torch_img = normalizer(torch_img)
    
    # load gradcam
    cam_dict = dict()
    resnet_model_dict = dict(type='resnet', arch=model, layer_name='layer4', input_size=(256, 256))
    resnet_gradcam = GradCAM(resnet_model_dict, True)
    resnet_gradcampp = GradCAMpp(resnet_model_dict, True)
    cam_dict['resnet'] = [resnet_gradcam, resnet_gradcampp]
    outputs = []
    for gradcam, gradcam_pp in cam_dict.values():
        # calculate gradcam mask for all images
        mask, _ = gradcam(normed_torch_img)
        heatmap, result = visualize_cam(mask, torch_img)
    
        mask_pp, _ = gradcam_pp(normed_torch_img)
        heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img)
        
        outputs.append(torch.stack([torch_img.squeeze().cpu(), heatmap, heatmap_pp, result, result_pp], 0))
    
    
    outputs = make_grid(torch.cat(outputs, 0), nrow=5)
    output_dir = './data/gradcam'
    os.makedirs(output_dir, exist_ok=True)
    output_name = f"{filename}_gradcam.jpg"
    output_path = os.path.join(output_dir, output_name)
    
    save_image(outputs, output_path)

  torch_img = torch.from_numpy(np.asarray(image)).permute(2, 0, 1).unsqueeze(0).float().div(255)
  return self._call_impl(*args, **kwargs)
