#REPEAT: Improving Uncertainty Estimation in Representation Learning Explainability
This notebook illustrates the basic usage of REPEAT, a new method for representation learning explainablity with improved uncertainty estimation. If you are running this notebook on Google Colab, remember to enable GPU support to speed up computation.

REPEAT treats each pixel in an image as a Bernoulli random variable that is either important or unimportant to the representation of the image. From these Bernoulli random variables we can directly estimate the importance of a pixel and its associated certainty, thus enabling users to ascertain certainty in pixel importance.

This notebook consists of 6 cells of code:

1. Install and import packages: installs and imports the necessary packages.
2. Data transformations: A selection of functions that transforms the image into the shape and format expected by the feature extractor.
3. Load and plot example image: Loads an example image, transforms it, and plots it.
4. Feature extractor: Defines a function that loads a pretrained ResNet18 feature extractor.
5. Code for REPEAT: A class that implement the REPEAT methodology.
6. Run REPEAT on example image: Loads feature extractor, runs REPEAT, and plots the results of the analysis.



In [None]:
#@title Install and import packages

!pip install relax-xai

import torch
import scipy.misc
import torchvision
import torch.nn as nn
import matplotlib.pyplot as plt

from PIL import Image
from relax_xai.relax import RELAX
from skimage.filters import threshold_triangle, threshold_otsu, threshold_mean

In [2]:
#@title Data transformations

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

class ToDevice(torch.nn.Module):
    """
    Sends the input object to the device specified in the
    object's constructor by calling .to(device) on the object.
    """
    def __init__(self, device):
        super().__init__()
        self.device = device

    def forward(self, img):
        return img.to(self.device)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(device={self.device})"

def unsqeeze_image(input_image: torch.Tensor) -> torch.Tensor:
    return input_image.unsqueeze(0)


def imagenet_image_transforms(device: str, new_shape_of_image: int = 224):
    """
    Returns transformations that takes a torch tensor and transforms it into a new tensor
    of size (1, C, new_shape_of_image, new_shape_of_image), normalizes the image according
    to the statistics from the Imagenet dataset, and puts the tensor on the desired device.
    """
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Resize((new_shape_of_image, new_shape_of_image)),
        torchvision.transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        torchvision.transforms.Lambda(unsqeeze_image),
        ToDevice(device),
    ])

    return transform


def imsc(img, *args, quiet=False, lim=None, interpolation='lanczos', **kwargs):
    r"""Rescale and displays an image represented as a img.
    The function scales the img :attr:`im` to the [0 ,1] range.
    The img is assumed to have shape :math:`3\times H\times W` (RGB)
    :math:`1\times H\times W` (grayscale).
    Args:
        img (:class:`torch.Tensor` or :class:`PIL.Image`): image.
        quiet (bool, optional): if False, do not display image.
            Default: ``False``.
        lim (list, optional): maximum and minimum intensity value for
            rescaling. Default: ``None``.
        interpolation (str, optional): The interpolation mode to use with
            :func:`matplotlib.pyplot.imshow` (e.g. ``'lanczos'`` or
            ``'nearest'``). Default: ``'lanczos'``.
    Returns:
        :class:`torch.Tensor`: Rescaled image img.
    """
    if isinstance(img, Image.Image):
        img = torchvision.transforms.pil_to_tensor(img)
    handle = None
    with torch.no_grad():
        if not lim:
            lim = [img.min(), img.max()]
        img = img - lim[0]  # also makes a copy
        img.mul_(1 / (lim[1] - lim[0]))
        img = torch.clamp(img, min=0, max=1)
        if not quiet:
            bitmap = img.expand(3,
                                *img.shape[1:]).permute(1, 2, 0).cpu().numpy()
    return bitmap

In [None]:
#@title Load and plot example image

example_image = scipy.misc.face()
example_image = imagenet_image_transforms('cuda')(example_image)

plt.figure(1)
plt.imshow(imsc(example_image.squeeze()))
plt.axis('off')
plt.show()


In [4]:
#@title Feature extractor

def load_resnet18_encoder() -> nn.Module:

    resnet18 = torchvision.models.resnet18(weights="DEFAULT")
    modules = list(resnet18.children())[:-1]
    encoder = nn.Sequential(*modules, nn.Flatten())
    encoder.eval()

    return encoder

In [5]:
#@title Code for REPEAT


class REPEAT(nn.Module):
    """
    This class implement REPEAT, a new method for represenationa learning
    explainability. REPEAT treats each pixel in an image as a Bernoulli random
    variable that is either important or unimportant to the representation of the image.
    From these Bernoulli random variables we can directly estimate the importance
    of a pixel and its associated certainty, thus enabling users to ascertain
    certainty in pixel importance.

    Parameters
    ----------
    input_image
        Input image to be explained.
    encoder
        Encoder that transforms the input image into a new representation
    num_repeats
        How many times to repat the RELAX calculation
    num_batches
        Number of batches with masks to generate
    batch_size
        The size of each batch of masks
    explanation_threshold
        Method for thresholding explanation into important and non-important pixels
    """
    def __init__(self,
                 input_image: torch.Tensor,
                 encoder: nn.Module,
                 num_repeats: int = 10,
                 batch_size: int = 100,
                 num_batches: int = 10,
                 explanation_threshold: str = 'mean'
                 ) -> None:
        super().__init__()

        self.input_image = input_image
        self.num_repeats = num_repeats
        self.num_batches = num_batches
        self.batch_size = batch_size

        self.device = input_image.device
        self.shape = tuple(input_image.shape[2:])

        self.probability_of_importance = torch.zeros(self.shape, device=self.device)
        self.uncertainty = torch.zeros(self.shape, device=self.device)

        self.encoder = encoder
        self.explanation_threshold = explanation_threshold

        self.THRESHOLD_METHODS = {
            'mean': threshold_mean,
            'otsu': threshold_triangle,
            'triangle': threshold_otsu,
        }

    def forward(self) -> None:

      for _ in range(self.num_repeats):

          with torch.no_grad():
            relax = RELAX(self.input_image, self.encoder, self.num_batches, self.batch_size)
            relax.forward()

            threshold_val = self.threshold_method(relax.importance, explanation_threshold=self.explanation_threshold)

            weight = relax.importance / relax.importance.max()
            self.probability_of_importance += weight * (relax.importance > threshold_val)

      self.probability_of_importance /= self.num_repeats
      self.uncertainty = self.probability_of_importance * (1 - self.probability_of_importance)

      return None

    def threshold_method(self, explanation: torch.Tensor, explanation_threshold: str) -> torch.Tensor:
        return self.THRESHOLD_METHODS[explanation_threshold](explanation.numpy(force=True))



In [None]:
#@title Run REPEAT on example image


encoder = load_resnet18_encoder().to('cuda')
repeat = REPEAT(example_image, encoder)
repeat.forward()


plt.figure(1)
plt.subplot(131)
plt.imshow(imsc(example_image.squeeze()))
plt.axis('off')
plt.subplot(132)
plt.imshow(imsc(example_image.squeeze()))
plt.imshow(repeat.probability_of_importance.numpy(force=True), alpha=0.75, cmap='bwr')
plt.axis('off')
plt.subplot(133)
plt.imshow(imsc(example_image.squeeze()))
plt.imshow(repeat.uncertainty.numpy(force=True), alpha=0.75, cmap='bwr')
plt.axis('off')
plt.tight_layout()
plt.show()