# TRAK

Here we'll provide a minimal and hackabale implementation of TRAK. For a more elaborate implementation with all the bells & whistles, see https://github.com/MadryLab/trak.

In [3]:
import torch
from tqdm.auto import tqdm

# let's abstract away the "boring" parts in utils.py
from utils import get_model, get_loader

model = get_model()
# we'll use a smaller training set containing only samples from the cat & dog classes
train_loader = get_loader(split="train", batch_size=500)
val_loader = get_loader(split="val")

Files already downloaded and verified
Files already downloaded and verified


Suppose we have a few models trained already (identical models trained on the same dataset, but with a different random seed). Let's load two of them.

For convenience, we're also adding the code to train these models from scratch.

In [11]:
NUM_MODELS = 2
models = []
for i in range(NUM_MODELS):
    model = get_model()
    sd = torch.load(f"./artifacts/model_{i}.pt")
    model.load_state_dict(sd)
    models.append(model)

# training from scratch (in case you want to regenerate the above checkpoints yourself)
want_to_retrain = False
if want_to_retrain:
    from utils import train
    models = []
    for i in range(NUM_MODELS):
        model = get_model()
        model = train(model, train_loader)
        models.append(model)

Next, we create the random projection matrix of size model_size x proj_dim in practice, this ends up being too large to work with, so we use a custom CUDA kernel (https://github.com/MadryLab/trak/tree/main/fast_jl) to project using this matrix. 

In [12]:
proj_dim = 512
model_size = sum(torch.numel(p) for p in model.parameters())
P = torch.randn(model_size, proj_dim, device="cuda")

Now let's compute the attribution score from a few train samples to one test (target) sample. To this end, we need to compute the surrogate features of the train and target samples.

Following the derivation in the second part of the tutorial, we'll use the *loss* when we "featurize" the train sample:

In [13]:
x_train, y_train = next(iter(train_loader))
x_train = x_train.to("cuda")
y_train = y_train.to("cuda")

Phi_train = {i: [] for i in range(NUM_MODELS)}
for i, model in enumerate(models):
    loss = torch.nn.CrossEntropyLoss(reduction="none")
    L = loss(model(x_train), y_train)
    for l in tqdm(L):  # iterate over the loss for each sample in the batch
        phi = torch.autograd.grad(l, model.parameters(), retain_graph=True)
        phi = torch.cat([p.flatten() for p in phi])  # flatten the gradients into a single vector
        Phi_train[i].append((P.T @ phi).clone().detach())
    Phi_train[i] = torch.stack(Phi_train[i])

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

...and we'll use the model output (or "measurement") of interest when we "featurize" the target sample:

In [20]:
# let's use the first sample in the validation set as the target sample
x_target, y_target = next(iter(val_loader))
x_target = x_target[0:1].to("cuda")
y_target = y_target[0:1].to("cuda")

def model_output(logits, label):
    """
    This function computes "margins", i.e. the difference between the logits of the target class and the log-sum-exp of the logits of all the other classes.
    """
    bindex = torch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
    logits_correct = logits[bindex, label.unsqueeze(0)]

    cloned_logits = logits.clone()
    # remove the logits of the correct labels from the sum
    # in logsumexp by setting to -ch.inf
    cloned_logits[bindex, label.unsqueeze(0)] = torch.tensor(
        -float("inf"), device=logits.device, dtype=logits.dtype
    )

    margins = logits_correct - cloned_logits.logsumexp(dim=-1)
    return margins.sum()

Phi_target = {}
for i, model in enumerate(models):
    O = model_output(model(x_target), y_target)
    phi = torch.autograd.grad(O, model.parameters(), create_graph=True)
    phi = torch.cat([p.flatten() for p in phi])  # flatten the gradients into a single vector
    Phi_target[i] = (P.T @ phi).clone().detach()

Next, we need to compute an estimate of the Hessian matrix. It turns out that for our linear surrogate model, the Hessian has a simple closed form!

In [23]:
H = {i: torch.zeros(proj_dim, proj_dim, device="cuda") for i in range(2)}
for i, model in enumerate(models):
    for x, y in train_loader:
        x = x.to("cuda")
        y = y.to("cuda")
        loss = torch.nn.CrossEntropyLoss(reduction="sum")
        L = loss(model(x), y)
        phi = torch.autograd.grad(L, model.parameters(), create_graph=True)
        phi = torch.cat([p.flatten() for p in phi])
        X = (P.T @ phi.reshape(-1, 1)).clone().detach()
        H[i] += X @ X.T

# we can optionally add a damping term lambda * I here
H_inv = {i: torch.linalg.inv(H[i]) for i in range(2)}

We are ready to compute our attribution scores:

In [24]:
scores = torch.zeros(Phi_train[0].shape[0])
for k in Phi_train.keys():
    scores += (Phi_train[k] @ H_inv[k] @ Phi_target[k] / len(Phi_train)).cpu()

That's it!