In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
from models import MLP
from data import gen_rnd_ds
from losses import LinearClassification
from block_analysis import block_hessian, curvature_effects

## Helpers

In [3]:
def get_model_ds_loss():    
    model =  MLP(inp_dim, hid_dim, out_dim, nlayer, bias, mode).cuda(device)
    ds = gen_rnd_ds(inp_dim, inp_mean, inp_var, 
                   out_dim, nsamp, device)
    loss_fn = LinearClassification(out_dim)
    return model, ds, loss_fn

In [4]:
def pp(lr, delta, h, H):
    rel = abs((h.item() - H.sum().item()) / min(abs(H.sum().item()), abs(h.item())))
    wari = h.item() / H.sum().item()
    print(f"LR {lr:.2E} \t || Delta={delta:.2E}\t ||Error={rel:.2E}  \t|| hoe={h.item()} \t|| H={H.sum().item()}\t||wari={wari}")

def lr_range(model, ds, loss_fn, start=-8, stop=8, step=1, log_scale=False):
    for lr in range(start, stop, step):
        if log_scale:
            lr = 10**lr
        H = block_hessian(model, ds, loss_fn, lr)
        delta, h = curvature_effects(model, ds, loss_fn, lr)
        pp(lr, delta, h, H)

## Params

In [5]:
# Model
mode = "linear"
bias = False
nlayer = 3
inp_dim = 10 
out_dim = 10
hid_dim = 100

# Data parameters
nsamp = 100
inp_mean = 0
inp_var = 1

# Others
device = 0
lr = 7

In [6]:
model, ds, loss_fn = get_model_ds_loss()
lr_range(model, ds, loss_fn, -8, 8, 1, True)

LR 1.00E-08 	 || Delta=0.00E+00	 ||Error=3.06E-07  	|| hoe=153226.953125 	|| H=153226.90625	||wari=1.000000305918857
LR 1.00E-07 	 || Delta=0.00E+00	 ||Error=6.37E-08  	|| hoe=15322.6943359375 	|| H=15322.693359375	||wari=1.0000000637330837
LR 1.00E-06 	 || Delta=5.82E-11	 ||Error=2.59E-07  	|| hoe=1415.8541259765625 	|| H=1415.853759765625	||wari=1.000000258650256
LR 1.00E-05 	 || Delta=6.69E-09	 ||Error=1.37E-01  	|| hoe=19.349321365356445 	|| H=17.02098846435547	||wari=1.1367918735082194
LR 1.00E-04 	 || Delta=7.62E-08	 ||Error=9.59E-06  	|| hoe=0.08392788469791412 	|| H=0.08392708003520966	||wari=1.0000095876408917
LR 1.00E-03 	 || Delta=7.66E-07	 ||Error=9.15E-01  	|| hoe=0.00024385825963690877 	|| H=0.00012732927280012518	||wari=1.9151782954081948
LR 1.00E-02 	 || Delta=7.66E-06	 ||Error=9.87E-01  	|| hoe=-4.729372449219227e-06 	|| H=-9.396786481374875e-06	||wari=0.5032967875340354
LR 1.00E-01 	 || Delta=7.66E-05	 ||Error=2.33E-02  	|| hoe=-2.5451299734413624e-06 	|| H=-2.4871951

In [7]:
lr_range(model, ds, loss_fn, 1, 20, 1, False)

LR 1.00E+00 	 || Delta=7.67E-04	 ||Error=2.76E-04  	|| hoe=-2.532149665057659e-06 	|| H=-2.5314511731266975e-06	||wari=1.000275925500115
LR 2.00E+00 	 || Delta=1.54E-03	 ||Error=3.56E-04  	|| hoe=-2.5329645723104477e-06 	|| H=-2.532062353566289e-06	||wari=1.0003563177434742
LR 3.00E+00 	 || Delta=2.31E-03	 ||Error=4.39E-04  	|| hoe=-2.5336114504170837e-06 	|| H=-2.5324991383968154e-06	||wari=1.0004392151623682
LR 4.00E+00 	 || Delta=3.08E-03	 ||Error=5.83E-04  	|| hoe=-2.5338376872241497e-06 	|| H=-2.5323606678284705e-06	||wari=1.0005832579120515
LR 5.00E+00 	 || Delta=3.86E-03	 ||Error=7.02E-04  	|| hoe=-2.534203076720587e-06 	|| H=-2.5324243324575946e-06	||wari=1.0007023879214059
LR 6.00E+00 	 || Delta=4.64E-03	 ||Error=8.55E-04  	|| hoe=-2.5345943868160248e-06 	|| H=-2.5324279704364017e-06	||wari=1.0008554700883554
LR 7.00E+00 	 || Delta=5.43E-03	 ||Error=1.03E-03  	|| hoe=-2.5350029773107963e-06 	|| H=-2.532394319132436e-06	||wari=1.0010301153176073
LR 8.00E+00 	 || Delta=6.21E-03	