In [41]:
import torch
import torch.nn as nn

class DiffLrpWrapper(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.activations = {}
        self._register_hooks()

    def _register_hooks(self):
        # Register a forward hook on each module
        for name, module in self.net.named_modules():
            # Avoid registering hooks on containers
            if len(list(module.children())) == 0:
                module.register_forward_hook(self._save_activation(name))

    def _save_activation(self, name):
        # This method returns a hook function
        def hook(module, input, output):
            self.activations[name] = output.detach()
        return hook

    def forward(self, x):
        # Forward pass through the network
        output = self.net(x)
        return output

    def get_activations(self):
        return self.activations

    def apply_operations(self):
        # Example operation: sum of activations for each layer
        results = {}
        for name, activation in self.activations.items():
            results[name] = torch.sum(activation)
        return results

# Example Usage
# Let's assume we have a simple network
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 10, kernel_size=5),
            nn.ReLU(),
            nn.Conv2d(10, 20, kernel_size=5),
            nn.ReLU()
        )
        self.classifier = nn.Sequential(
            nn.Linear(320, 50),
            nn.ReLU(),
            nn.Linear(50, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 320)  # Flatten the output for the classifier
        x = self.classifier(x)
        return x

# Instantiate and use the wrapper
model = SimpleNet()
wrapped_model = DiffLrpWrapper(model)

# Forward pass
input_tensor = torch.randn(1, 1, 28, 28)
output = wrapped_model(input_tensor)

# Get activations
activations = wrapped_model.get_activations()
print(activations.keys())

# Apply operations on activations
# results = wrapped_model.apply_operations()

# print(results)


dict_keys(['features.0', 'features.1', 'features.2', 'features.3', 'classifier.0', 'classifier.1', 'classifier.2'])


In [42]:
# layers is the stripped list of tuples to access weights
layers = [(name, module) for name, module in wrapped_model.net.named_modules() if len(list(module.children())) == 0]
# activations is the stripped dictionary to access activations, indexed by layer names caught by hooks in the forward pass

# one can access the weights of a layer by iterating through the layers list (ordered list on order of execution of forward pass)
layers[0][1].weight.shape

torch.Size([10, 1, 5, 5])

In [31]:

def print_layers(model, prefix=""):
    for name, module in model.named_children():
        if isinstance(module, nn.Sequential):
            # Recursive call to handle nested structures
            print(f"{prefix}{name} (Sequential):")
            print_layers(module, prefix=prefix + "  ")
        else:
            # Print layer type
            print(f"{prefix}{name}: {type(module)}")


In [32]:
for name, module in wrapped_model.net.named_modules():
    print(name, type(module))

 <class '__main__.SimpleNet'>
features <class 'torch.nn.modules.container.Sequential'>
features.0 <class 'torch.nn.modules.conv.Conv2d'>
features.1 <class 'torch.nn.modules.activation.ReLU'>
features.2 <class 'torch.nn.modules.conv.Conv2d'>
features.3 <class 'torch.nn.modules.activation.ReLU'>
classifier <class 'torch.nn.modules.container.Sequential'>
classifier.0 <class 'torch.nn.modules.linear.Linear'>
classifier.1 <class 'torch.nn.modules.activation.ReLU'>
classifier.2 <class 'torch.nn.modules.linear.Linear'>


In [24]:
if type(wrapped_model.net.features[0]) is nn.Conv2d:
    print("yes")

yes


In [43]:
import torch
from captum.attr import LRP

atr = LRP(wrapped_model.net)

In [45]:

output = wrapped_model(input_tensor)
attr = atr.attribute(input_tensor, target=0)
loss = attr.sum() - 10
loss.backward()



RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [50, 10]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

In [None]:
# if one can manually define the gradient computation at run time with this function

# Step 1: Define the tensor
x = torch.tensor([2.0, 3.0], requires_grad=True)

# Step 2: Define the function, e.g., f = x^2
y = x ** 2

# Step 3: Calculate the gradient manually
grads = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=torch.tensor([1.0, 1.0]))
