In [None]:
import torch
import torch.nn as nn
import copy


In [None]:


class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 10, kernel_size=5),
            nn.ReLU(),
            nn.Conv2d(10, 20, kernel_size=5),
            nn.ReLU()
        )
        self.classifier = nn.Sequential(
            nn.Linear(320, 50),
            nn.ReLU(),
            nn.Linear(50, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 320)  # Flatten the output for the classifier
        x = self.classifier(x)
        return x

class DiffLrpWrapper(nn.Module):
    def __init__(self, net):
        super().__init__()
        assert isinstance(net, nn.Module), f"Expected net to be an instance of nn.Module, got {type(net)}"
        self.net = net
        self.activations = {}
        self.outputs = {}
        self._register_hooks()

    def _register_hooks(self):
        # Register a forward hook on each module
        for name, module in self.net.named_modules():
            # Avoid registering hooks on containers
            if len(list(module.children())) == 0:
                module.register_forward_hook(self._save_activation(name))

    def _save_activation(self, name):
        # This method returns a hook function
        def hook(module, input, output):
            self.activations[name] = input[0].detach()
            self.outputs[name] = output.detach()
        return hook

    def forward(self, x, target_class:torch.Tensor):
        assert x.shape[0] == target_class.shape[0], f"Expected x and target_class to have the same batch size, got {x.shape[0]} and {target_class.shape[0]}"
        # Forward pass through the network
        initial_out = self.net(x)
        # Create a mask that zeros out all elements except for the target class
        mask = torch.zeros_like(initial_out)
        mask[torch.arange(mask.size(0)), target_class] = 1  # Ensure target_class is either a scalar or has the same batch size as x

        # Apply the mask to propogate relenace forwards
        relevance = initial_out * mask
        # loop backwards from output to input layer
        for name, actual_module in list(self.net.named_modules())[::-1]:
            if len(list(actual_module.children())) == 0:
                # if the module is a leaf module, apply LRP
                # print(f"reversing layer {name} which is of type {type(actual_module)}")
                relevance = self._apply_lrp(name, actual_module, relevance.detach())
        return relevance
    

    def _apply_lrp(self, name:str, actual_module:torch.nn.Module, relevance_to_be_propagaed:torch.Tensor):
        # Get the activation of the module
        layer_activation_values = self.activations[name]
        # check datatypes coming through
        assert isinstance(layer_activation_values, torch.Tensor)
        assert isinstance(actual_module, nn.Module)
        assert isinstance(relevance_to_be_propagaed, torch.Tensor)
        # Check that the shape of the layer outputs and relevance are the same
        if not self.outputs[name].shape == relevance_to_be_propagaed.shape:
            print(f"shapes didn't match for layer {name}")
            relevance_to_be_propagaed = relevance_to_be_propagaed.view(self.outputs[name].shape)
        # Get the relevance of the output & apply LRP
        relevance = self._reverse_layer_(layer_activation_values, actual_module, relevance_to_be_propagaed)
        return relevance
    
    def _reverse_layer_(self, activations_at_start:torch.Tensor, actual_module:torch.nn.Module, relevance:torch.Tensor, epsilon=1e-9):
        # make sure corret data is coming in
        assert isinstance(activations_at_start, torch.Tensor), f"Expected activations_at_start to be a torch.Tensor, got {type(activations_at_start)}"
        assert isinstance(actual_module, nn.Module), f"Expected actual_module to be an nn.Module, got {type(actual_module)}"
        assert isinstance(relevance, torch.Tensor), f"Expected relevance to be a torch.Tensor, got {type(relevance)}"
        # print(f"activations_at_start shape: {activations_at_start.shape}")
        activations_at_start.requires_grad_()
        activations_at_start.retain_grad()
        # perform a modified forward pass (alpha beta rule applied here apparently)
        z = epsilon + actual_module.forward(activations_at_start)
        # divide the outputs by the relevance
        s = torch.div(relevance, z)
        # multiply with weights matrix and perform a backwards pass to get the unit relevance
        torch.multiply(z, s.data).sum().backward()
        # multiple activations with gradients to get the final relevance
        c = activations_at_start * activations_at_start.grad
        return c

        
    def get_activations(self):
        return self.activations



In [None]:

# Example Usage
# Instantiate and use the wrapper
model = SimpleNet()
wrapped_model = DiffLrpWrapper(model)
target_class = torch.randint(0, 10, (20,1))

target_class

In [None]:

# Forward pass
input_tensor = torch.randn(20, 1, 28, 28)
output = wrapped_model(input_tensor, target_class)
output.sum().backward()


In [None]:

def print_layers(model, prefix=""):
    for name, module in model.named_children():
        if isinstance(module, nn.Sequential):
            # Recursive call to handle nested structures
            print(f"{prefix}{name} (Sequential):")
            print_layers(module, prefix=prefix + "  ")
        else:
            # Print layer type
            print(f"{prefix}{name}: {type(module)}")


In [None]:
for name, module in wrapped_model.net.named_modules():
    print(name, type(module))

In [None]:
if type(wrapped_model.net.features[0]) is nn.Conv2d:
    print("yes")

In [None]:
import torch
from captum.attr import LRP

atr = LRP(wrapped_model.net)

output = wrapped_model(input_tensor)
attr = atr.attribute(input_tensor, target=0)
loss = attr.sum() - 10
loss.backward()

In [None]:
# if one can manually define the gradient computation at run time with this function

# Step 1: Define the tensor
x = torch.tensor([2.0, 3.0], requires_grad=True)

# Step 2: Define the function, e.g., f = x^2
y = x ** 2

# Step 3: Calculate the gradient manually
grads = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=torch.tensor([1.0, 1.0]))


In [None]:
import torch
import copy

def copy_layer(layer):
    # Create a deep copy of the layer
    layer_copy = copy.deepcopy(layer)

    return layer_copy

In [None]:
x_prime = copy_layer(wrapped_model.net.features[0])

In [None]:
def reverse_layer(activations_at_start, layer, relevance, epsilon=1e-9):
    activations_at_start.requires_grad_()
    activations_at_start.retain_grad()
    # perform a modified forward pass (alpha beta rule applied here apparently)
    z = epsilon + copy_layer(layer).forward(activations_at_start)
    # divide the outputs by the relevance
    s = torch.div(relevance, z)
    # 
    torch.multiply(z, s.data).sum().backward()
    c = activations * activations.grad
    return c




    
    

In [None]:
for name, module in list(wrapped_model.net.named_modules())[::-1]:
            if len(list(module.children())) == 0:
                print(f"name: {name} module: {type(module)}")
            

In [None]:
layer = wrapped_model.net.classifier[2]
print(type(layer))
input_tensor = torch.randn(1, 1, 28, 28, requires_grad=True)
starting_relevance = wrapped_model.net(input_tensor)
activations_at_start = wrapped_model.get_activations()["classifier.1"].requires_grad_()
activations_at_start.retain_grad()
epsilon = 1e-9
# reverse_layer(activations, layer, starting_relevance)
z = epsilon + layer.forward(activations_at_start)
# divide the outputs by the relevance
s = torch.div(starting_relevance, z)
# 
torch.multiply(z, s.detach()).sum().backward()
c = activations_at_start * activations_at_start.grad

In [9]:
import torch
import torch.nn.functional as F

class InputStoringLayer(torch.nn.Module):
    def __init__(self, layer):
        super(InputStoringLayer, self).__init__()
        self.layer = layer
        self.input = None

    def forward(self, x):
        self.input = x
        return self.layer(x)

def lrp_linear(layer, R, eps=1e-6):
    """
    LRP for a linear layer.
    Arguments:
        layer: the linear layer (InputStoringLayer wrapping nn.Linear)
        R: relevance scores from the previous layer (Tensor)
        eps: small value to avoid division by zero (float)
    Returns:
        relevance scores for the input of this layer (Tensor)
    """
    W = layer.layer.weight
    X = layer.input
    Z = W @ X.t() + layer.layer.bias[:, None] + eps
    S = R / Z
    C = W.t() @ S
    R_new = X * C.t()
    return R_new

def lrp_conv2d(layer, R, eps=1e-6):
    """
    LRP for a convolutional layer.
    Arguments:
        layer: the convolutional layer (InputStoringLayer wrapping nn.Conv2d)
        R: relevance scores from the previous layer (Tensor)
        eps: small value to avoid division by zero (float)
    Returns:
        relevance scores for the input of this layer (Tensor)
    """
    W = layer.layer.weight
    X = layer.input
    Z = F.conv2d(X, W, bias=layer.layer.bias, stride=layer.layer.stride, padding=layer.layer.padding) + eps
    S = R / Z
    C = F.conv_transpose2d(S, W, stride=layer.layer.stride, padding=layer.layer.padding)
    R_new = X * C
    return R_new

class LRPModel(torch.nn.Module):
    def __init__(self, model):
        super(LRPModel, self).__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)

    def lrp(self, x, target_class):
        """
        Perform LRP on the model.
        Arguments:
            x: input data (Tensor)
            target_class: index of the target class (int)
        Returns:
            relevance scores for the input (Tensor)
        """
        # Forward pass
        output = self.model(x)
        
        # Initialize relevance for the output layer
        R = torch.zeros_like(output)
        R[:, target_class] = output[:, target_class]
        
        # Perform backward pass for LRP
        for layer in reversed(list(self.model.children())):
                print(f"reversing layer {layer} which is of type {type(layer)}")
                if isinstance(layer.layer, torch.nn.Linear):
                    R = lrp_linear(layer, R)
                elif isinstance(layer.layer, torch.nn.Conv2d):
                    R = lrp_conv2d(layer, R)
                elif isinstance(layer.layer, torch.nn.ReLU):
                    R = R * (layer.input > 0).float()
        
        return R

# Define a simple model for demonstration
class SimpleCNN(torch.nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = InputStoringLayer(torch.nn.Conv2d(1, 10, kernel_size=5))
        self.conv2 = InputStoringLayer(torch.nn.Conv2d(10, 20, kernel_size=5))
        self.fc1 = InputStoringLayer(torch.nn.Linear(320, 50))
        self.fc2 = InputStoringLayer(torch.nn.Linear(50, 10))
        self.relu = InputStoringLayer(torch.nn.ReLU())

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = x.view(-1, 320)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# Example usage
model = SimpleCNN()
lrp_model = LRPModel(model)

# Dummy input
x = torch.randn(1, 1, 28, 28)
target_class = 0

# Perform LRP
R = lrp_model.lrp(x, target_class)
print(R)

reversing layer relu.layer which is of type <class 'torch.nn.modules.activation.ReLU'>


AttributeError: 'ReLU' object has no attribute 'layer'