In [1]:
import zeroptim.callbacks.functional as utils
from zeroptim.trainer import ZeroptimTrainer
from zeroptim.configs import load
import torch

config = load("configs/experiment/mlp-sgd-full.yaml")
trainer = ZeroptimTrainer.from_config(config)
model = trainer.model
loader = trainer.loader
crit = trainer.crit

get_params_data = lambda named_params: map(lambda x: x[1], named_params)
named_params_t0 = [(n, p.clone().detach()) for (n, p) in  model.named_parameters()]

def func_fwd(*model_params, **kwargs):
    # state-less functional forward pass (except model)
    names = kwargs.get("state").get("names")
    inputs = kwargs.get("state").get("inputs")
    targets = kwargs.get("state").get("targets")
    for name, p in zip(names, model_params):
        utils.set_attr(model, name.split("."), p)
    return crit(model(inputs), targets)

def metrics_in_batch(inputs, targets, names, primals, tangents, func_fwd):
    state = {"names": names, "inputs": inputs, "targets": targets}
    _, jvp = torch.autograd.functional.jvp(func_fwd, primals, tangents, state=state)
    _, hvp = torch.autograd.functional.hvp(func_fwd, primals, tangents, state=state)
    vhv = sum((v * hv).sum() for v, hv in zip(tangents, hvp))
    jvp, vhv = jvp.item(), vhv.item()
    return jvp, vhv

def metrics_in_landscape(iterator, names, primals, tangents, func_fwd):
    agg_jvp, agg_vhv, count = 0.0, 0.0, 0.0

    for inputs, targets in iterator:
        sz = inputs.size(0)
        state = {"names": names, "inputs": inputs, "targets": targets}
        _, jvp_ = torch.autograd.functional.jvp(
            func_fwd, primals, tangents, state=state
        )
        _, hvp_ = torch.autograd.functional.hvp(
            func_fwd, primals, tangents, state=state
        )
        vhv_ = sum((v * hv).sum() for v, hv in zip(tangents, hvp_))
        agg_jvp += jvp_.item() * sz
        agg_vhv += vhv_.item() * sz
        count += sz

    jvp = agg_jvp / count
    vhv = agg_vhv / count

    return jvp, vhv

In [2]:
# same random v for all runs
random_v = [torch.randn_like(p) for p in get_params_data(named_params_t0)]

In [3]:
print(len(list(get_params_data(named_params_t0))), len(random_v))
print(list(get_params_data(named_params_t0))[0].shape, random_v[0].shape)
print(list(get_params_data(named_params_t0))[1].shape, random_v[1].shape)
print(list(get_params_data(named_params_t0))[2].shape, random_v[2].shape)
print(list(get_params_data(named_params_t0))[3].shape, random_v[3].shape)
print(list(get_params_data(named_params_t0))[4].shape, random_v[4].shape)
print(list(get_params_data(named_params_t0))[5].shape, random_v[5].shape)
print(list(get_params_data(named_params_t0))[6].shape, random_v[6].shape)
print(list(get_params_data(named_params_t0))[7].shape, random_v[7].shape)

8 8
torch.Size([1024, 784]) torch.Size([1024, 784])
torch.Size([1024]) torch.Size([1024])
torch.Size([256, 1024]) torch.Size([256, 1024])
torch.Size([256]) torch.Size([256])
torch.Size([32, 256]) torch.Size([32, 256])
torch.Size([32]) torch.Size([32])
torch.Size([10, 32]) torch.Size([10, 32])
torch.Size([10]) torch.Size([10])


In [4]:
# compute jvp and vhv for all batch separately
tmp_params, names = utils.make_functional(model)

jvp_batches, vhv_batches = [], []
for inputs, targets in loader:
    primals = tuple(get_params_data(named_params_t0))
    tangents = tuple(random_v)
    jvp, vhv = metrics_in_batch(inputs, targets, names, primals, tangents, func_fwd)
    jvp_batches.append(jvp)
    vhv_batches.append(vhv)

jvp_batch = sum(jvp_batches) / len(jvp_batches)
vhv_batch = sum(vhv_batches) / len(vhv_batches)

utils.restore_functional(model, tmp_params, names)

In [5]:
# compute jvp for the entire landscape at once
tmp_params, names = utils.make_functional(model)

iterator = torch.utils.data.DataLoader(
    dataset=loader.dataset,
    batch_size=loader.batch_size
)

tangents = tuple(random_v)
primals = tuple(get_params_data(named_params_t0))
jvp_full, vhv_full = metrics_in_landscape(iterator, names, primals, tangents, func_fwd)

utils.restore_functional(model, tmp_params, names)

In [6]:
print(f"Compute per batch and then average: JVP={jvp_batch:.4f}, VHV={vhv_batch:.4f}")
print(f"Compute for the entire landscape: JVP={jvp_full:.4f}, VHV={vhv_full:.4f}")

Compute per batch and then average: JVP=1.8681, VHV=742.7760
Compute for the entire landscape: JVP=1.8681, VHV=742.7760


In [None]:
# get first element of loader
inputs, targets = next(iter(loader))
targets.shape