In [None]:
import torch
from torch import nn
from torchvision import models
from torchvision.io import read_image
import torchvision.transforms.functional as TF
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
from skimage import transform
import os

import warnings
warnings.filterwarnings('ignore')

In [None]:
class ModifyModel(nn.Module):
    # Style1: for models that containing 'avgpool' and 'fc' at the end of model.names_children()
    # Style2: for models that containing 'features', 'avgpool' and 'classifier' in model.names_children()

    def __init__(self, model, name: str = '') -> None:
        super(ModifyModel, self).__init__()
        self.model = deepcopy(model)
        try:
            self.weight_matrix = model.fc.weight.detach().clone()
            # print(name + ':' + "Style1")
        except:
            # print(name + ':' + "Style2")
            self.weight_matrix = model.classifier[-1].weight.detach().clone()
        self.model.avgpool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1) if not name.startswith("vgg") else (7, 7)),
            nn.Flatten())
        self.name = name
        self.final_activation = [idx-1 for idx, (name, _) in enumerate(model.named_children()) if name == 'avgpool'][0]
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = True

    def forward(self, x):
        activation_list = []
        for name, child in self.model.named_children():
            if name == 'features':
                for seq in child:
                    x = seq(x)
                    activation_list.append(x)
                continue
            x = child(x)
            activation_list.append(x)
        return activation_list

# Functions
----------

In [None]:
def class_activation_map(input_path,
                         input_model,
                         weight_matrix,
                         visualize: bool = True,
                         device : str = 'cuda',
                         mean_vector: tuple or list = (0.485, 0.456, 0.406),
                         std_vector: tuple or list = (0.229, 0.224, 0.225)):

    # Modify mean and std vectors for element-wise operation with image array
    mean_vector = np.array(mean_vector)[..., None, None]
    std_vector = np.array(std_vector)[..., None, None]

    # Import and normalize image
    raw_image = read_image(input_path) / 255.
    input_tensor = ((raw_image - mean_vector) / std_vector).unsqueeze(0).float().to(device)

    # Get model results
    input_model.to(device)
    with torch.no_grad():
        activations = input_model(input_tensor)

    # Extract last layer weights specifically associated with predicted class node
    weight_vector = weight_matrix[torch.argmax(activations[-1])]
    weight_vector = weight_vector.to(device)

    # Get CAM based on activation-map and last layer weights
    final_map = torch.tensordot(weight_vector, activations[input_model.final_activation], dims=([0], [1]))

    # Get ready for colormap
    c, h, w = raw_image.shape
    color_map = TF.resize(final_map.cpu(), [h, w])[0]
    color_map = (color_map - color_map.min()) / (color_map.max() - color_map.min())
    color_map = plt.cm.jet(color_map)

    if visualize:
        show_image = raw_image.permute(1, 2, 0)
        plt.imshow(show_image, alpha=0.6)
        plt.imshow(color_map, alpha=0.4)
        plt.axis('off')
        plt.title(f"CAM|Model:{input_model.name}")
        plt.show()

    return final_map, color_map

In [None]:
def saliency_map(input_path,
                 input_model,
                 device: str = 'cuda',
                 target_label: int or float or np.array or torch.Tensor = torch.empty(0),
                 visualize: bool = True,
                 mean_vector: tuple or list = (0.485, 0.456, 0.406),
                 std_vector: tuple or list = (0.229, 0.224, 0.225)):

    # Check target type
    if not isinstance(target_label, torch.Tensor):
        target_label = torch.as_tensor(int(target_label)).unsqueeze(0)

    # Modify mean and std vectors for element-wise operation with image array
    mean_vector = np.array(mean_vector)[..., None, None]
    std_vector = np.array(std_vector)[..., None, None]

    # Load, normalize and convert image to torch.Tensor
    raw_image = read_image(input_path) / 255.
    input_tensor = ((raw_image - mean_vector) / std_vector).unsqueeze(0).float().to(device).requires_grad_()

    # loss function is needed for computing loss and derivative with respect to input image
    criterion = nn.CrossEntropyLoss()

    # push model to eval mode and make all parameters requires_grad to false (only input image needs gradient)
    input_model.eval()
    for p in input_model.parameters():
        p.requires_grad = False
    input_model.to(device)

    # Forward: use model output as target! (in model we trust)
    y_hat = input_model(input_tensor)[-1]
    y_true = torch.argmax(y_hat).unsqueeze(0) if not target_label.numel() else target_label
    loss = criterion(y_hat, y_true)

    # Compute gradient of loss with respect to input image
    input_grad = torch.autograd.grad(loss, input_tensor)[0][0]

    # Only magnitude of gradients are needed
    input_grad.abs_()

    # Normalize gradients for visualization
    input_grad = ((input_grad - input_grad.min()) / (input_grad.max() - input_grad.min())).detach().cpu()

    if visualize:
        plt.imshow(input_grad.sum(0), alpha=0.8, cmap='hot')
        plt.imshow(raw_image.permute(1, 2, 0), alpha=0.2)
        plt.axis('off')
        plt.title(f"SaliencyMap|Model:{input_model.name}")
        plt.show()

    return input_grad

In [None]:
# index of activations layer from each model that we want to use for GradCAM
target_dict = {'resnet18': [1, 3, 5], 'resnet50': [1, 3, 5], 'resnet101': [1, 3, 5], 'vgg16': [8, 16, 22],
               'googlenet': [2, 8, 15], 'efficientnet_b4': [1, 2, 4], 'efficientnet_b7': [1, 3, 5], 'mobilenet_v3': [1, 4, 11]}

In [None]:
def grad_cam(input_path,
             input_model,
             device: str = 'cuda',
             target_label: int or float or np.array or torch.Tensor = torch.empty(0),
             visualize: bool = True,
             mean_vector: tuple or list = (0.485, 0.456, 0.406),
             std_vector: tuple or list = (0.229, 0.224, 0.225)):

    # Check target type
    if not isinstance(target_label, torch.Tensor):
        target_label = torch.as_tensor(int(target_label)).unsqueeze(0)

    # Modify mean and std vectors for element-wise operation with image array
    mean_vector = np.array(mean_vector)[..., None, None]
    std_vector = np.array(std_vector)[..., None, None]

    # Load, normalize and convert image to torch.Tensor
    raw_image = read_image(input_path) / 255.
    input_tensor = ((raw_image - mean_vector) / std_vector).unsqueeze(0).float().to(device)

    # Loss value is needed for calculating derivatives
    criterion = nn.CrossEntropyLoss()

    # push model to eval mode and make parameters requires_grad to True because we want derivative of middle activations
    input_model.to(device)
    input_model.eval()
    for param in input_model.parameters():
        param.requires_grad = True

    # Forward: if no target-label we use model output as target! (in model we trust)
    activations = input_model(input_tensor)
    y_true = torch.argmax(activations[-1]).unsqueeze(0) if not target_label.numel() else target_label
    loss = criterion(activations[-1], y_true)

    # Compute gradient of loss with respect to activations
    target_activations = [activations[i] for i in target_dict[input_model.name]]

    # # Option1: retain non-leaf activation gradients during backward
    # for each in target_activations:
    #     each.retain_grad()
    # loss.backward()
    #
    # # Averaging derivatives in spacial dimensions(height and width) and create a weight vector
    # # torch.where(i>0, i , torch.tensor(0, device=device, dtype=torch.float))
    # target_grads = [torch.mean(torch.abs(i.grad), dim=(0, 2, 3)) for i in target_activations]
    # for param in input_model.parameters():
    #     param.grad.zero_()
    # for act in target_activations:
    #     act.grad.zero_()

    # Option2: Use autograd
    target_grads = [torch.autograd.grad(loss, i, retain_graph=True)[0] for i in target_activations]
    # Averaging derivatives in spacial dimensions(height and width) and create a weight vector
    # torch.where(i>0, i , torch.tensor(0, device=device, dtype=torch.float))
    target_grads = [torch.mean(torch.abs(i), dim=(0, 2, 3)) for i in target_grads]

    # Get GradCAM based on activation-maps and their derivatives
    final_maps = []
    for m, g in zip(target_activations, target_grads):
        final_maps.append(torch.tensordot(m, g, dims=([1], [0])).squeeze().detach().cpu())

    if visualize:
        plt.figure(figsize=(15, 5))
        show_image = raw_image.permute(1, 2, 0)
        for idx, each_map in enumerate(final_maps):
            plt.subplot(1, len(final_maps), idx+1)
            plt.imshow(transform.resize(each_map, output_shape=show_image.shape[:-1]), alpha=0.8, cmap='hot')
            plt.imshow(show_image, alpha=0.2)
            plt.axis('off')
            plt.title(f"Activation:{target_dict[input_model.name][idx]:02d}|Model:{input_model.name}")
        plt.show()

    return final_maps

In [None]:
base_model = models.resnet50(pretrained=True)
mod_model = ModifyModel(base_model, name='resnet50')
image_path = "pics/golden_retriever_1.jpeg"

print("Grad-Cam")
_ = grad_cam(image_path, mod_model)

In [None]:
print("Saliency-Map")
_ = saliency_map(image_path, mod_model)

In [None]:
print("Class-Activation-Map")
_ = class_activation_map(image_path, mod_model, mod_model.weight_matrix)

# Group Visualization
--------------

In [None]:
resnet18 = models.resnet18(pretrained=True)
resnet50 = models.resnet50(pretrained=True)
resnet101 = models.resnet101(pretrained=True)
googlenet = models.googlenet(pretrained=True)
vgg16 = models.vgg16(pretrained=True)

In [None]:
# Model list for CAM
name_list = ['resnet18', 'resnet50', 'resnet101', 'googlenet']
root_image_path = "pics/samples"
num_samples = 10
sample_list = [os.path.join(root_image_path, i) for i in os.listdir(root_image_path)][:num_samples]
print(f"num_samples: {num_samples}")

In [None]:
# Create a 3x10 subplot figure
fig, axs = plt.subplots(len(name_list), len(sample_list), figsize=(35, 15))

# Flatten the axes array for easy indexing
flatten_axs = axs.flatten()

# Loop over each image and plot it on its corresponding subplot
OUTPUT_SHAPE = (224, 224, 3)
for i, ax_list in enumerate(axs):
    cam_model = ModifyModel(eval(name_list[i]), name=name_list[i])
    for j, ax in enumerate(ax_list):
        _, cam_output = class_activation_map(sample_list[j], cam_model, cam_model.weight_matrix, visualize=False)
        ax.imshow(transform.resize(plt.imread(sample_list[j]), OUTPUT_SHAPE), alpha=0.9)
        ax.imshow(transform.resize(cam_output[:, :, :-1], OUTPUT_SHAPE), alpha=0.4)
        ax.axis('off')
        if j == 0:
            ax.text(-0.5, 0.5, name_list[i], transform=ax.transAxes, fontsize=18, fontweight='bold')
        if i == 0:
            ax.text(0.5, 1.1, f"sample-{j+1:02d}", transform=ax.transAxes,
                    fontsize=18, va='center', ha='center', fontweight='bold')

fig.set_facecolor((1., 1., 1.))
fig.tight_layout()

plt.show()

In [None]:
# Model List for Grad-CAM
name_list = ['resnet18', 'resnet50', 'resnet101', 'googlenet', 'vgg16']
image_path = "pics/golden_retriever_1.jpeg"
input_image = plt.imread(image_path)

In [None]:
# Create a 3x10 subplot figure
fig, axs = plt.subplots(3, len(name_list), figsize=(12, 8))

# Flatten the axes array for easy indexing
flatten_axs = axs.flatten()

# Loop over each image and plot it on its corresponding subplot
OUTPUT_SHAPE = (224, 224, 3)
total_cams = []
for n in range(len(name_list)):
    cam_model = ModifyModel(eval(name_list[n]), name=name_list[n])
    cam_outputs = grad_cam(image_path, cam_model, visualize=False)
    total_cams.append(cam_outputs)

for i, ax_list in enumerate(axs):
    for j, ax in enumerate(ax_list):
        ax.imshow(transform.resize(total_cams[j][i], output_shape=input_image.shape[:-1]), alpha=0.9, cmap='hot')
        ax.imshow(input_image, alpha=0.1)
        ax.axis('off')
        if j == 0:
            ax.text(-0.5, 0.5, f"Map:{i+1:02d}", transform=ax.transAxes, fontsize=12, fontweight='bold')
        if i == 0:
            ax.text(0.5, 1.1, name_list[j], transform=ax.transAxes,
                    fontsize=12, va='center', ha='center', fontweight='bold')

fig.set_facecolor((1., 1., 1.))
fig.tight_layout()

plt.show()