 # CAM 

This notebook is based on code found at: 

https://snappishproductions.com/blog/2018/01/03/class-activation-mapping-in-pytorch.html.html


In [None]:
%matplotlib inline

from PIL import Image, ImageFilter, ImageOps
from matplotlib.pyplot import imshow
from torchvision import models, transforms
from torch.autograd import Variable
from torch.nn import functional as F
from torch import topk
import torch
import numpy as np
import skimage.transform

## Import
Import image that we want to classify. Based on code from SML course UU - lab exercise

In [None]:
from urllib.request import urlopen
from urllib.error import URLError
import matplotlib.pyplot as plt

image_url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Felis_catus-cat_on_snow.jpg/640px-Felis_catus-cat_on_snow.jpg'

lighthouse_url = 'https://github.com/Falk0/latex_master1_semester2/blob/main/deep_learning_for_image_analysis/figures/assignment_4/lighthouse_original.jpg?raw=true'
lighthouse_90deg_url = 'https://github.com/Falk0/latex_master1_semester2/blob/main/deep_learning_for_image_analysis/figures/assignment_4/lighthouse_90deg.jpg?raw=true'
lighthouse_blur_url = 'https://github.com/Falk0/latex_master1_semester2/blob/main/deep_learning_for_image_analysis/figures/assignment_4/lighthouse_blur.jpg?raw=true'
lighthouse_highpass_url = 'https://github.com/Falk0/latex_master1_semester2/blob/main/deep_learning_for_image_analysis/figures/assignment_4/lighthouse_highpass.jpg?raw=true'
lighthouse_mix_channel_url = 'https://github.com/Falk0/latex_master1_semester2/blob/main/deep_learning_for_image_analysis/figures/assignment_4/lighthouse_mix_channel.jpg?raw=true'
lighthouse_noise_url = 'https://github.com/Falk0/latex_master1_semester2/blob/main/deep_learning_for_image_analysis/figures/assignment_4/lighthouse_noise.jpg?raw=true'
lighthouse_equalized_url = 'https://github.com/Falk0/latex_master1_semester2/blob/main/deep_learning_for_image_analysis/figures/assignment_4/lighthouse%20equalized.jpg?raw=true'

lighthouse_url_list = [
    lighthouse_url,
    lighthouse_90deg_url,
    lighthouse_blur_url,
    lighthouse_highpass_url,
    lighthouse_mix_channel_url,
    lighthouse_noise_url,
    lighthouse_equalized_url   
]

description = ['Original', '90deg', 'Gaussian blur', 'Highpass filtered', 'Mixed color channels', 'Noise', 'Hist.equalized']

lighthouse_image_list = []
for i in range(len(lighthouse_url_list)):
    
    try:
        with Image.open(urlopen(lighthouse_url_list[i])) as im:
            # The following fixes some problems when loading images:
            # https://stackoverflow.com/a/64598016
            lighthouse_image_list.append(im.convert("RGB"))
    except (URLError, OSError):
        print("please provide a valid URL or local path")


 
print(f"{lighthouse_image_list[3].mode} image of size {lighthouse_image_list[3].size}")
plt.imshow(np.asarray(lighthouse_image_list[3]))
plt.xticks([])
plt.yticks([])
plt.show()

Store class names in a dictionary

In [None]:
# Define the URL of the raw text file on GitHub
class_name = {}

# Define the URL of the raw text file on GitHub
url = 'https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt'

try:
    # Open the URL and read the contents of the file
    with urlopen(url) as response:
        text = response.read().decode('utf-8')
        text = text.replace('{', '')
        text = text.replace('}', '')

        for line in text.splitlines():
            key, value = line.split(':')
            key = int(key)
    
            if value.count(',') >= 2:
                value = value.replace(',', '\n', 1)
            # remove last , if present
            if value.endswith(','):
                value = value[:-1]

            class_name[key] = value.replace("'", "")

except URLError as e:
    print("please provide a valid URL or local path")

# set  beacon, lighthouse, beacon light, pharos to lighthouse for less printing
class_name[437] = 'lighthouse'

# print some random class names


Import the 3 images for attention maps

In [None]:
sandal_url = "https://raw.githubusercontent.com/Okrash0/Explainable-Artificial-Intelligence/main/fig/sandal.jpg"
toilet_url = "https://raw.githubusercontent.com/Okrash0/Explainable-Artificial-Intelligence/main/fig/toilet.jpg"

In [None]:
image_url_list = [sandal_url, toilet_url]
image_list = []

for i in range(len(image_url_list)):
    
    try:
        with Image.open(urlopen(image_url_list[i])) as im:
            # The following fixes some problems when loading images:
            # https://stackoverflow.com/a/64598016
            image_list.append(im.convert("RGB"))
    except (URLError, OSError):
        print("please provide a valid URL or local path")

image_name_list = ["toilet", "sandal", "lighthouse"]

image_list.append(lighthouse_image_list[0])

In [None]:
# show images in subplots 
fig, axs = plt.subplots(1, len(image_list))
for i in range(len(image_list)):
    axs[i].imshow(np.asarray(image_list[i]))
    axs[i].set_xticks([])
    axs[i].set_yticks([])
plt.show()

In [None]:
def plot_images_1(image_list, mode=None):
    fig, axs = plt.subplots(1, len(image_list))
    for i in range(len(image_list)):
        axs[i].imshow(np.asarray(image_list[i]), cmap=mode)
        axs[i].set_xticks([])
        axs[i].set_yticks([])
    plt.show()

In [None]:
def rotate_filter(image, angle=90):
    return image.rotate(angle)

def blur_filter(image, radius=2):
    return image.filter(ImageFilter.GaussianBlur(radius=radius))

def gaussian_noise_filter(image):
    # convert to numpy array
    image = np.asarray(image)
    row, col, ch = image.shape
    mean = 0
    var = 0.1
    sigma = var**0.9
    gauss = np.random.normal(mean, sigma, (row, col, ch))
    gauss = gauss.reshape(row, col, ch)
    # convert image to float in range 0-1
    noise = gauss * 255

    # add noise to image and round if down or up if final value is 0 or 255
    noise = image + noise
    noise = np.where(noise < 0, 0, noise)
    noise = np.where(noise > 255, 255, noise)

    # convert back to PIL image
    return Image.fromarray(noise.astype('uint8'), 'RGB')

def mix_color_filter(image, mode=1):
    # Split the color channels
    r, g, b = image.split()

    # Mix the color channels
    if mode == 1:
        return Image.merge("RGB", (b, g, r))
    elif mode == 2:
        return Image.merge("RGB", (r, b, g))
    elif mode == 3:
        return Image.merge("RGB", (g, r, b))



### Apply the filters
Just for testing right now 

In [None]:
print(image_list[0])

rotated_list = []
for i in range(len(image_list)):
    rotated_list.append(rotate_filter(image_list[i]))

plot_images_1(rotated_list)

gause_list = []
for i in range(len(image_list)):
    gause_list.append(blur_filter(image_list[i], 10))

plot_images_1(gause_list)

noise_list = []
for i in range(len(image_list)):
    noise_list.append(gaussian_noise_filter(image_list[i]))

plot_images_1(noise_list)

mix_list = []
for i in range(len(image_list)):
    mix_list.append(mix_color_filter(image_list[i], 2))

plot_images_1(mix_list)

## Preprocess 

In [None]:
# Imagenet mean/std

normalize = transforms.Normalize(
   mean=[0.485, 0.456, 0.406],
   std=[0.229, 0.224, 0.225]
)

# Preprocessing - scale to 224x224 for model, convert to tensor, 
# and normalize to -1..1 with mean/std for ImageNet

preprocess = transforms.Compose([
   transforms.Resize((224,224)),
   transforms.ToTensor(),
   normalize
])

display_transform = transforms.Compose([
   transforms.Resize((224,224))])

In [None]:
lighthouse_tensors = []
for i in range(len(lighthouse_image_list)):
    lighthouse_tensors.append(preprocess(lighthouse_image_list[i]))

In [None]:
image_tensors = []
for i in range(len(image_list)):
    image_tensors.append(preprocess(image_list[i]))

In [None]:
prediction_var_list = []
for i in range(len(lighthouse_image_list)):
    prediction_var_list.append(Variable((lighthouse_tensors[i].unsqueeze(0)), requires_grad=True))


In [None]:
prediction_var_list_image = []
for i in range(len(image_list)):
    prediction_var_list_image.append(Variable((image_tensors[i].unsqueeze(0)), requires_grad=True))

## Load model

In [None]:
model = models.resnet18(pretrained=True)

In [None]:
model.eval()

In [None]:
class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = ((output.cpu()).data).numpy()
    def remove(self): self.hook.remove()

In [None]:
final_layer = model._modules.get('layer4')

In [None]:
activated_features = SaveFeatures(final_layer)

## Make prediction

In [None]:
pred_probabilities_list = []

for i in range(len(lighthouse_tensors)):
    prediction = model(prediction_var_list[i])
    pred_probabilities_list.append(F.softmax(prediction).data.squeeze())
    activated_features.remove()



In [None]:
for i in range(len(lighthouse_tensors)):
    topk(pred_probabilities_list[i],1)
    print(topk(pred_probabilities_list[i],1))

 https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a
 
437: 'beacon, lighthouse, beacon light, pharos',

In [None]:
def getCAM(feature_conv, weight_fc, class_idx):
    _, nc, h, w = feature_conv.shape
    cam = weight_fc[class_idx].dot(feature_conv.reshape((nc, h*w)))
    cam = cam.reshape(h, w)
    cam = cam - np.min(cam)
    cam_img = cam / np.max(cam)
    return [cam_img]


In [None]:
weight_softmax_params = list(model._modules.get('fc').parameters())
weight_softmax = np.squeeze(weight_softmax_params[0].cpu().data.numpy())

In [None]:
weight_softmax_params

In [None]:
overlay_list = []
class_list = []
for i in range(len(lighthouse_tensors)):
    class_idx = topk(pred_probabilities_list[i],1)[1].int()
    overlay_list.append(getCAM(activated_features.features, weight_softmax, class_idx))
    class_list.append(class_idx)
    print(topk(pred_probabilities_list[i],1))
    activated_features.remove()

In [None]:
print(overlay_list)

## Plot heatmap of predicted class

Ploting functions

In [None]:
def plot_heatmap(overlay_images, titles, save_plot=False):
    """plot heatmap of images with titles
    Args: 
        overlay_images: list of overlay images
        titles: list of titles
        save_plot: bool, save plot or not
    """

    n = len(overlay_images)
    _, axes = plt.subplots(1, n, figsize=(4 * n, 4))

    for i in range(n):
        if n == 1:
            ax = axes
        else:
            ax = axes[i]
        ax.imshow(overlay_images[i][0], alpha=0.5, cmap='jet')
        ax.set_title(titles[i].item())
        ax.axis('off')

    if save_plot:
        plt.savefig('plot.png', bbox_inches='tight')

    plt.show()


def plot_images(images, titles, save_plot=False):
    """plot images with titles
    Args:
        images: list of images
        titles: list of titles
        save_plot: bool, save plot or not

    """
    n = len(images)
    fig, axes = plt.subplots(1, n, figsize=(4 * n, 4))

    for i in range(n):
        if n == 1:
            ax = axes
        else:
            ax = axes[i]
        ax.imshow(images[i])
        ax.set_title(titles[i])
        ax.axis('off')

    if save_plot:
        plt.savefig('plot.png', bbox_inches='tight')

    plt.show()


def plot_images_overlay(images, tensor, overlay_list, titles, save_plot=False):
    """plot images with titles and overlay
    Args:
        images: list of images
        tensor: list of images
        overlay_list: list of overlay images
        titles: list of tensors with class index
    
    """
    n = len(images)
    fig, axes = plt.subplots(1, n, figsize=(4 * n, 4))

    for i in range(n):
        if n == 1:
            ax = axes
        else:
            ax = axes[i]
        ax.imshow(display_transform(images[i]))
        ax.imshow(skimage.transform.resize(
            overlay_list[i][0], tensor[i].shape[1:3]), alpha=0.5, cmap='jet')
        ax.set_title(class_name[titles[i].item()])
        ax.axis('off')

    if save_plot:
        plt.savefig('plot.png', bbox_inches='tight')

    plt.show()


In [None]:
plot_heatmap(overlay_list, class_list)
plot_images(lighthouse_image_list, description)
plot_images_overlay(lighthouse_image_list, lighthouse_tensors, overlay_list, class_list)

 437: 'beacon, lighthouse, beacon light, pharos',
 
 972: 'cliff, drop, drop-off',
 
 646: 'maze, labyrinth',
 

In [None]:
imshow(display_transform(lighthouse_image_list[3]))
imshow(skimage.transform.resize(overlay_list[3][0], lighthouse_tensors[0].shape[1:3]), alpha=0.5, cmap='jet');

## Plot heatmap of second predected class

In [None]:
overlay_list_sec = []
class_list_sec = []
for i in range(len(lighthouse_tensors)):
    class_idx = topk(pred_probabilities_list[i],2)[1].int()
    print(topk(pred_probabilities_list[i],2))
    class_idx = class_idx[1]
    overlay_list_sec.append(getCAM(activated_features.features, weight_softmax, class_idx))
    class_list_sec.append(class_idx)
    activated_features.remove()


In [None]:
plot_heatmap(overlay_list_sec, class_list_sec)
plot_images(lighthouse_image_list, description)
plot_images_overlay(lighthouse_image_list, lighthouse_tensors ,overlay_list_sec, class_list_sec)

Second choice

483: 'castle' [3.2850e-04]

976: 'promontory, headland, head, foreland' [0.0701]

975: 'lakeside, lakeshore', [0.0419]

50: 'American alligator, Alligator mississipiensis',[0.0419]

460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty', [0.0021]

483: 'castle', [0.0043]

497: 'church, church building', [2.8490e-04]



## Plot heatmap of other classes
Lets check the heat map for some other classes 

*   527: 'desktop computer'
*   587: 'hammer'
*   497: 'church, church building',


In [None]:
other_class = 497

overlay_list_otherclass = []
class_list_otherclass = []
for i in range(len(lighthouse_tensors)):
    input_image = prediction_var_list[i]
    prediction = model(input_image) 
    class_idx = topk(prediction, 2)[1].int()
    overlay_list_otherclass.append(getCAM(activated_features.features, weight_softmax, other_class))
    
    class_list_otherclass.append(torch.tensor(other_class, dtype=torch.int32))
    activated_features.remove()



In [None]:
plot_heatmap(overlay_list_otherclass, class_list_otherclass)
plot_images(lighthouse_image_list, description)
plot_images_overlay(lighthouse_image_list,
                    lighthouse_tensors, overlay_list_otherclass, class_list_otherclass)


# Heatmap of other classes

Make prediciton

In [None]:
pred_probabilities_list_image = []

for i in range(len(image_tensors)):
    prediction = model(prediction_var_list_image[i])
    pred_probabilities_list_image.append(F.softmax(prediction).data.squeeze())
    activated_features.remove()

print("Probability and predicted class:")
for i in range(len(image_tensors)):
    probs = topk(pred_probabilities_list_image[i],1)
    # print probabilities and predicted classes
    print(class_name[probs[1].item()], probs[0].item())

weight_softmax_params = list(model._modules.get('fc').parameters())
weight_softmax = np.squeeze(weight_softmax_params[0].cpu().data.numpy())

overlay_list = []
class_list = []
for i in range(len(image_tensors)):
    class_idx = topk(pred_probabilities_list_image[i],1)[1].int()
    overlay_list.append(getCAM(activated_features.features, weight_softmax, class_idx))
    class_list.append(class_idx)
    activated_features.remove()

In [None]:
plot_heatmap(overlay_list, class_list)
plot_images(image_list, image_name_list)
plot_images_overlay(image_list, image_tensors, overlay_list, class_list)


# LRP

In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet18
from PIL import Image
from captum.attr import LRP
from captum.attr import visualization as viz
from urllib.request import urlopen
from urllib.error import URLError


# Preprocessing and display_transform functions
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize
])

display_transform = transforms.Compose([
    transforms.Resize((224, 224))
])

threshold = 0.9


# Process each image in the list
for idx, input_image in enumerate(lighthouse_image_list):
    preprocessed_image = preprocess(input_image).unsqueeze(0)

    # Get the predicted class index
    output = model(preprocessed_image)
    _, pred_class = torch.max(output, 1)
    pred_class_idx = pred_class.item()

    # Compute LRP attributions
    lrp = LRP(model)
    attributions = lrp.attribute(preprocessed_image, target=pred_class_idx)

    # Visualize the attributions
    attributions_np = attributions.squeeze().detach().numpy()
    original_image_np = display_transform(input_image)
    original_image_np = np.array(original_image_np) / 255.0

    # Normalize the attributions
    attributions_np = (attributions_np - np.min(attributions_np)) / (np.max(attributions_np) - np.min(attributions_np))

    # Fix the flipped heatmap
    attributions_np = np.flip(attributions_np, axis=2)

    # Swap the axes of the attributions array
    attributions_np = np.transpose(attributions_np, (1, 2, 0))

    # Create a subplot for the original image and the heatmap
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(original_image_np)
    axs[0].set_title('Original Image')
    axs[0].axis('off')
    axs[1].imshow(attributions_np, cmap='viridis')
    axs[1].set_title('Heatmap')
    axs[1].axis('off')

    plt.show()


