In [1]:
import torch
from torchessian import hessian_matmul, lanczos
from itertools import product
from copy import deepcopy

In [119]:
n = 2
model = torch.nn.Linear(n, 1)
# torch.manual_seed(2019)
v = torch.rand(sum(p.data.numel() for p in model.parameters()))
x = torch.rand(1000, n)
y = x.sum(1)
def loss_function(y_hat, y):
    return ((y_hat.view(-1) - y.view(-1)) * (y_hat.view(-1) - y.view(-1))).sum() / y.numel()
batch = x, y
m = 3
# Exact Hessian

h11 = 2 * (x[:, 0] * x[:, 0]).sum() / x.size(0)
h21 = h12 = 2 * (x[:, 0] * x[:, 1]).sum() / x.size(0)
h31 = h13 = 2 * x[:, 0].sum() / x.size(0)

h22 = 2 * (x[:, 1] * x[:, 1]).sum() / x.size(0)
h32 = h23 = 2 * x[:, 1].sum() / x.size(0)

h33 = 2

H = torch.zeros(3, 3)

for i, j in product(range(3), range(3)):
    exec("H[%d, %d] = h%d%d" % (i, j, i+1, j+1))
    
print(H)

tensor([[0.6434, 0.4898, 0.9823],
        [0.4898, 0.6738, 1.0056],
        [0.9823, 1.0056, 2.0000]])


In [120]:
torch.eig(lanczos(model, loss_function, batch, 3, 3)[0])[0][:, 0]

tensor([3.0431, 0.1054, 0.1687])

In [101]:
def hessian_matmul_pearl(model, loss_function, v, batch):
    r = 1e-1
    model_delta = deepcopy(model)
    begin = end = 0
    for p in model_delta.parameters():
        if not p.requires_grad:
            continue
        end = begin + p.data.numel()
        p_flat = p.data.view(-1)
        p_flat += r * v[begin:end]
        begin = end


    model.zero_grad()

    x, y = batch
    E = loss_function(model(x), y)
    E_delta = loss_function(model_delta(x), y)

    E.backward()
    E_delta.backward()

    grad_w = torch.cat(
        list(p.grad.view(1, -1) for p in model.parameters() if p.requires_grad), 
        1
    )
    grad_w_delta = torch.cat(
        list(p.grad.view(1, -1) for p in model_delta.parameters() if p.requires_grad), 
        1
    )

    grad_w.squeeze_()
    grad_w_delta.squeeze_()


    return (grad_w_delta - grad_w) / r


In [102]:
model = model.to(torch.float64)
v = torch.ones(3).double()
h1 = hessian_matmul_pearl(model, loss_function, v, batch)
h2 = hessian_matmul(model, loss_function, v, batch)

(h1 - h2).norm()

tensor(1.3323e-15, dtype=torch.float64)

In [96]:
from torchvision.models import resnet18

model = resnet18(pretrained=False)
model = model.to(torch.float64)
torch.manual_seed(2019)
x = torch.rand(16, 3, 224, 244).double()
y = torch.randint(2, (16, ))
loss_function = torch.nn.CrossEntropyLoss()
batch = x, y
v = torch.rand(sum(p.data.numel() for p in model.parameters() if p.requires_grad)).double()

In [98]:
h1.sort()

torch.return_types.sort(
values=tensor([-2.7155, -2.5175, -2.3020,  ...,  2.1834,  2.3812,  2.6128],
       dtype=torch.float64),
indices=tensor([8128,   75, 7607,  ...,  489, 2924, 2016]))

In [99]:
h2.sort()

torch.return_types.sort(
values=tensor([-62.0502, -60.6201, -59.0441,  ...,  61.1140,  63.2780,  79.3621],
       dtype=torch.float64),
indices=tensor([4583, 2483, 9183,  ..., 7450, 2544, 6928]))

In [97]:
h1 = hessian_matmul_pearl(model, loss_function, v, batch)
h2 = hessian_matmul(model, loss_function, v, batch)

(h1 - h2).norm()

tensor(2215.4188, dtype=torch.float64)