# Import libraries

In [None]:
import os
import PIL
import numpy as np
import matplotlib.pyplot as plt
import cv2

import torch
from torch import nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms as T
from torchvision.utils import make_grid, save_image

from gradcam.utils import visualize_cam
from gradcam import GradCAM, GradCAMpp

# Define device (GPU or CPU)

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)

# Show image

In [None]:
# Read JPEG image file
img_path = 'tiny-imagenet-200/val/images/n01882714/val_5008.JPEG'
img = PIL.Image.open(img_path)

In [None]:
# Define transformation sequence
preprocess_transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224), 
    T.ToTensor(),
])

# Pre-processing
torch_img = preprocess_transform(img).to(device)

In [None]:
# Show image
plt.imshow(torch_img.permute(1, 2, 0))
plt.show()

# Load trained model

In [None]:
# Baseline model
model_baseline = models.resnet18(pretrained = False)
num_ftrs = model_baseline.fc.in_features
model_baseline.fc = nn.Linear(num_ftrs, 200)

model_path = 'ResNet18_baseline_model.pth'
model_baseline.load_state_dict(torch.load(model_path, map_location = torch.device(device)))

In [None]:
# OS regularization model
model = models.resnet18(pretrained = True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 200)

class SplittedResNet18(nn.Module):
    def __init__(self, resnet18):
        super().__init__()
        self.cnn = nn.Sequential(*list(resnet18.children())[:-1])
        self.flatten = nn.Flatten()
        self.fc = resnet18.fc

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        representation = self.flatten(self.cnn(x))
        output = self.fc(representation)
        return representation, output

model_os = SplittedResNet18(model)

model_path = 'ResNet18_OS_0.01.pth'
model_os.load_state_dict(torch.load(model_path, map_location = torch.device(device)))

In [None]:
del model

# Define Grad-CAM function

In [None]:
# Baseline model
def gradcam_baseline(model, img_fpath):
    
    model.eval()
    
    def __extract(grad):
        global feature_grad
        feature_grad = grad
        
    img = PIL.Image.open(img_fpath).convert('RGB')
    transforms = T.Compose([
        T.Resize(256),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    img = transforms(img)
    img = img.unsqueeze(0)
        
    # get features from the last convolutional layer
    x = model.conv1(img)
    x = model.bn1(x)
    x = model.relu(x)
    x = model.maxpool(x)
    x = model.layer1(x)
    x = model.layer2(x)
    x = model.layer3(x)
    x = model.layer4(x)
    features = x    
    
    # hook for the gradients
    def __extract_grad(grad):
        global feature_grad
        feature_grad = grad
    features.register_hook(__extract_grad)
    
    # get the output from the whole VGG architecture
    x = model.avgpool(x)
    x = x.view(x.size(0), -1)
    output = model.fc(x)
    pred = torch.argmax(output).item()
    print(pred)    
    
    # get the gradient of the output
    output[:, pred].backward()
    
    # pool the gradients across the channels
    pooled_grad = torch.mean(feature_grad, dim=[0, 2, 3])
    
    # weight the channels with the corresponding gradients
    # (L_Grad-CAM = alpha * A)
    features = features.detach()
    for i in range(features.shape[1]):
        features[:, i, :, :] *= pooled_grad[i]
        
    # average the channels and create an heatmap
    # ReLU(L_Grad-CAM)
    heatmap = torch.mean(features, dim=1).squeeze()
    heatmap = np.maximum(heatmap, 0)

    # normalization for plotting
    heatmap = heatmap / torch.max(heatmap)
    heatmap = heatmap.numpy()
    
    # project heatmap onto the input image
    img = cv2.imread(img_fpath)
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    
    heatmap_plot = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    image_plot = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    superimposed_img = heatmap * 0.8 + img
    superimposed_img = np.uint8(255 * superimposed_img / np.max(superimposed_img))
    superimposed_img = cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB) 

    fig, ax = plt.subplots(1, 3, figsize = (14, 3))
    ax[0].imshow(image_plot, aspect='auto')
    ax[1].imshow(heatmap_plot, aspect='auto')
    ax[2].imshow(superimposed_img, aspect='auto')
    plt.show()

In [None]:
# OS regularization model
def gradcam_OS_model(model, img_fpath):
    
    model.eval()
    
    def __extract(grad):
        global feature_grad
        feature_grad = grad
        
    img = PIL.Image.open(img_fpath).convert('RGB')
    transforms = T.Compose([
        T.Resize(256),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    img = transforms(img)
    img = img.unsqueeze(0)
    
    # get features from the last convolutional layer
    x = list(model.cnn.children())[0](img)
    x = list(model.cnn.children())[1](x)
    x = list(model.cnn.children())[2](x)
    x = list(model.cnn.children())[3](x)    
    x = list(model.cnn.children())[4](x)
    x = list(model.cnn.children())[5](x)
    x = list(model.cnn.children())[6](x)
    x = list(model.cnn.children())[7](x)
    features = x  
    
    # hook for the gradients
    def __extract_grad(grad):
        global feature_grad
        feature_grad = grad
    features.register_hook(__extract_grad)
    
    # get the output from the whole VGG architecture
    x = list(model.cnn.children())[8](x)
    x = x.view(x.size(0), -1)
    output = list(model.children())[2](x)
    pred = torch.argmax(output).item()
    print(pred)   

    # get the gradient of the output
    output[:, pred].backward()
    
    # pool the gradients across the channels
    pooled_grad = torch.mean(feature_grad, dim=[0, 2, 3])
    
    # weight the channels with the corresponding gradients
    # (L_Grad-CAM = alpha * A)
    features = features.detach()
    for i in range(features.shape[1]):
        features[:, i, :, :] *= pooled_grad[i]
        
    # average the channels and create an heatmap
    # ReLU(L_Grad-CAM)
    heatmap = torch.mean(features, dim=1).squeeze()
    heatmap = np.maximum(heatmap, 0)

    # normalization for plotting
    heatmap = heatmap / torch.max(heatmap)
    heatmap = heatmap.numpy()
    
    # project heatmap onto the input image
    img = cv2.imread(img_fpath)
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    
    heatmap_plot = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    image_plot = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    superimposed_img = heatmap * 0.8 + img
    superimposed_img = np.uint8(255 * superimposed_img / np.max(superimposed_img))
    superimposed_img = cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB) 

    fig, ax = plt.subplots(1, 3, figsize = (14, 3))
    ax[0].imshow(image_plot, aspect='auto')
    ax[1].imshow(heatmap_plot, aspect='auto')
    ax[2].imshow(superimposed_img, aspect='auto')
    plt.show()

# Show Grad-CAM image

In [None]:
img_path = 'tiny-imagenet-200/val/images/n01882714/val_5008.JPEG'

In [None]:
gradcam_baseline(model_baseline, img_path)

In [None]:
gradcam_OS_model(model_os, img_path)