In [None]:
import torch
from torchvision.models import resnet18
from captum.attr import IntegratedGradients
from captum.attr import visualization as viz

# Load a pre-trained ResNet18 model
model = resnet18(pretrained=True)
model.eval()  # Set the model to evaluation mode

# Assume img is a preprocessed image tensor
# img = ...

# Wrap model with Integrated Gradients
ig = IntegratedGradients(model)

# Compute attributions using Integrated Gradients
attributions, delta = ig.attribute(img, target=0, return_convergence_delta=True)

# Visualize the attributions
# The visualization utility is more suitable for notebook environments
viz.visualize_image_attr(attributions[0].cpu().permute(1,2,0).detach().numpy(), 
                         method="heat_map", 
                         sign="all", 
                         show_colorbar=True)
