In [1]:
import torch

# Image-related utilities
from torchvision.io import decode_image, read_image
from torchvision.transforms import ToTensor
from torchvision import transforms
from PIL import Image

# Import models
from torchvision.models import alexnet, AlexNet_Weights
from torchvision.models import vgg19, VGG19_Weights

# Dataset
from torchvision.datasets import Imagenette

# LRP package
from src.lrp import LRPModel
from src.data import get_data_loader

# Utils
import argparse
import time
import pathlib
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

In [2]:
# Define custom colormap
colors = ["white", "red"]  # Transition from white to red
custom_cmap = LinearSegmentedColormap.from_list("white_red", colors, N=256)

In [3]:
def plot_relevance_scores(
    x: torch.tensor, r: torch.tensor, name: str
) -> None:
    """Plots results from layer-wise relevance propagation next to original image.

    Method currently accepts only a batch size of one.

    Args:
        x: Original image.
        r: Relevance scores for original image.
        name: Image name.
        config: Argparse namespace object.

    """
    output_dir = "./output/"

    max_fig_size = 20

    _, _, img_height, img_width = x.shape
    max_dim = max(img_height, img_width)
    fig_height, fig_width = (
        max_fig_size * img_height / max_dim,
        max_fig_size * img_width / max_dim,
    )

    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(fig_width, fig_height))

    x = x[0].squeeze().permute(1, 2, 0).detach().cpu()
    x_min = x.min()
    x_max = x.max()
    x = (x - x_min) / (x_max - x_min)
    axes[0].imshow(x)
    axes[0].set_axis_off()

    r_min = r.min()
    r_max = r.max()
    r = (r - r_min) / (r_max - r_min)
    axes[1].imshow(r, cmap='hot')
    axes[1].set_axis_off()

    fig.tight_layout()
    plt.savefig(f"{output_dir}/image_{name}.png", bbox_inches="tight")
    plt.close(fig)

In [4]:
def per_image_lrp(model):
    """Test function that plots heatmaps for images placed in the input folder.

    Images have to be placed in their corresponding class folders.

    Args:
        config: Argparse namespace object.

    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    print(f"Using: {device}\n")

    data_loader = get_data_loader()
    
    model = model
    model.to(device)


    lrp_model = LRPModel(model=model, top_k=0.02)

    for i, (x, y) in enumerate(data_loader):
        x = x.to(device)
        # y = y.to(device)  # here not used as method is unsupervised.

        t0 = time.time()
        r = lrp_model.forward(x)
        print("{time:.2f} FPS".format(time=(1.0 / (time.time() - t0))))

        plot_relevance_scores(x=x, r=r, name=str(i))

In [5]:
# Pre-trained model
# per_image_lrp(vgg19(weights=VGG19_Weights.DEFAULT))

# Retrained model
PATH = 'vgg19_imagenette.pth'
model = vgg19()
model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=10)
model.load_state_dict(torch.load(PATH, weights_only=True))
per_image_lrp(model)

Using: cuda

3.46 FPS
5.44 FPS
7.93 FPS
7.47 FPS
6.71 FPS
7.38 FPS
11.81 FPS
7.72 FPS
6.50 FPS
6.78 FPS
7.71 FPS
11.07 FPS
7.45 FPS
10.34 FPS
7.63 FPS
7.33 FPS
7.73 FPS
7.39 FPS
7.06 FPS
10.07 FPS
