In [1]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.load.*weights_only=False.*")

import torch
from torch import nn
from torch.utils.data import DataLoader

from _dattri.MLP.load import load_benchmark

from _dattri.TRAK.trak import TRAKAttributor
# from dattri.algorithm.rps import RPSAttributor
from dattri.metrics.metrics import lds
from dattri.task import AttributionTask

In [2]:
activation_fn = "leaky_relu"
proj_dim = 128

projector_kwargs = {
    "proj_dim": proj_dim,
    "device": "cpu",
}

In [3]:
model_details, groundtruth = load_benchmark(
    model=f"mlp_{activation_fn}", dataset="mnist", metric="lds"
)

def loss_trak(params, data_target_pair):
    image, label = data_target_pair
    image_t = image.unsqueeze(0)
    label_t = label.unsqueeze(0)
    loss = nn.CrossEntropyLoss()
    yhat = torch.func.functional_call(model_details["model"], params, image_t)
    logp = -loss(yhat, label_t)
    return logp - torch.log(1 - torch.exp(logp))

def m_trak(params, image_label_pair):
    image, label = image_label_pair
    image_t = image.unsqueeze(0)
    label_t = label.unsqueeze(0)
    loss = nn.CrossEntropyLoss()
    yhat = torch.func.functional_call(model_details["model"], params, image_t)
    p = torch.exp(-loss(yhat, label_t.long()))
    return p

In [4]:
train_loader = DataLoader(
    model_details["train_dataset"],
    shuffle=False,
    batch_size=500,
    sampler=model_details["train_sampler"],
)

task = AttributionTask(
    model=model_details["model"].to("cpu"),
    loss_func=loss_trak,
    checkpoints=model_details["models_half"][0:50],
)


attributor = TRAKAttributor(
    task=task,
    correct_probability_func=m_trak,
    projector_kwargs=projector_kwargs,
    device="cpu",
)

In [5]:
data, label = attributor.cache(train_loader, verbose=False, sparse_check=True)

yhat
 GradTrackingTensor(lvl=2, value=
    BatchedTensor(lvl=1, bdim=0, value=
        tensor([[[  1.4075,  -2.9710,   1.8976,  ...,  -6.3893,   0.8470,  -1.8786]],

                [[ 16.4520,  -6.1797,   4.0359,  ...,  -2.5516,  -2.6246,  -1.3072]],

                [[ -3.2890,  -4.1089,   3.7946,  ...,  -0.5409,  -3.6524,   1.1941]],

                ...,

                [[ -4.0341,  -1.2962,   6.8789,  ...,   9.3506,  -0.0266,   7.2985]],

                [[  0.0398,  -2.7580,   6.6621,  ..., -12.0137,   1.8494,  -8.7620]],

                [[  0.9621,   0.2927,   1.7146,  ..., -10.2995,  15.4168,  -4.0678]]],
               grad_fn=<AddBackward0>)
    )
)

original
log p:	 GradTrackingTensor(lvl=2, value=
    BatchedTensor(lvl=1, bdim=0, value=
        tensor([-7.7700e-03, -3.6477e-05, -6.1545e-03, -5.0901e-05, -8.5967e-03,
                -6.4519e-04, -6.2941e-05, -2.3722e-05, -8.1089e-04, -8.1297e-05,
                -5.5012e-04, -3.4960e-02, -1.1325e-05, -3.6358e-05, -9.2333e-

In [6]:
# load parameters
parameters = torch.load("result/retrain/mlp_linear/0/model_weights_0.pt")
print(loss_trak(parameters, (data.to("cuda"), label.to("cuda"))))
# print(m_trak(parameters, (data.to("cuda"), label.to("cuda"))))


yhat
 tensor([[-5.9349, -0.0772,  1.3386, 11.5718, -1.7603,  0.2596, -7.2953,  2.5628,
          2.5986, -1.1153]], device='cuda:0')

original
log p:	 tensor(-0.0003, device='cuda:0')

using log_softmax
log_softmax_pred:	 tensor([[-1.7507e+01, -1.1649e+01, -1.0234e+01, -3.1073e-04, -1.3332e+01,
         -1.1313e+01, -1.8867e+01, -9.0093e+00, -8.9735e+00, -1.2687e+01]],
       device='cuda:0')
log p:	 tensor([-0.0003], device='cuda:0')
tensor(8.0765, device='cuda:0')
tensor(8.0765, device='cuda:0')
