In [1]:
%load_ext autoreload
%autoreload 2

# load a standard mlp model

In [2]:
from zeroptim.mlp import MLP
from zeroptim.data import loader
from zeroptim.utils import parse_yaml_config
import torch

In [3]:
dataloader = loader('mnist-digits')
m = MLP(**parse_yaml_config('mlp.yaml'))
opt = torch.optim.SGD(m.parameters(), lr=0.1)
crit = torch.nn.CrossEntropyLoss()
m.norm()

tensor(28.3087, grad_fn=<AddBackward0>)

# take one gradient step

In [4]:
# fix (inputs, targets) for a batch
inputs, targets = next(iter(dataloader))
inputs = inputs.flatten(start_dim=1)

In [5]:
# store parameters prior the update
prev_params = tuple([p.clone() for p in m.parameters()])

# perform one optimization step
opt.zero_grad()
loss = crit(m(inputs), targets)
loss.backward()
opt.step()

# store params post-update, gradients, update-directions
cur_params = tuple([p.clone() for p in m.parameters()])
gs = [p.grad.clone() for p in m.parameters()]
vs = [p2.detach() - p1.detach() for p2, p1 in zip(cur_params, prev_params)]

# compute gradient correlation manually

In [6]:
gcorr = sum((g * v).sum()for g, v in zip(gs, vs))
print(f'Forward value: {loss}')
print(f'Gradient correlation: {gcorr.item()}')

Forward value: 2.312971591949463
Gradient correlation: -0.011501125991344452


# working attempt using torch.functional

In [7]:
# https://discuss.pytorch.org/t/combining-functional-jvp-with-a-nn-module/81215/2

def del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        del_attr(getattr(obj, names[0]), names[1:])

def set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        set_attr(getattr(obj, names[0]), names[1:], val)

def make_functional(model):
    # Remove all the parameters in the model
    orig_params = tuple(model.parameters())
    names = []
    for name, p in list(model.named_parameters()):
        del_attr(model, name.split("."))
        names.append(name)
    return orig_params, names

def restore_functional(model, orig_params, names):
    # Restore all the parameters in the model
    for name, p in zip(names, orig_params):
        set_attr(model, name.split("."), p)

In [8]:
print(f'Forward pass before jvp {crit(m(inputs), targets)}')

Forward pass before jvp 2.301527976989746


In [9]:
def func(*params):
    for name, p in zip(names, params):
        set_attr(m, name.split("."), p)
    return crit(m(inputs), targets)

orig_params, names = make_functional(m)
fwdvalue, jvp = torch.autograd.functional.jvp(func, prev_params, v=tuple(vs))
restore_functional(m, orig_params, names)

# print final results
print(f'Forward value: {fwdvalue}')
print(f'Gradient correlation (jvp): {jvp}')
print(f'Gradient correlation (manual): {gcorr}')

Forward value: 2.312971591949463
Gradient correlation (jvp): -0.011501124128699303
Gradient correlation (manual): -0.011501125991344452


In [10]:
print(f'Forward pass after jvp {crit(m(inputs), targets)}')

Forward pass after jvp 2.301527976989746


# compute vhv using torch.hvp

In [11]:
orig_params, names = make_functional(m)
fwdvalue, hvp = torch.autograd.functional.vhp(func, prev_params, v=tuple(vs))
vhv = sum((v * hv).sum() for v, hv in zip(vs, hvp))
restore_functional(m, orig_params, names)

# print final results
print(f'Forward value: {fwdvalue}')
print(f'Directional sharpness (vhv): {vhv}')

Forward value: 2.312971591949463
Directional sharpness (vhv): -0.0009699021466076374


# Previous attemps and failed work-arounds

# verify `jvp` functions on inputs

In [12]:
def func(*inputs):
    l = crit(m(inputs[0]), targets)
    return l

primals = (inputs, inputs,) # multidim to check broadcasting works
tangents = (torch.randn(inputs.shape), torch.randn(inputs.shape),)
fwdvalue, jvp = torch.autograd.functional.jvp(func, primals, tangents)

print(f'Forward value: {round(fwdvalue.item(), 3)}')
print(f'JVP (on inputs): {round(jvp.item(), 3)}')

Forward value: 2.302
JVP (on inputs): 0.027


# compute gradient correlation using `jvp`

# attempt 01: reassign model in the closure (dependency on params)

In [13]:
def reassign(*params):
    with torch.no_grad():
        # re-assign parameters of a model in-place
        for p, new_p in zip(m.parameters(), params):
            p.copy_(new_p)
    m.zero_grad() # zero-out gradients

def func(*params):
    reassign(*params) # get back to prev params
    l = crit(m(inputs), targets) # compute loss
    reassign(*cur_params) # get back to new params
    return l

# compute gradient correlations with jvp
fwdvalue, jvp = torch.autograd.functional.jvp(func, prev_params, tuple(vs))

print(f'Forward value: {round(fwdvalue.item(), 3)}')
print(f'JVP (on params): {round(jvp.item(), 3)}')

Forward value: 2.313
JVP (on params): 0.0


# attempt 02: recreate model from scratch in closure

In [14]:
def func(*params):
    m = MLP(**parse_yaml_config('mlp.yaml'))
    with torch.no_grad():
        for p, new_p in zip(m.parameters(), params):
            p[:] = new_p
    return crit(m(inputs), targets)

# compute gradient correlations with jvp
fwdvalue, jvp = torch.autograd.functional.jvp(func, prev_params, tuple(vs))

print(f'Forward value: {round(fwdvalue.item(), 3)}')
print(f'JVP (on params): {round(jvp.item(), 3)}')

Forward value: 2.313
JVP (on params): 0.0


# attempt 03: use a param perturbation to avoid copying params and detaching

In [15]:
def func(*params):
    # create a perturbation in model parameters
    perturb = [prev - cur for prev, cur in zip(params, cur_params)]
    
    # apply the perturbation to the model to rebase model at *params
    for p, eps in zip(m.parameters(), perturb):
        p.data += eps

    return crit(m(inputs), targets)

fwdvalue, jvp = torch.autograd.functional.jvp(func, prev_params, tuple(vs))

# print final results
print(f'Forward value: {round(fwdvalue.item(), 3)}')
print(f'JVP (on params): {round(jvp.item(), 3)}')

Forward value: 2.313
JVP (on params): 0.0


# attempt 04: write the forward pass manually

In [16]:
def func(*params):

    w1, b1 = params[0], params[1]
    w2, b2 = params[2], params[3]
    w3, b3 = params[4], params[5]

    z1 = m.act_func(inputs @ w1.T + b1)
    z2 = m.act_func(z1 @ w2.T + b2)
    z3 = z2 @ w3.T + b3

    return crit(z3, targets)

fwdvalue, jvp = torch.autograd.functional.jvp(func, prev_params, tuple(vs))

# print final results
print(f'Forward value: {round(fwdvalue.item(), 3)}')
print(f'JVP (on params): {round(jvp.item(), 3)}')

Forward value: 2.72
JVP (on params): -0.067
