In [1]:
from torch.nn import MaxPool2d,Module,Sequential
from captum.attr._utils.lrp_rules import EpsilonRule
import torch
from collections import OrderedDict

In [2]:
layer1 = MaxPool2d((2,2), 2)

class Simple(Module):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.rule = EpsilonRule()

    def forward(self, x: torch.Tensor):
        return x.sum(dim=[i+1 for i,_ in enumerate(x.shape[1:])])
    
layer = Sequential(OrderedDict({"one": layer1, "two": Simple()}))

In [3]:
layer

Sequential(
  (one): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  (two): Simple()
)

In [4]:
input = torch.ones((1,4,4))
input[0,[1],[1]] = 10
input

tensor([[[ 1.,  1.,  1.,  1.],
         [ 1., 10.,  1.,  1.],
         [ 1.,  1.,  1.,  1.],
         [ 1.,  1.,  1.,  1.]]])

In [5]:
# input.requires_grad = True

In [12]:
output = layer1(input)
output

tensor([[[10.,  1.],
         [ 1.,  1.]]])

In [7]:
def print_backward_hook(module, grad_input, grad_output):
    print("Backward pass gradients (relevances):")
    print(grad_output, grad_output[0].detach().sum())  # grad_output[0] contains the gradients with respect to the output
    
    print("Backward pass gradients - input (relevances):")
    print(grad_input, grad_input[0].detach().sum())  # grad_output[0] contains the gradients with respect to the output

In [8]:
list(layer.named_children())[1][1].register_full_backward_hook(print_backward_hook)
list(layer.named_children())[0][1].register_full_backward_hook(print_backward_hook)

<torch.utils.hooks.RemovableHandle at 0x13dca1c70>

In [9]:
from captum.attr import LRP

epic = LRP(layer)

In [10]:
epic.attribute(input)

Backward pass gradients (relevances):
(tensor([1.]),) tensor(1.)
Backward pass gradients - input (relevances):
(tensor([[[0.7692, 0.0769],
         [0.0769, 0.0769]]]),) tensor(1.)
Backward pass gradients (relevances):
(tensor([[[0.7692, 0.0769],
         [0.0769, 0.0769]]]),) tensor(1.)
Backward pass gradients - input (relevances):
(tensor([[[0.0000, 0.0000, 0.0769, 0.0000],
         [0.0000, 0.7692, 0.0000, 0.0000],
         [0.0769, 0.0000, 0.0769, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000]]]),) tensor(1.)




tensor([[[ 0.,  0.,  1.,  0.],
         [ 0., 10.,  0.,  0.],
         [ 1.,  0.,  1.,  0.],
         [ 0.,  0.,  0.,  0.]]], grad_fn=<MulBackward0>)