In [None]:
!pip install torchray

In [None]:
import torchray
import torchray.benchmark
import matplotlib.pyplot as plt
%matplotlib inline

Let's take an example image with it's imagenet categories (`cat_id1` and `cat_id2` are just network output indices for some classes, contained in the image)

In [None]:
model, image_tensor, cat_id1, cat_id2 = torchray.benchmark.get_example_data(arch='vgg16', shape=224)

In [None]:
# image_tensor is already preprocessed appropriately with same normalization
image_tensor.size(), cat_id1, cat_id2

## Vanilla gradient visualization
Let's start from the simplest and most obvious visualization technique - the gradients themselves. We can backpropogate through our network until the input and use that input gradient as saliency map visualization

In [None]:
x = image_tensor
category_id = cat_id1
x.requires_grad_(True)
# inference + backward
y = model(x)
z = y[0, category_id]
z.backward()

def to_saliency(gradient):
    # gradient has shape (B, 3, H, W) but saliency should have shape (B, 1, H, W)
    # and probably be non-negative
    # how would you propose to compute saliency?
    saliency = gradient.norm(dim=1, keepdim=True) 
    return saliency

saliency = to_saliency(x.grad)
saliency.size()

Let's plot what we have...

The `plot_example` helper function simply plots two images and does nothing more

In [None]:
plt.figure(figsize=(15,15))
torchray.benchmark.plot_example(x, saliency, 
                                method="simple_backprop", 
                                category_id=category_id, 
                                show_plot=True, 
                                save_path=None)

In [None]:
plt.figure(figsize=(15,15))
torchray.benchmark.plot_example(x, x.grad, 
                                method="simple_backprop", 
                                category_id=category_id, 
                                show_plot=True, 
                                save_path=None)

You can try different ways to transform `grad` into `saliency`, check it out before continue

## Changing backpropagation rules!

As you can (or can not...) see from the above examples simple gradient visualization work somehow, but we believe the better way for visualization of saliency exists. One simple improvement we can do is to modify backpropagation rules slightly (for some specific layers/functions) to get more nice saliencies in the end.

In `TorchRay` some methods are already implemented, let's check them out.

In [None]:
from torchray.attribution.common import gradient_to_saliency
from torchray.benchmark import get_example_data, plot_example

import torch

The two well-known methods, namely Deconvnet and Guided Backprop modify only ReLU function backward computation.

One way of changing backprop rules is to define custom context manager.

So we are heading to something like this

```python
with ChangedBackpropRules():
    y = model(x)
    z = y[0, category_id]
    z.backward()
# and now y.grad contains *modified* gradient
```

How does this context manager look like?

In [None]:
# `Patch` is a fancy tool to replace callable in a module
class Patch(object):
    """Patch a callable in a module."""

    @staticmethod
    def resolve(target):
        """Resolve a target into a module and an attribute.
        The function resolves a string such as ``'this.that.thing'`` into a
        module instance `this.that` (importing the module) and an attribute
        `thing`.
        Args:
            target (str): target string.
        Returns:
            tuple: module, attribute.
        """
        target, attribute = target.rsplit('.', 1)
        components = target.split('.')
        import_path = components.pop(0)
        target = __import__(import_path)
        for comp in components:
            import_path += '.{}'.format(comp)
            __import__(import_path)
            target = getattr(target, comp)
        return target, attribute

    def __init__(self, target, new_callable):
        """Patch a callable in a module.
        Args:
            target (str): path to the callable to patch.
            callable (fun): new callable.
        """
        target, attribute = Patch.resolve(target)
        self.target = target
        self.attribute = attribute
        self.orig_callable = getattr(target, attribute)
        setattr(target, attribute, new_callable)

    def __del__(self):
        self.remove()

    def remove(self):
        """Remove the patch."""
        if self.target is not None:
            setattr(self.target, self.attribute, self.orig_callable)
        self.target = None


# This is our context manager (base class)
# we will need to specify ReLU function implementation here
class ReLUContext(object):
    """
    A context manager that replaces :func:`torch.relu` with
        :attr:`relu_function`.
    Args:
        relu_func (:class:`torch.autograd.function.FunctionMeta`): class
            definition of a :class:`torch.autograd.Function`.
    """

    def __init__(self, relu_func):
        assert isinstance(relu_func, torch.autograd.function.FunctionMeta)
        self.relu_func = relu_func
        self.patches = []

    def __enter__(self):
        relu = self.relu_func().apply
        self.patches = [
            Patch('torch.relu', relu),
            Patch('torch.relu_', relu),
        ]
        return self

    def __exit__(self, type, value, traceback):
        for p in self.patches:
            p.remove()
        return False  # re-raise any exception


# Our fancy ReLU with changed backward pass
class DeConvNetReLU(torch.autograd.Function):
    """DeConvNet ReLU autograd function.
    This is an autograd function that redefines the ``relu`` function
    to match the DeConvNet ReLU definition.
    """

    @staticmethod
    def forward(ctx, input):
        """DeConvNet ReLU forward function."""
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """DeConvNet ReLU backward function."""
        return ?  # TODO


# And finally our context manager which
# swaps the `relu` implementation
class DeConvNetContext(ReLUContext):
    """DeConvNet context.
    This context modifies the computation of gradient to match the DeConvNet
    definition.
    See :mod:`torchray.attribution.deconvnet` for how to use it.
    """

    def __init__(self):
        super(DeConvNetContext, self).__init__(DeConvNetReLU)

Note the difference:

![ReLU backprop changed](https://cdn1.imggmi.com/uploads/2019/12/10/7ebba954e965228e071c407189f84986-full.png)

In [None]:
# Obtain example data.
model, x, category_id, _ = get_example_data()

# DeConvNet method.
x.requires_grad_(True)

with DeConvNetContext():
    y = model(x)
    z = y[0, category_id]
    z.backward()

saliency = gradient_to_saliency(x)

# Plots.
plt.figure(figsize=(15,15))
plot_example(x, saliency, 'deconvnet', category_id)

In `TorchRay` Deconvnet is already implemented, let's check ourselves:

In [None]:
from torchray.attribution.deconvnet import DeConvNetContext

# Obtain example data.
model, x, category_id, _ = get_example_data()

# DeConvNet method.
x.requires_grad_(True)

with DeConvNetContext():
    y = model(x)
    z = y[0, category_id]
    z.backward()

saliency = gradient_to_saliency(x)

# Plots.
plt.figure(figsize=(15,15))
plot_example(x, saliency, 'deconvnet', category_id)

Let's implement GuidedBackprop now, it is very similar

In [None]:
class GuidedBackpropReLU(torch.autograd.Function):
    """This class implements a ReLU function with the guided backprop rules."""
    @staticmethod
    def forward(ctx, input):
        """Guided backprop ReLU forward function."""
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """Guided backprop ReLU backward function."""
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        # TODO
        return grad_input


class GuidedBackpropContext(ReLUContext):
    r"""GuidedBackprop context.
    This context modifies the computation of gradients
    to match the guided backpropagaton definition.
    See :mod:`torchray.attribution.guided_backprop` for how to use it.
    """

    def __init__(self):
        super(GuidedBackpropContext, self).__init__(GuidedBackpropReLU)

In [None]:
model, image_tensor, cat_id1, cat_id2 = torchray.benchmark.get_example_data(arch='vgg16', shape=224)

x = image_tensor
category_id = cat_id1
x.requires_grad_(True)

with GuidedBackpropContext():
      y = model(x)
      z = y[0, category_id]
      z.backward()
# now compute saliency map from gradient
saliency = gradient_to_saliency(x)
saliency_cat1 = saliency
# plot saliency for category1(dog)
plt.figure(figsize=(15,15))
torchray.benchmark.plot_example(x, saliency, 
                                method="guided_backprop", 
                                category_id=category_id, 
                                show_plot=True, 
                                save_path=None)

In [None]:
# repeat for other category
x = image_tensor
category_id = cat_id2
# note that we zero gradients here 
# as `x` already havs .grad from previous computations
x.grad.zero_()
model.zero_grad()

with GuidedBackpropContext():
      y = model(x)
      z = y[0, category_id]
      z.backward()
# now compute saliency map from gradient
saliency = gradient_to_saliency(x)
plt.figure(figsize=(15,15))
torchray.benchmark.plot_example(x, saliency, 
                                method="guided_backprop", 
                                category_id=category_id, 
                                show_plot=True, 
                                save_path=None)

Note how similar are 2 saliencies for different catagories!

In [None]:
print(saliency.min(), saliency.max())
print(saliency.size())
saliency

In [None]:
saliency_cat1

Obviously `GuidedBackprop` is also already implemented

In [None]:
from torchray.attribution.guided_backprop import GuidedBackpropContext

Before we move forward compare the difference:

In [None]:
# 1
# Obtain example data.
model, x, category_id, _ = get_example_data()

# Guided backprop.
x.requires_grad_(True)

with GuidedBackpropContext():
    y = model(x)
    z = y[0, category_id]
    z.backward()

saliency = gradient_to_saliency(x)

# Plots.
plot_example(x, saliency, 'guided backprop', category_id)

In [None]:
from torchray.attribution.guided_backprop import guided_backprop
# 2
# Obtain example data.
model, x, category_id, _ = get_example_data()

# Guided backprop.
saliency = guided_backprop(model, x, category_id)

# Plots.
plot_example(x, saliency, 'guided backprop', category_id)

So just to write a bit less code let's import more general method which wraps up this forward + backward calls inside the context

In [None]:
from torchray.attribution.common import saliency as compute_saliency
# Obtain example data.
model, x, category_id, _ = get_example_data()

# Guided backprop.
saliency = compute_saliency(model, x, category_id, context_builder=GuidedBackpropContext)
# saliency = guided_backprop(model, x, category_id)  # (the same)

# Plots.
plot_example(x, saliency, 'guided backprop', category_id)

Let's use and compare other models!

In [None]:
# gradient
from torchray.attribution.gradient import gradient
# gradient with different grad2saliency function
from torchray.attribution.grad_cam import grad_cam
from torchray.attribution.linear_approx import linear_approx
# ReLU overwritten
from torchray.attribution.deconvnet import deconvnet
from torchray.attribution.guided_backprop import guided_backprop
# Linear layers overwritten
from torchray.attribution.excitation_backprop import excitation_backprop
from torchray.attribution.excitation_backprop import contrastive_excitation_backprop

You can see `TorchRay` [docs](https://facebookresearch.github.io/TorchRay/attribution.html#) for method details 

In [None]:
def draw_method(model, x, category_id, method, method_name, figsize=(15,15), **kwargs):
    saliency = method(model, x, category_id, **kwargs)
    plt.figure(figsize=figsize)
    plot_example(x, saliency, method_name, category_id)

In [None]:
model, x, category_id, _ = get_example_data()
draw_method(model, x, category_id, method=gradient, method_name="gradient")

In [None]:
methods = [gradient, linear_approx, deconvnet, excitation_backprop]
method_names = ["gradient", "linear_approx", "deconvnet", "excitation_backprop"]

archs = ['alexnet', 'vgg11', 'resnet18', 'resnet50', 'wide_resnet101_2', 'densenet121', 'densenet201', 'mobilenet_v2']

for arch in archs:
    model, image_tensor, cat_id1, cat_id2 = torchray.benchmark.get_example_data(arch=arch, shape=224)
    print(f"Model={arch}", flush=True)
    for method, method_name in zip(methods, method_names):
        draw_method(model, image_tensor, cat_id1, method=method, method_name=method_name)
        if image_tensor.grad is not None:
            image_tensor.grad.zero_()
        draw_method(model, image_tensor, cat_id2, method=method, method_name=method_name)
    plt.show()

For now we visualize only the saliency on the first layer, but can we do this for some intermediate layer? 

Yes, we do! Let's check it out

In [None]:
model, image_tensor, cat_id1, cat_id2 = torchray.benchmark.get_example_data()
# for method, method_name in zip(methods, method_names):
draw_method(model, image_tensor, cat_id1, method=guided_backprop, method_name="gradient", saliency_layer="features.9")

In [None]:
model

Now try to visualize on all non-activation layers of the classifier

In [None]:
model, image_tensor, cat_id1, cat_id2 = torchray.benchmark.get_example_data()
method = guided_backprop
method_name = "gradient"
# iterate throuwh all non-relu layers
for i in range(31):
    if image_tensor.grad is not None:
        image_tensor.grad.zero_()

    layer_is_not_relu = ???  # TODO
    if layer_is_not_relu:
        saliency_layer = f"features.{i}"
        print(f"Saliency layer: {saliency_layer}", flush=True)
        draw_method(model, image_tensor, cat_id1, method=method, method_name=method_name, saliency_layer=saliency_layer)

You can also try with other methods and architectures

In [None]:
model, image_tensor, cat_id1, cat_id2 = torchray.benchmark.get_example_data()
draw_method(model, image_tensor, cat_id1, method=excitation_backprop, method_name="gradient", saliency_layer="features.30")

## Perturbation methods
Perturbation methods **do not require your model to be differentiable, so you can actually apply it to any classifier**, aka black box model.

In [None]:
# RISE
from torchray.attribution.rise import rise, rise_class
from torchray.benchmark import get_example_data, plot_example
from torchray.utils import get_device

# Obtain example data.
model, x, category_id_1, category_id_2 = get_example_data()

# Run on GPU if available.
device = get_device()
model.to(device)
x = x.to(device)

# RISE method.
saliency = rise(model, x)
saliency1 = saliency[:, category_id_1].unsqueeze(0)
saliency2 = saliency[:, category_id_2].unsqueeze(0)

# Plots.
plt.figure(figsize=(15,15))
plot_example(x, saliency1, 'RISE', category_id_1)
plt.figure(figsize=(15,15))
plot_example(x, saliency2, 'RISE', category_id_2)

In [None]:
# RISE per class
# Obtain example data.
model, x, category_id_1, category_id_2 = get_example_data()

# Run on GPU if available.
device = get_device()
model.to(device)
x = x.to(device)

# RISE method.
saliency1 = rise_class(model, x, target=[category_id_1])  # should work with target=category_id_1, looks like a bug
print(saliency1.size())
saliency2 = rise_class(model, x, target=[category_id_2])

# Plots.
plt.figure(figsize=(15,15))
plot_example(x, saliency1, 'RISE', category_id_1)
plt.figure(figsize=(15,15))
plot_example(x, saliency2, 'RISE', category_id_2)

The torchray's authors are benchmarking several approaches and proposes the new method called "Extremal Perturbation"

In [None]:
from torchray.attribution.extremal_perturbation import extremal_perturbation, contrastive_reward
from torchray.benchmark import get_example_data, plot_example
from torchray.utils import get_device

# Obtain example data.
model, x, category_id_1, category_id_2 = get_example_data()

# Run on GPU if available.
device = get_device()
model.to(device)
x = x.to(device)

# Extremal perturbation backprop.
masks_1, _ = extremal_perturbation(
    model, x, category_id_1,
    reward_func=contrastive_reward,
    debug=True,
    areas=[0.12],
)

masks_2, _ = extremal_perturbation(
    model, x, category_id_2,
    reward_func=contrastive_reward,
    debug=True,
    areas=[0.05],
)

# Plots.
plot_example(x, masks_1, 'extremal perturbation', category_id_1)
plot_example(x, masks_2, 'extremal perturbation', category_id_2)

So which method is better? TorchRay provide [benchmarking results](https://facebookresearch.github.io/TorchRay/benchmark.html#id8) in a "pointing game". In this game the goal of the method is to provide a saliency map on which pixel with maximal value will belong to an object of the specified class.

Beyond classification interpretability: can we interpret segmentation/detection models? Yes we can, but it is a harder task and the simple methods provide much worse results.

# References & Further read

1. TorchRay [github](https://github.com/facebookresearch/TorchRay/tree/master), [docs](https://facebookresearch.github.io/TorchRay/index.html)

2. ICCV'19 interpretability tutorial (theory) [slides](https://interpretablevision.github.io/slide/iccv19_binder_slide.pdf) [site](https://interpretablevision.github.io/)

3. Explainable AI: Interpreting, Explaining and Visualizing Deep Learning [book](https://www.springer.com/gp/book/9783030289539)

4. Papers: [deconvnet](https://doi.org/10.1007/978-3-319-10590-1_53), [guided backprop](https://arxiv.org/abs/1412.6806), [grad-cam](http://openaccess.thecvf.com/content_iccv_2017/html/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.html), [excitation backprop](https://arxiv.org/abs/1608.00507), [RISE](https://arxiv.org/pdf/1806.07421.pdf), [Extremal Perturbations](https://arxiv.org/abs/1910.08485)