In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import copy
import numpy as np
import matplotlib.pyplot as plt
import autograd_grad_sample

train_data = torchvision.datasets.MNIST(
    root='./data/',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True,
)

test_data = torchvision.datasets.MNIST(
    root='./data/',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True,
)

In [2]:
x = torch.FloatTensor(train_data.data.view(-1, 28*28).float())
y = torch.LongTensor(train_data.targets)

In [3]:
model = nn.Sequential(
    nn.Linear(28*28, 512),
    nn.Linear(512, 512),
    nn.Linear(512, 10))
# model = nn.Linear(28*28, 10)
model_clone = copy.deepcopy(model)

autograd_grad_sample.add_hooks(model)

In [4]:
train_ids = np.arange(10)
output = F.log_softmax(model(x[train_ids]), dim=1)
target = y[train_ids]
num_samples = target.size(0)
loss = F.nll_loss(output, target, reduction='none')

num_samples = loss.size(0)
grad_per_sample = torch.zeros_like(loss)
grad_all = {name: torch.zeros_like(param) for name, param in model.named_parameters()}

for i in range(num_samples):
    model.zero_grad()
    loss[i].backward(retain_graph=True)
    grad_norm = 0.0
    for name, param in model.named_parameters():
        grad_all[name] += param.grad
        grad_norm += param.grad.data.norm(2) ** 2
    grad_per_sample[i] = torch.sqrt(grad_norm)

for name, param in model.named_parameters():
    param.grad = grad_all[name]/num_samples


In [9]:
# model.zero_grad()
# output = F.log_softmax(model(x[train_ids]), dim=1)
# target = y[train_ids]
# loss = F.nll_loss(output, target, reduction='none').mean()
# loss.backward(retain_graph=True)

model_clone.zero_grad()
output = F.log_softmax(model_clone(x[train_ids]), dim=1)
target = y[train_ids]
loss = F.nll_loss(output, target, reduction='none').mean()
loss.backward(retain_graph=True)
for p1, p2 in zip(model.parameters(), model_clone.parameters()):
    print(torch.allclose(p1.grad.data, p2.grad.data))
    print(p1.grad.data.max(), p1.grad.data.min(), p1.grad.data.mean())
    print(p2.grad.data.max(), p2.grad.data.min(), p2.grad.data.mean())
    print((p1.grad.data-p2.grad.data).norm(2)**2)

False
tensor(4.4751) tensor(-5.8465) tensor(-0.0077)
tensor(4.4751) tensor(-5.8465) tensor(-0.0077)
tensor(2.1736e-08)
False
tensor(0.0226) tensor(-0.0271) tensor(-0.0003)
tensor(0.0226) tensor(-0.0271) tensor(-0.0003)
tensor(5.1245e-15)
False
tensor(4.2421) tensor(-5.2108) tensor(0.0005)
tensor(4.2421) tensor(-5.2108) tensor(0.0005)
tensor(8.6035e-10)
True
tensor(0.0382) tensor(-0.0413) tensor(-1.2577e-05)
tensor(0.0382) tensor(-0.0413) tensor(-1.2577e-05)
tensor(8.2743e-16)
False
tensor(14.9703) tensor(-17.4761) tensor(-1.1211e-08)
tensor(14.9703) tensor(-17.4761) tensor(-2.6633e-08)
tensor(3.3423e-10)
True
tensor(0.2148) tensor(-0.2999) tensor(4.0303e-08)
tensor(0.2148) tensor(-0.2999) tensor(3.6880e-08)
tensor(7.5734e-16)


In [6]:
for p1, p2 in zip(model.parameters(), model_clone.parameters()):
    print(torch.allclose(p1.data, p2.data))

True
True
True
True
True
True


In [7]:
model.parameters()

<generator object Module.parameters at 0x7fbb4afd2b10>