# TRAK

Here we'll provide a minimal and hackabale implementation of TRAK. For a more elaborate implementation with all the bells&whistles, see https://github.com/MadryLab/trak.

In [2]:
import torch

# let's abstract away the "boring" parts in utils.py
from utils import get_model, get_loader

model = get_model()
train_loader = get_loader(split="train")
val_loader = get_loader(split="val")

Files already downloaded and verified
Files already downloaded and verified


Suppose we have a few models trained already (identical models trained on the same dataset, but with a different random seed). Let's load two of them; below we'll load pre-trained checkpoints.

For convenience, we're also adding the code to train these models from scratch (commented out below).

In [35]:
NUM_MODELS = 2
models = []

# from utils import train
# for i in range(NUM_MODELS):
#     model = get_model()
#     model = train(model, train_loader)
#     models.append(model)

for path in ["./models/model_0.pt", "./models/model_1.pt"]:
    model = get_model()
    sd = torch.load(path)
    model.load_state_dict(sd)
    models.append(model)

Next, we create the random projection matrix of size model_size x proj_dim in practice, this ends up being too large to work with, so we use a custom CUDA kernel (https://github.com/MadryLab/trak/tree/main/fast_jl) to project using this matrix. 

In [36]:
proj_dim = 512
model_size = sum(torch.numel(p) for p in model.parameters())
P = torch.randn(model_size, proj_dim, device="cuda")

Now let's compute the attribution score from one train sample to one test (target) sample. To this end, we need to compute the surrogate features of the train and target samples.

Following the derivation in the second part of the tutorial, we'll use the *loss* when we "featurize" the train sample:

In [37]:
x, y = train_loader.dataset[0]
x = x.to("cuda")
y = torch.tensor(y).to(torch.int64).to("cuda")

Phi_train = {}
for i, model in enumerate(models):
    loss = torch.nn.CrossEntropyLoss(reduction="none")
    L = loss(model(x.unsqueeze(0)), y.unsqueeze(0))
    phi = torch.autograd.grad(L, model.parameters(), create_graph=True)
    phi = torch.cat([p.flatten() for p in phi])  # flatten the gradients into a single vector
    Phi_train[i] = (P.T @ phi).clone().detach()

...and we'll use the model output (or "measurement") of interest when we "featurize" the target sample:

In [38]:
x_target, y_target = val_loader.dataset[0]
def model_output(logits, label):
    """
    This function computes "margins", i.e. the difference between the logits of the target class and the log-sum-exp of the logits of all the other classes.
    """
    bindex = torch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
    logits_correct = logits[bindex, label.unsqueeze(0)]

    cloned_logits = logits.clone()
    # remove the logits of the correct labels from the sum
    # in logsumexp by setting to -ch.inf
    cloned_logits[bindex, label.unsqueeze(0)] = torch.tensor(
        -float("inf"), device=logits.device, dtype=logits.dtype
    )

    margins = logits_correct - cloned_logits.logsumexp(dim=-1)
    return margins.sum()

Phi_target = {}
for i, model in enumerate(models):
    O = model_output(model(x.unsqueeze(0)), y.unsqueeze(0))
    phi = torch.autograd.grad(O, model.parameters(), create_graph=True)
    phi = torch.cat([p.flatten() for p in phi])  # flatten the gradients into a single vector
    Phi_target[i] = (P.T @ phi).clone().detach()

Next, we need to compute an estimate of the Hessian matrix. It turns out that for our linear surrogate model, the Hessian has a simple closed form!

In [39]:
H = {i: torch.zeros(proj_dim, proj_dim, device="cuda") for i in range(2)}
for i, model in enumerate(models):
    for x, y in train_loader:
        x = x.to("cuda")
        y = torch.tensor(y).to(torch.int64).to("cuda")
        loss = torch.nn.CrossEntropyLoss(reduction="sum")
        L = loss(model(x), y)
        phi = torch.autograd.grad(L, model.parameters(), create_graph=True)
        phi = torch.cat([p.flatten() for p in phi])
        X = (P.T @ phi.reshape(-1, 1)).clone().detach()
        H[i] += X @ X.T

# we can optionally add a damping term lambda * I here
H_inv = {i: torch.linalg.inv(H[i]) for i in range(2)}

  y = torch.tensor(y).to(torch.int64).to("cuda")


That's it, we are ready to compute our attribution scores:

In [40]:
score = 0.
for k in Phi_train.keys():
    score += Phi_train[k] @ H_inv[k] @ Phi_target[k] / len(Phi_train)
print(score)

tensor(-164915.9375, device='cuda:0')
