In [1]:
import torch
from copy import deepcopy

In [2]:
def hessian_matmul(model, loss_function, v, batch):
    r = 0.001
    model_delta = deepcopy(model)
    begin = end = 0
    for p in model_delta.parameters():
        end = begin + p.data.numel()
        p_flat = p.data.view(-1)        
        p_flat += r * v[begin:end]
        begin = end
    
    x, y = batch
    E = loss_function(model(x), y)
    E_delta = loss_function(model_delta(x), y)
    
    E.backward()
    E_delta.backward()
    
    grad_w = torch.cat(list(p.grad.view(1, -1) for p in model.parameters()), 1)
    grad_w_delta = torch.cat(list(p.grad.view(1, -1) for p in model_delta.parameters()), 1)
    
    grad_w.squeeze_()
    grad_w_delta.squeeze_()
    
    return (grad_w_delta - grad_w) / r

In [3]:
model = torch.nn.Linear(2, 1)
torch.manual_seed(2019)
v = torch.Tensor([1, 2, 3])
x = torch.rand(10, 2)
y = x.sum(1)
def loss_function(y_hat, y):
    return ((y_hat.view(-1) - y.view(-1)) * (y_hat.view(-1) - y.view(-1))).sum() / y.numel()
batch = x, y

In [4]:
hessian_matmul(model, loss_function, v, batch)

tensor([4.7985, 4.1026, 8.8365])

In [17]:
class SampleNN(torch.nn.Module):
    def __init__(self, input_size):
        super(SampleNN, self).__init__()
        self.fc1 = torch.nn.Linear(input_size, 64)
        self.conv = torch.nn.Conv2d(1, 1, 3, stride=1, padding=1)
        self.fc2 = torch.nn.Linear(64, 1)
    
    def forward(self, x):
        x = self.fc1(x)
        x = x.view(-1, 1, 8, 8)
        x = self.conv(x)
        x = x.view(-1, 64)
        x = self.fc2(x)
        return x

input_size = 10
model = SampleNN(input_size)
torch.manual_seed(2019)
x = torch.rand(100, input_size)
y = x.sum(1)
def loss_function(y_hat, y):
    return ((y_hat.view(-1) - y.view(-1)) * (y_hat.view(-1) - y.view(-1))).sum() / y.numel()
batch = x, y

In [19]:
v = torch.rand(sum(p.data.numel() for p in model.parameters()))

Hv = hessian_matmul(model, loss_function, v, batch)
Hv.size()

torch.Size([779])

In [40]:
from torchvision.models import resnet18

model = resnet18(pretrained=False)
torch.manual_seed(2019)
x = torch.rand(16, 3, 224, 244)
y = torch.randint(2, (16, ))
loss_function = torch.nn.CrossEntropyLoss()
batch = x, y

In [42]:
v = torch.rand(sum(p.data.numel() for p in model.parameters()))

Hv = hessian_matmul(model, loss_function, v, batch)
Hv.size()

torch.Size([11689512])