In [1]:
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
sys.path.append("../../src")
from lr_tools import lr_calibrate
from models import MLP
from data import gen_rnd_ds
from losses import LinearClassification, SquaredClassification
from block_analysis import block_hessian, curvature_effects

## Helpers

In [3]:
def get_model_ds_loss(nlayer=5, mode="linear", loss="linear"):    
    model =  MLP(inp_dim, hid_dim, out_dim, nlayer, bias, mode).cuda(device)
    ds = gen_rnd_ds(inp_dim, inp_mean, inp_var, 
                   out_dim, nsamp, device)
    loss_fn = {
        "linear":LinearClassification(out_dim),
        "squared":SquaredClassification(out_dim),
        "ce":torch.nn.CrossEntropyLoss()
    }[loss]
    return model, ds, loss_fn

import time

def tick():
    torch.cuda.synchronize()
    return time.time()

def timer(func, *args):
    start = tick()
    func(*args)
    return start - tick()

## Params

In [4]:
# Model
mode = "linear"
bias = False
nlayer = 5
inp_dim = 10 
out_dim = 10
hid_dim = 100

# Data parameters
nsamp = 100
inp_mean = 0
inp_var = 1

# Others
device = 0
lr = 7

In [5]:
from block_analysis import block_hessian, legacy_block_hessian

In [6]:
for nlayer in [5, 7, 9]:
    for mode in ["linear", "relu"]:
        for loss in ["linear", "squared", "ce"]:
            model, ds, loss_fn = get_model_ds_loss(nlayer, mode, loss)
            H, gnorms = block_hessian(model, ds, loss_fn, lr)
            H_l, gnorms_l = legacy_block_hessian(model, ds, loss_fn, lr)
            print((H-H_l).abs().sum().item(), (gnorms-gnorms_l).abs().sum().item())

0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0


In [8]:
for nlayer in [5, 7, 9, 11, 13]:
    for mode in ["linear", "relu"]:
        for loss in ["linear", "squared", "ce"]:
            model, ds, loss_fn = get_model_ds_loss(nlayer, mode, loss)
            t1 = timer(legacy_block_hessian, model, ds, loss_fn, lr)
            t2 = timer(block_hessian, model, ds, loss_fn, lr)
            print(t1/t2)

2.0767333160218127
2.0792261752485466
2.2810839159994103
2.3957500144902335
2.437676376062966


### Line profile

In [9]:
%load_ext line_profiler

In [13]:
model, ds, loss_fn = get_model_ds_loss(nlayer, mode)
%lprun -f block_hessian block_hessian(model, ds, loss_fn, lr)

Timer unit: 1e-06 s

Total time: 0.349877 s
File: ../../src/block_analysis.py
Function: block_hessian at line 98

Line #      Hits         Time  Per Hit   % Time  Line Contents
    98                                           def block_hessian(model, ds, loss_fn, lr):
    99                                               """
   100                                                   Missing merge_DH(D, H)
   101                                               """
   102         1      13457.0  13457.0      3.8      model = clone_model(model)
   103         1       5809.0   5809.0      1.7      grads, loss_t = get_grad_loss(model, ds, loss_fn, lr)
   104         1       1523.0   1523.0      0.4      gnorms = get_gnorms(grads, model)
   105         1      46483.0  46483.0     13.3      d = _block_hessian_diag(model, ds, loss_fn, grads, loss_t, lr)
   106         1     282314.0 282314.0     80.7      H = _block_hessian_off_diag(model, ds, loss_fn, grads, loss_t, lr)
   107         1        289

In [14]:
model, ds, loss_fn = get_model_ds_loss(nlayer, mode)
%lprun -f legacy_block_hessian legacy_block_hessian(model, ds, loss_fn, lr)

Timer unit: 1e-06 s

Total time: 1.37359 s
File: ../../src/block_analysis.py
Function: legacy_block_hessian at line 177

Line #      Hits         Time  Per Hit   % Time  Line Contents
   177                                           def legacy_block_hessian(model, ds, loss_fn, lr):
   178                                               """
   179                                                   Missing merge_DH(D, H)
   180                                               """
   181         1       7349.0   7349.0      0.5      grads, loss_t = get_grad_loss(model, ds, loss_fn, lr)
   182         1       1915.0   1915.0      0.1      gnorms = get_gnorms(grads, model)
   183         1     198156.0 198156.0     14.4      d = legacy_block_hessian_diag(model, ds, loss_fn, grads, loss_t, lr)
   184         1    1165883.0 1165883.0     84.9      H = legacy_block_hessian_off_diag(model, ds, loss_fn, grads, loss_t, lr)
   185         1        282.0    282.0      0.0      H = _merge_blocks(H, d)
   