In [None]:
import torch
import timm
from PIL import Image
from tensordict import TensorDict
from tdhook.attribution import Saliency
from tdhook.attribution import IntegratedGradients

In [None]:
# Load model and prepare image
model = timm.create_model("vgg16.tv_in1k", pretrained=True)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

image = Image.open("results/simple/zebra_1.jpg").convert("RGB")
image_tensor = transforms(image)

In [None]:
# Define attribution target (zebra class = 340)
def init_attr_targets(targets, _):
    zebra_logit = targets["output"][..., 340]
    return TensorDict(out=zebra_logit, batch_size=targets.batch_size)

In [None]:
# Compute attribution
with Saliency(IntegratedGradients(init_attr_targets=init_attr_targets)).prepare(model) as hooked_model:
    td = TensorDict(
        {"input": image_tensor.unsqueeze(0), ("baseline", "input"): torch.zeros_like(image_tensor).unsqueeze(0)},
        batch_size=1,
    )
    td = hooked_model(td)  # Access attribution with td.get(("attr", "input"))