# Checking the correctness

I'm still working on the checks of the Hessian-vector calculations.

In [121]:
import torch
from torchessian import hessian_matmul
from torchessian.batch_mode import lanczos
from itertools import product
from copy import deepcopy

In [174]:
def hessian_matmul_pearl(model, loss_function, v, batch):
    """
        This is a PyTorch implementation of the method introduced
        by the article "Fast Exact Multiplication by the Hessian"
        
        - Link: http://www.bcl.hamilton.ie/~barak/papers/nc-hessian.pdf
        
    """
    r = 1e-3
    model_delta = deepcopy(model)
    begin = end = 0
    for p in model_delta.parameters():
        if not p.requires_grad:
            continue
        end = begin + p.data.numel()
        p_flat = p.data.view(-1)
        p_flat += r * v[begin:end]
        begin = end


    model.zero_grad()

    x, y = batch
    E = loss_function(model(x), y)
    E_delta = loss_function(model_delta(x), y)

    E.backward()
    E_delta.backward()

    grad_w = torch.cat(
        list(p.grad.view(1, -1) for p in model.parameters() if p.requires_grad), 
        1
    )
    grad_w_delta = torch.cat(
        list(p.grad.view(1, -1) for p in model_delta.parameters() if p.requires_grad), 
        1
    )

    grad_w.squeeze_()
    grad_w_delta.squeeze_()


    return (grad_w_delta - grad_w) / r


In [175]:
n = 2
model = torch.nn.Linear(n, 1)
# torch.manual_seed(2019)
v = torch.rand(sum(p.data.numel() for p in model.parameters()))
x = torch.rand(1000, n)
y = x.sum(1)
def loss_function(y_hat, y):
    return ((y_hat.view(-1) - y.view(-1)) * (y_hat.view(-1) - y.view(-1))).sum() / y.numel()
batch = x, y
m = 3
# Exact Hessian

h11 = 2 * (x[:, 0] * x[:, 0]).sum() / x.size(0)
h21 = h12 = 2 * (x[:, 0] * x[:, 1]).sum() / x.size(0)
h31 = h13 = 2 * x[:, 0].sum() / x.size(0)

h22 = 2 * (x[:, 1] * x[:, 1]).sum() / x.size(0)
h32 = h23 = 2 * x[:, 1].sum() / x.size(0)

h33 = 2

H = torch.zeros(3, 3)

for i, j in product(range(3), range(3)):
    exec("H[%d, %d] = h%d%d" % (i, j, i+1, j+1))
    
print(H)

tensor([[0.6698, 0.5066, 0.9960],
        [0.5066, 0.6937, 1.0245],
        [0.9960, 1.0245, 2.0000]])


In [176]:
torch.eig(lanczos(model, loss_function, batch, 3, 3)[0])[0][:, 0]

100%|██████████| 2/2 [00:00<00:00, 841.72it/s]

[Batch-mode LANCZOS Algorithm running]





tensor([3.0796, 0.1088, 0.1751])

In [177]:
v = torch.rand(3)
h1 = hessian_matmul_pearl(model, loss_function, v, batch)
h2 = hessian_matmul(model, loss_function, v, batch)

(h1 - h2).norm()

tensor(0.0020)

In [169]:
from torchvision.models import MobileNetV2

model = MobileNetV2()
torch.manual_seed(2019)
x = torch.rand(64, 3, 224, 244)
y = torch.randint(2, (64, ))
loss_function = torch.nn.CrossEntropyLoss()
batch = x, y
v = torch.rand(sum(p.data.numel() for p in model.parameters() if p.requires_grad))

In [170]:
h1 = hessian_matmul_pearl(model, loss_function, v, batch)
h2 = hessian_matmul(model, loss_function, v, batch)

(h1 - h2).norm()

tensor(12318.9453)

In [171]:
h1.sort()[0]

tensor([-1174.3099, -1073.7307,  -960.3297,  ...,   771.1893,   795.3997,
          842.8501])

In [172]:
h2.sort()[0]

tensor([-664.6807, -473.8264, -432.6977,  ...,  315.0118,  340.9702,
         385.2315])