In [11]:
import torch
from torch.nn import Module
import torch_geometric
from torch_geometric.data import Data

class LayerIntegratedGradients:
    
    def __init__(self, model: Module, layer: Module):
        """
        Initialize the Layer Integrated Gradients instance.

        Parameters:
        model: PyTorch model
            The model to analyze.
        layer: PyTorch layer
            The layer to analyze gradients for.
        """
        self.model = model
        self.layer = layer
        self.hook = None
        self.activations = None
        self.gradients = None
        self._register_hooks()

    def _register_hooks(self):
        """Register forward and backward hooks to capture activations and gradients."""
        def forward_hook(module, inp, out):
            self.activations = out

        def backward_hook(module, grad_inp, grad_out):
            self.gradients = grad_out[0]

        self.hook = self.layer.register_forward_hook(forward_hook)
        self.layer.register_full_backward_hook(backward_hook)

    def compute_integrated_gradients(self, input_tensor, target_class_idx, baseline=None, steps=50):
        """
        Compute integrated gradients for a specific input and target class.

        Parameters:
        input_tensor: torch.Tensor
            The input tensor to analyze.
        target_class_idx: int
            Index of the target class for which gradients are calculated.
        baseline: torch.Tensor, optional
            Baseline tensor for IG. If None, a zero tensor of the same shape as input is used.
        steps: int
            Number of steps for the Riemann approximation.

        Returns:
        torch.Tensor: Integrated gradients for the layer's activations.
        """
        if baseline is None:
            baseline = torch.zeros_like(input_tensor)

        # Scale inputs from baseline to the actual input
        scaled_inputs = torch.stack([
            baseline + (float(i) / steps) * (input_tensor - baseline) for i in range(steps + 1)
        ])

        integrated_grads = None
        for scaled_input in scaled_inputs:
            self.model.zero_grad()

            # Forward pass
            output = self.model(scaled_input)

            # Select the target class
            target = output[:, target_class_idx].sum()

            # Backward pass
            target.backward(retain_graph=True)

            # Accumulate gradients wrt activations
            if integrated_grads is None:
                integrated_grads = self.gradients.clone()
            else:
                integrated_grads += self.gradients

        # Average over the steps and scale by input difference
        integrated_grads /= steps
        return integrated_grads * (input_tensor - baseline)

    def remove_hooks(self):
        """Remove the registered hooks."""
        if self.hook:
            self.hook.remove()

# Example usage with a custom PyTorch or PyTorch-Geometric model
class CustomModel(Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.fc1 = torch.nn.Linear(10, 20)
        self.fc2 = torch.nn.Linear(20, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

# Create a toy model and dummy input
model = CustomModel()
layer = model.fc1
lig = LayerIntegratedGradients(model, layer)

input_tensor = torch.randn(1, 10, requires_grad=True)
target_class_idx = 1  # Target class index

# Compute Layered Integrated Gradients
lig_results = lig.compute_integrated_gradients(input_tensor, target_class_idx)
print("Integrated Gradients:", lig_results)

lig.remove_hooks()


RuntimeError: The size of tensor a (20) must match the size of tensor b (10) at non-singleton dimension 1