In [2]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.load.*weights_only=False.*")

import torch
from torch import nn
from torch.utils.data import DataLoader

from _dattri.MLP.load import load_benchmark

from _dattri.algorithm.tracin import TracInAttributor
# from dattri.algorithm.rps import RPSAttributor
from dattri.metrics.metrics import lds
from dattri.task import AttributionTask

In [2]:
activation_fn = "relu"
proj_dim = 512

projector_kwargs = {
    "proj_dim": proj_dim,
    "device": "cuda",
}

In [3]:
model_details, groundtruth = load_benchmark(
    model=f"mlp_{activation_fn}", dataset="mnist", metric="lds", method="Grad-Dot"
)

def loss_tracin(params, data_target_pair):
    image, label = data_target_pair
    image_t = image.unsqueeze(0)
    label_t = label.unsqueeze(0)
    loss = nn.CrossEntropyLoss()
    yhat = torch.func.functional_call(model_details["model"], params, image_t)
    return loss(yhat, label_t.long())

task = AttributionTask(
    model=model_details["model"].to("cuda"),
    loss_func=loss_tracin,
    checkpoints=model_details["models_full"][0]
)

In [4]:
train_loader = DataLoader(
    model_details["train_dataset"],
    shuffle=False,
    batch_size=500,
    sampler=model_details["train_sampler"],
)

test_loader = DataLoader(
    model_details["test_dataset"],
    shuffle=False,
    batch_size=500,
    sampler=model_details["test_sampler"],
)

attributor = TracInAttributor(
    task=task,
    weight_list=torch.ones(1) * 1e-3,
    normalized_grad=False,
    projector_kwargs=projector_kwargs,
    device="cuda",
)

In [5]:
score = attributor.attribute(train_loader, test_loader, verbose=False, sparse_check=True)

All zeros gradient after projection
original gradient:
tensor([8.0930e-11, 8.0930e-11, 8.0930e-11,  ..., 4.7851e-10, 3.2513e-12,
        8.6450e-11], device='cuda:0', grad_fn=<SelectBackward0>)

All zeros gradient after projection
original gradient:
tensor([1.4453e-10, 1.4453e-10, 1.4453e-10,  ..., 1.6917e-14, 0.0000e+00,
        4.3345e-10], device='cuda:0', grad_fn=<SelectBackward0>)

All zeros gradient after projection
original gradient:
tensor([6.0624e-11, 6.0624e-11, 6.0624e-11,  ..., 4.1798e-12, 1.0055e-11,
        7.0587e-12], device='cuda:0', grad_fn=<SelectBackward0>)

Average Sparsity of Original Gradients: 0.4155
Average Sparsity of Projected Gradients: 0.0012
Average Distance Relative Error (Original vs Projected): 0.0302


In [30]:
eps = 1e-2
delta = 1e-1
k = 12 * torch.log(torch.tensor(1 / delta)) / eps**2
c = 16 * torch.log(torch.tensor(1 / delta)) * (torch.log(torch.tensor(k / delta)) ** 2) / eps

# create a Rademaacher vector
r = torch.randint(0, 2, size=(10 * 20,), device='cpu', dtype=torch.float32) * 2 - 1

k_int = int(k)

# create a random hash function from [10 * self.block_size] to [k]
h_ = torch.randint(
    high=k_int,
    size=(10*20,),
    device='cpu',
    dtype=torch.int64,
)

print(k_int)
# H_ \in {0, \pm 1}^{k x (10 * self.block_size)} such that H_{i,j} = r_j if h_(j) = i, 0 otherwise
H_ = torch.zeros(
    k_int,
    10 * 20,
    dtype=torch.float32,
    device='cpu',
)
print(H_.shape)
print(h_.shape)
# for i in range(k_int):
#     for j in range(20):
#         H_[i, j] = r[j] if h_[j] == i else 0
H_[h_, torch.arange(10*20)] = r

# P \in {0, \pm 1}^{(10 * self.block_size) x self.block_size} such that P_{i,j} = 1 / sqrt(c) if (j-1)c + 1 <= i <= jc, 0 otherwise
P = torch.zeros(
    10 * 20,
    20,
    dtype=torch.float32,
    device='cpu',
)
c_int = int(c)
for j in range(20):
    P[j * c_int : (j + 1) * c_int, j] = 1 / torch.sqrt(torch.tensor(c, dtype=torch.float32))


276310
torch.Size([276310, 200])
torch.Size([200])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


  c = 16 * torch.log(torch.tensor(1 / delta)) * (torch.log(torch.tensor(k / delta)) ** 2) / eps
  P[j * c_int : (j + 1) * c_int, j] = 1 / torch.sqrt(torch.tensor(c, dtype=torch.float32))
