In [1]:
import torch
from copy import deepcopy

In [2]:
def hessian_matmul(model, loss_function, v, batch):
    
    model.zero_grad()
    
    x, y = batch
    E = loss_function(model(x), y)
    v.requires_grad = False
    grad_result = torch.autograd.grad(E, model.parameters(), create_graph=True)
    grad_result = torch.cat(tuple(p.view(1, -1) for p in grad_result), 1)
    grad_result.backward(v.view(1, -1))

    result = torch.cat(tuple(p.grad.view(1, -1) for p in model.parameters()), 1)
    
    model.zero_grad()
    
    return result

In [3]:
model = torch.nn.Linear(2, 1)
torch.manual_seed(2019)
v = torch.Tensor([1, 2, 3])
x = torch.rand(10, 2)
y = x.sum(1)
def loss_function(y_hat, y):
    return ((y_hat.view(-1) - y.view(-1)) * (y_hat.view(-1) - y.view(-1))).sum() / y.numel()
batch = x, y

In [4]:
hessian_matmul(model, loss_function, v, batch)

tensor([[4.7984, 4.1026, 8.8368]])

In [6]:
class SampleNN(torch.nn.Module):
    def __init__(self, input_size):
        super(SampleNN, self).__init__()
        self.fc1 = torch.nn.Linear(input_size, 64)
        self.conv = torch.nn.Conv2d(1, 1, 3, stride=1, padding=1)
        self.fc2 = torch.nn.Linear(64, 1)
    
    def forward(self, x):
        x = self.fc1(x)
        x = x.view(-1, 1, 8, 8)
        x = self.conv(x)
        x = x.view(-1, 64)
        x = self.fc2(x)
        return x

input_size = 10
model = SampleNN(input_size)
torch.manual_seed(2019)
x = torch.rand(100, input_size)
y = x.sum(1)
def loss_function(y_hat, y):
    return ((y_hat.view(-1) - y.view(-1)) * (y_hat.view(-1) - y.view(-1))).sum() / y.numel()
batch = x, y

In [7]:
v = torch.rand(sum(p.data.numel() for p in model.parameters()))

Hv = hessian_matmul(model, loss_function, v, batch)
Hv.size()

torch.Size([1, 779])

In [8]:
from torchvision.models import resnet18

model = resnet18(pretrained=False)
torch.manual_seed(2019)
x = torch.rand(16, 3, 224, 244)
y = torch.randint(2, (16, ))
loss_function = torch.nn.CrossEntropyLoss()
batch = x, y

In [9]:
v = torch.rand(sum(p.data.numel() for p in model.parameters()))

Hv = hessian_matmul(model, loss_function, v, batch)
Hv.size()

torch.Size([1, 11689512])