In [3]:
import torch
import torch.nn as nn
from torch.autograd.functional import hvp

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        del_attr(getattr(obj, names[0]), names[1:])
def set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        set_attr(getattr(obj, names[0]), names[1:], val)

def make_functional(mod):
    orig_params = tuple(mod.parameters())
    # Remove all the parameters in the model
    names = []
    for name, p in list(mod.named_parameters()):
        del_attr(mod, name.split("."))
        names.append(name)
    return orig_params, names

def load_weights(mod, names, params):
    for name, p in zip(names, params):
        set_attr(mod, name.split("."), p)

def load_weights_from_optim(model, names, optimizer):
    # create a list of params from optim
    param_list = []
    for g in optimizer.param_groups:
        for p in g['params']:
            if p.requires_grad:
                param_list.append(p)
    # load params into model
    for name, p in zip(names, param_list):
        set_attr(model, name.split("."), p)

In [5]:
# model
class mlp(nn.Module):
    def __init__(self):
        super(mlp, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(1, 2, bias=False),
            nn.ReLU(),
            nn.Linear(2, 1, bias=False),
        )
    def forward(self, x):
        return self.layers(x).pow(5)

# Training Loop
Computation of Hg, where H is the hessian, and g is the gradient

In [9]:
torch.manual_seed(1996)

model = mlp()
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=.01)
print(sum(p.numel() for p in model.parameters()))

n = 10
x = torch.randn(n).reshape([n,1])
y = 2. * x - 3* x.pow(2) + x.pow(3)

for inputs, targets in zip(x,y):
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)
    optimizer.zero_grad()
    loss.backward(create_graph=True) # needed to differentiate a second time
    
    # gp: inner prod between g and p
    # in the iner prod g should require grad (difderentiate through it), p should not (constant)
    gp = torch.tensor(0., requires_grad=False)
    for group in optimizer.param_groups:
        for p in group['params']:
            if p.requires_grad:
                gp = gp + torch.sum(torch.mul(p.grad, p.detach().clone())) # scalar prod
    
    # Hvp
    Hz_lis = []
    n_params = len(list(model.parameters()))
    count = 0
    for group in optimizer.param_groups:
        for p in group['params']:
            if p.requires_grad:
                count += 1
                retain = (count!=n_params)
                Hp = torch.autograd.grad(gp, p, retain_graph=retain)[0]
                Hz_lis.append(Hp)
    
    print(torch.cat([Hz_i.flatten() for Hz_i in Hz_lis]))
    
    # step
    optimizer.step()

4
tensor([-1.2894e-08,  0.0000e+00, -2.3014e-09,  0.0000e+00])
tensor([-7.9106e-10,  0.0000e+00, -1.4119e-10,  0.0000e+00])
tensor([-6.8691e-08,  0.0000e+00, -1.2260e-08,  0.0000e+00])
tensor([-4.5436e-06,  0.0000e+00, -8.1097e-07,  0.0000e+00])
tensor([-2.1609e-06,  0.0000e+00, -3.8569e-07,  0.0000e+00])
tensor([-6.6748e-08,  0.0000e+00, -1.1914e-08,  0.0000e+00])
tensor([-0.0031,  0.0000, -0.0005,  0.0000])
tensor([1.3738e-05, 0.0000e+00, 2.4521e-06, 0.0000e+00])
tensor([3.5232e-04, 0.0000e+00, 6.2887e-05, 0.0000e+00])
tensor([ 0.0000, -0.0038,  0.0000,  0.0129])
