In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys; sys.path.append('..')
from zeroptim.mlp import MLP
from zeroptim.data import loader
from zeroptim.utils import parse_yaml_config
from zeroptim.utils import *

In [None]:
from zeroptim.optimizers.mezo import MeZO
from zeroptim.optimizers.smartes import SmartES

In [None]:
import torch
from tqdm.auto import trange

# load standard mlp model

In [None]:
dataloader = loader('mnist-digits')
first_batch = next(iter(dataloader))
inputs, targets = first_batch
print("Shape of inputs:", inputs.shape)
print("Shape of targets:", targets.shape)
inputs.flatten(start_dim=1).shape

In [None]:
dataloader = loader('mnist-digits')
m = MLP(**parse_yaml_config('mlp.yaml'))
opt = torch.optim.SGD(m.parameters(), lr=1e-3)
crit = torch.nn.CrossEntropyLoss()
m

# define benchmark loop

In [None]:
def benchmark(model, loader, opt, criterion, train_iters=1000):

    def closure(inputs, targets, with_backward=False):
        # optimization-step closure :)
        opt.zero_grad()
        loss = criterion(model(inputs), targets)
        if with_backward: loss.backward()
        return loss
    
    def func_fwd(*params):
        for name, p in zip(names, params):
            set_attr(m, name.split("."), p)
        return crit(m(inputs), targets)

    losses = []
    jvps, vhvs = [], []
    n_iters = 0

    model.train()
    for epoch_idx in (pbar := trange(int(train_iters))):
        if n_iters >= train_iters: break

        for batch_idx, (inputs, targets) in enumerate(loader):
            if n_iters >= train_iters: break

            inputs = inputs.flatten(start_dim=1)
            prev_params = tuple([p.clone() for p in model.parameters()])
            
            # take optimization step
            loss = opt.step(
                lambda: closure(
                    inputs, targets, 
                    with_backward=not isinstance(opt, (MeZO, SmartES))
                )
            )

            cur_params = tuple([p.clone() for p in model.parameters()])
            vs = tuple([p2.detach() - p1.detach() for p2, p1 in zip(cur_params, prev_params)])

            # compute jvp and vhp
            tmp_params, names = make_functional(model)
            _, jvp = torch.autograd.functional.jvp(func_fwd, prev_params, vs)
            _, hvp = torch.autograd.functional.vhp(func_fwd, prev_params, vs)
            vhv = sum((v * hv).sum() for v, hv in zip(vs, hvp))
            restore_functional(model, tmp_params, names)

            # append metrics
            losses.append(loss.item())
            jvps.append(jvp.item())
            vhvs.append(vhv.item())

            # update tqdm bar
            pbar.set_description(f'train loss: {loss.item():.3f}')
            pbar.update(1)
            n_iters += 1

    pbar.close()

    return losses, jvps, vhvs

In [None]:
losses, jvps, vhvs = benchmark(m, dataloader, opt, crit, train_iters=30*len(dataloader))

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.figure()
plt.plot(losses, label='SGD', color='C0', linestyle='--')
plt.legend()
plt.show()

In [None]:
plt.figure()
plt.plot(jvps, label='jvps', color='C2', linestyle='--')
plt.legend()

In [None]:
plt.figure()
plt.plot(vhvs, label='vhvs', color='C1', linestyle='--')
plt.legend()