In [None]:
import torch
from torch import nn
from torchvision import models
from torchvision.io import read_image
import torchvision.transforms.functional as TF
import numpy as np
import matplotlib.pyplot as plt
import os
from copy import deepcopy
from skimage.transform import resize

import warnings
warnings.filterwarnings('ignore')

# Function
---------

In [None]:
def resize_image(image, output_shape):
    # Get the original image shape
    h, w = image.shape[:2]

    # Compute the ratio of the new image size to the old image size
    ratio = min(output_shape[0] / h, output_shape[1] / w)

    # Compute the new image size with the same aspect ratio as the original image
    new_h, new_w = int(h * ratio), int(w * ratio)

    # Resize the image using the scikit-image library
    resized_image = resize(image, (new_h, new_w))

    # Create an output array with the desired output shape
    output = np.ones(output_shape, dtype=resized_image.dtype)

    # Compute the padding values
    pad_h = (output_shape[0] - new_h) // 2
    pad_w = (output_shape[1] - new_w) // 2

    # Copy the resized image into the output array
    output[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized_image

    return output

In [None]:
class CAMModel(nn.Module):

    def __init__(self, input_model, model_name: str = '', device: str = 'cuda') -> None:
        super(CAMModel, self).__init__()
        self.model = deepcopy(input_model)
        self.model.avgpool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten())
        try:
            self.weight_matrix = input_model.fc.weight.detach().clone().to(device)
        except:
            self.weight_matrix = input_model.classifier[-1].weight.detach().clone().to(device)
        self.device = device
        self.model_name = model_name
        self.model.to(device)
        self.model.eval()

    def get_cam(self, input_path, visualize: bool = False,
                mean_vector: tuple or list = (0.485, 0.456, 0.406),
                std_vector: tuple or list = (0.229, 0.224, 0.225)): # imageNet mean and std for each channel

        mean_vector = np.array(mean_vector)[..., None, None]
        std_vector = np.array(std_vector)[..., None, None]

        # Import and normalize image
        raw_image = read_image(input_path) / 255.
        input_image = ((raw_image - mean_vector) / std_vector).unsqueeze(0).float().to(self.device)

        # Get model results
        with torch.no_grad():
            activation_map, final_output = self.forward(input_image)

        # Extract last layer weights specifically associated with predicted class node
        weight_vector = self.weight_matrix[torch.argmax(final_output)]

        # Get CAM based on activation-map and last layer weights
        final_map = torch.tensordot(weight_vector, activation_map, dims=([0], [1]))

        # Get ready for colormap
        c, h, w = raw_image.shape
        color_map = TF.resize(final_map, [h, w])[0].cpu()
        color_map = (color_map - color_map.min()) / (color_map.max() - color_map.min())
        color_map = plt.cm.jet(color_map)

        if visualize:
            show_image = raw_image.permute(1, 2, 0)
            plt.imshow(show_image, alpha=0.6)
            plt.imshow(color_map, alpha=0.4)
            plt.axis('off')
            plt.title(f"{self.model_name}:CAM")
            plt.show()

        return final_map, color_map

    # Modify input_model to return desired activations (feature map) in addition to final output vector (logits)
    def forward(self, x):
        activation = None
        for name, child in self.model.named_children():
            if name == 'avgpool':
                activation = x
            x = child(x)
        return activation, x

# Import Models
------

In [None]:
resnet18 = models.resnet18(pretrained=True)
for param in resnet18.parameters():
    param.requires_grad = False

In [None]:
resnet101 = models.resnet101(pretrained=True)
for param in resnet101.parameters():
    param.requires_grad = False

In [None]:
googlenet = models.googlenet(pretrained=True)
for param in googlenet.parameters():
    param.requires_grad = False

In [None]:
regnet = models.regnet_x_400mf(pretrained=True)
for param in regnet.parameters():
    param.requires_grad = False

# Get Results
-----------

In [None]:
# Sample Data
root_image_path = "pics/samples"
num_samples = 10
sample_list = [os.path.join(root_image_path, i) for i in os.listdir(root_image_path)][:num_samples]
print(f"num_samples: {num_samples}")

In [None]:
# Model List
name_list = ['resnet18', 'resnet101', 'googlenet', 'regnet']

In [None]:
# Single Image
source_path = sample_list[0]
for i, each_model in enumerate(name_list):
    cam_model = CAMModel(eval(each_model), model_name=each_model, device ='cuda' if torch.cuda.is_available() else 'cpu')
    cam_model.get_cam(source_path, visualize=True)

In [None]:
# Create a 3x10 subplot figure
fig, axs = plt.subplots(len(name_list), len(sample_list), figsize=(35, 15))

# Flatten the axes array for easy indexing
flatten_axs = axs.flatten()

# Loop over each image and plot it on its corresponding subplot
OUTPUT_SHAPE = (224, 224, 3)
for i, ax_list in enumerate(axs):
    cam_model = CAMModel(eval(name_list[i]), model_name=name_list[i], device ='cuda' if torch.cuda.is_available() else 'cpu')
    for j, ax in enumerate(ax_list):
        _, cam_output = cam_model.get_cam(sample_list[j], visualize=False)
        ax.imshow(resize_image(plt.imread(sample_list[j]), OUTPUT_SHAPE), alpha=0.9)
        ax.imshow(resize_image(cam_output[:, :, :-1], OUTPUT_SHAPE), alpha=0.4)
        # ax.imshow(plt.imread(sample_list[j]) / 255. * 0.6 + cam_output[:, :, :-1] * 0.4)
        ax.axis('off')
        if j == 0:
            ax.text(-0.5, 0.5, name_list[i], transform=ax.transAxes, fontsize=18, fontweight='bold')
        if i == 0:
            ax.text(0.5, 1.1, f"sample-{j+1:02d}", transform=ax.transAxes,
                    fontsize=18, va='center', ha='center', fontweight='bold')

fig.set_facecolor((1., 1., 1.))
fig.tight_layout()

plt.show()