In [1]:
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
plt.style.use("dark_background")
%matplotlib inline

sys.path.append("../../src")
from models import MLP
from data import gen_rnd_ds
from losses import LinearClassification
from block_analysis import block_hessian, curvature_effects

## Helpers

In [2]:
def get_model_ds_loss():    
    model =  MLP(inp_dim, hid_dim, out_dim, nlayer, bias, mode).cuda(device)
    ds = gen_rnd_ds(inp_dim, inp_mean, inp_var, 
                   out_dim, nsamp, device)
    loss_fn = LinearClassification(out_dim)
    return model, ds, loss_fn

In [6]:
def pp(lr, delta, h, H):
    """
    """
    rel = abs((h.item() - H.sum().item()) / min(abs(H.sum().item()), abs(h.item())))
    print(f"LR {lr:.2E} \t ||Error={rel:.2E}  \t ||H={h.item():.2E} \t || BH={H.sum().item():.2E} \t || Delta={delta:.2E}")

def lr_range(model, ds, loss_fn, start=-8, stop=8, step=1, log_scale=True):
    """
    """
    for lr in range(start, stop, step):
        if log_scale:
            lr = 10**lr
        H = block_hessian(model, ds, loss_fn, lr)
        delta, h = curvature_effects(model, ds, loss_fn, lr)
        pp(lr, delta, h, H)

## Params

In [7]:
# Model
mode = "linear"
bias = False
nlayer = 3
inp_dim = 10 
out_dim = 10
hid_dim = 100

# Data parameters
nsamp = 100
inp_mean = 0
inp_var = 1

# Others
device = 0
lr = 7

In [8]:
model, ds, loss_fn = get_model_ds_loss()
lr_range(model, ds, loss_fn, -8, 8, 1, True)

LR 1.00E-08 	 ||Error=6.40E-08  	 ||H=4.89E+05 	 || BH=4.89E+05 	 || Delta=0.00E+00
LR 1.00E-07 	 ||Error=0.00E+00  	 ||H=4.89E+04 	 || BH=4.89E+04 	 || Delta=0.00E+00
LR 1.00E-06 	 ||Error=8.35E-02  	 ||H=3.02E+03 	 || BH=2.79E+03 	 || Delta=9.31E-10
LR 1.00E-05 	 ||Error=1.47E-01  	 ||H=1.59E+01 	 || BH=1.82E+01 	 || Delta=2.36E-08
LR 1.00E-04 	 ||Error=3.99E+00  	 ||H=-4.08E-02 	 || BH=-2.04E-01 	 || Delta=2.44E-07
LR 1.00E-03 	 ||Error=2.91E-04  	 ||H=3.42E-04 	 || BH=3.42E-04 	 || Delta=2.44E-06
LR 1.00E-02 	 ||Error=6.05E-01  	 ||H=3.89E-06 	 || BH=6.25E-06 	 || Delta=2.44E-05
LR 1.00E-01 	 ||Error=2.44E-03  	 ||H=9.91E-06 	 || BH=9.94E-06 	 || Delta=2.44E-04
LR 1.00E+00 	 ||Error=3.97E-04  	 ||H=9.94E-06 	 || BH=9.95E-06 	 || Delta=2.44E-03
LR 1.00E+01 	 ||Error=2.54E-03  	 ||H=9.92E-06 	 || BH=9.95E-06 	 || Delta=2.39E-02
LR 1.00E+02 	 ||Error=2.60E-02  	 ||H=9.69E-06 	 || BH=9.95E-06 	 || Delta=1.96E-01
LR 1.00E+03 	 ||Error=3.40E-01  	 ||H=7.42E-06 	 || BH=9.95E-06 	 || Delta

In [6]:
lr_range(model, ds, loss_fn, 1, 20, 1, False)

LR 1.00E+00 	 || Delta=4.14E-04	 ||Error=5.56E-04  	|| hoe=-7.73E-07 	|| H=-7.72E-07
LR 2.00E+00 	 || Delta=8.28E-04	 ||Error=2.59E-04  	|| hoe=-7.73E-07 	|| H=-7.73E-07
LR 3.00E+00 	 || Delta=1.24E-03	 ||Error=2.27E-04  	|| hoe=-7.73E-07 	|| H=-7.73E-07
LR 4.00E+00 	 || Delta=1.66E-03	 ||Error=2.89E-04  	|| hoe=-7.73E-07 	|| H=-7.73E-07
LR 5.00E+00 	 || Delta=2.08E-03	 ||Error=3.19E-04  	|| hoe=-7.73E-07 	|| H=-7.73E-07
LR 6.00E+00 	 || Delta=2.49E-03	 ||Error=6.24E-04  	|| hoe=-7.73E-07 	|| H=-7.73E-07
LR 7.00E+00 	 || Delta=2.91E-03	 ||Error=6.68E-04  	|| hoe=-7.73E-07 	|| H=-7.73E-07
LR 8.00E+00 	 || Delta=3.33E-03	 ||Error=7.54E-04  	|| hoe=-7.73E-07 	|| H=-7.73E-07
LR 9.00E+00 	 || Delta=3.75E-03	 ||Error=8.16E-04  	|| hoe=-7.74E-07 	|| H=-7.73E-07
LR 1.00E+01 	 || Delta=4.17E-03	 ||Error=9.40E-04  	|| hoe=-7.74E-07 	|| H=-7.73E-07
LR 1.10E+01 	 || Delta=4.59E-03	 ||Error=1.02E-03  	|| hoe=-7.74E-07 	|| H=-7.73E-07
LR 1.20E+01 	 || Delta=5.01E-03	 ||Error=1.13E-03  	|| hoe=-7.74E