In [1]:
import torch

# Image-related utilities
from torchvision.io import decode_image, read_image
from torchvision.transforms import ToTensor
from torchvision import transforms
from PIL import Image

# Import models
from torchvision.models import alexnet, AlexNet_Weights
from torchvision.models import vgg19, VGG19_Weights

# Dataset
from torchvision.datasets import Imagenette

# LRP package
from src.lrp import LRPModel
from src.data import get_data_loader

# Utils
import argparse
import time
import pathlib
import matplotlib.pyplot as plt

In [2]:
def plot_relevance_scores(
    x: torch.tensor, r: torch.tensor, name: str
) -> None:
    """Plots results from layer-wise relevance propagation next to original image.

    Method currently accepts only a batch size of one.

    Args:
        x: Original image.
        r: Relevance scores for original image.
        name: Image name.
        config: Argparse namespace object.

    """
    output_dir = "./output/"

    max_fig_size = 20

    _, _, img_height, img_width = x.shape
    max_dim = max(img_height, img_width)
    fig_height, fig_width = (
        max_fig_size * img_height / max_dim,
        max_fig_size * img_width / max_dim,
    )

    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(fig_width, fig_height))

    x = x[0].squeeze().permute(1, 2, 0).detach().cpu()
    x_min = x.min()
    x_max = x.max()
    x = (x - x_min) / (x_max - x_min)
    axes[0].imshow(x)
    axes[0].set_axis_off()

    r_min = r.min()
    r_max = r.max()
    r = (r - r_min) / (r_max - r_min)
    axes[1].imshow(r, cmap="afmhot")
    axes[1].set_axis_off()

    fig.tight_layout()
    plt.savefig(f"{output_dir}/image_{name}.png", bbox_inches="tight")
    plt.close(fig)

In [3]:
def per_image_lrp():
    """Test function that plots heatmaps for images placed in the input folder.

    Images have to be placed in their corresponding class folders.

    Args:
        config: Argparse namespace object.

    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    print(f"Using: {device}\n")

    data_loader = get_data_loader()

    model = vgg19(weights=VGG19_Weights.DEFAULT)
    model.to(device)

    lrp_model = LRPModel(model=model, top_k=0.02)

    for i, (x, y) in enumerate(data_loader):
        x = x.to(device)
        # y = y.to(device)  # here not used as method is unsupervised.

        t0 = time.time()
        r = lrp_model.forward(x)
        print("{time:.2f} FPS".format(time=(1.0 / (time.time() - t0))))

        plot_relevance_scores(x=x, r=r, name=str(i))

In [4]:
per_image_lrp()

Using: cuda

3.22 FPS
