# Tests for reproduced Laplace
1. shape test
2. examine if the results given by reproduced Laplace and `laplace-torch` package match.
3. test on reproduced Laplace. 

In [1]:
import torch 
import torch.nn as nn

import numpy as np 
import matplotlib.pyplot as plt
 
from LABDL.laplace.curvatures import BackPackGGN as BPG
from LABDL.laplace.curvatures import BackPackEF as BPE

from laplace.curvature import BackPackEF, BackPackGGN

In [2]:
import time
A = torch.rand(10, 4, 4)
B = torch.rand(7, 9, 5)
t = torch.kron(A, B)
print(t.shape)

torch.Size([70, 36, 20])


In [3]:
%reload_ext autoreload 
%autoreload 2

In [6]:
batch_size = 10
in_features = 3
out_features = 2

def _model():
    """model to fit on"""
    model = nn.Linear(in_features, out_features)
    setattr(model, 'output_size', 2) # for the laplace-torch
    setattr(model, 'num_params', sum([len(param.flatten()) for param in model.parameters()])
    )
    return model
    
def _reg_Xy():
    """regression samples"""
    X = torch.randn(batch_size, in_features)
    y = torch.randn(batch_size, out_features)
    return X, y

def _cls_Xy():
    """classification samples"""
    X = torch.randn(batch_size, in_features)
    y = torch.randint(out_features, (batch_size, ))
    return X, y

In [7]:
model = _model()
reg_Xy = _reg_Xy()
cls_Xy = _cls_Xy()

In [8]:
def test_ggn_shape(model, Xy):
    """GGNCurvature methods shape test
    Args: 
    model: note that in BackPack, self-defined module should be extended. 
    """
    X, y = Xy
    ggn = BPG(model, likelihood='regression')

    Gs, loss = ggn.gradients(X, y)
    print("gradients.shape:", Gs.shape)
    assert Gs.shape == (batch_size, model.num_params)
        
    Js, out = ggn.jacobians(X)
    print("jacobians.shape:", Js.shape)
    assert Js.shape == (batch_size, out_features, model.num_params)
    
    ggn_full, loss = ggn.full(X, y)
    print("ggn_full.shape:", ggn_full.shape)
    assert ggn_full.shape == (model.num_params, model.num_params)
                     
    ggn_diag, loss = ggn.diag(X, y)
    print("ggn_diag.shape:", ggn_diag.shape)
    assert ggn_diag.shape == (model.num_params,)
        
    #print("=====kron=====")
    #gnn_kron, loss = ggn.kron(features, targets)

test_ggn_shape(model, reg_Xy)

gradients.shape: torch.Size([10, 8])
jacobians.shape: torch.Size([10, 2, 8])
ggn_full.shape: torch.Size([8, 8])
ggn_diag.shape: torch.Size([8])


In [9]:
def test_ef_shape(model, Xy):
    """EFCurvature methods shape test
    Args: 
    model: note that in BackPack, self-defined module should be extended. 
    """
    X, y = Xy
    ef = BPE(model, likelihood='regression')

    ef_full, loss = ef.full(X, y)
    print("ef_full.shape:", ef_full.shape)
    assert ef_full.shape == (model.num_params, model.num_params)
                     
    ef_diag, loss = ef.diag(X, y)
    print("ef_diag.shape:", ef_diag.shape)
    assert ef_diag.shape == (model.num_params,)
        
    #print("=====kron=====")
    #gnn_kron, loss = ef.kron(features, targets)

test_ef_shape(model, reg_Xy)

ef_full.shape: torch.Size([8, 8])
ef_diag.shape: torch.Size([8])


In [10]:
def test_ef(model, Xy, likelihood): 
    X, y = Xy
    ef = BPE(model, likelihood)
    efbp = BackPackEF(model, likelihood) 
    H_ef, loss = ef.full(X, y)     
    bploss, H_efbp = efbp.full(X, y)
    np.testing.assert_allclose(loss, bploss, atol=1e-8) 
    np.testing.assert_allclose(H_ef, H_efbp, atol=1e-8) 
    
    H_ef, loss = ef.diag(X, y)     
    bploss, H_efbp = efbp.diag(X, y)
    np.testing.assert_allclose(loss, bploss, atol=1e-8) 
    np.testing.assert_allclose(H_ef, H_efbp, atol=1e-8) 
    
    
test_ef(model, reg_Xy, 'regression')
test_ef(model, reg_Xy, 'regression')

In [12]:
def test_ggn(model, Xy, likelihood, backend_kwargs): 
    X, y = Xy
    ggn = BPG(model, likelihood, **backend_kwargs)
    ggnbp = BackPackGGN(model, likelihood, **backend_kwargs) 
    
    #H, loss = ggn.full(X, y)     
    #bploss, bpH = ggnbp.full(X, y)
    #H, bpH = H.detach().numpy(), bpH.detach().numpy()
        
    #np.testing.assert_allclose(loss, bploss, atol=1e-6) 
    #np.testing.assert_allclose(H, bpH, atol=1e-6) 
    
    H, loss = ggn.diag(X, y)     
    bploss, bpH = ggnbp.diag(X, y)
    H, bpH = H.detach().numpy(), bpH.detach().numpy()
    
    np.testing.assert_allclose(loss, bploss, atol=1e-6) 
    np.testing.assert_allclose(H, bpH, atol=1e-6) 
    
backend_kwargs = {'stochastic':False}
test_ggn(model, reg_Xy, 'regression', backend_kwargs)
test_ggn(model, cls_Xy, 'classification', backend_kwargs)

The EF, and GGN with MC approximation might show non-neglectable deviations from the exact GGN.

In [13]:
def test_ggn_exact_vs_stochastic(model, Xy): 
    X, y = Xy
    ggn_exact = BPG(model, 'regression', stochastic=False)
    ggn_stoch = BPG(model, 'regression', stochastic=True)
    ggnbp = BackPackGGN(model, 'regression', stochastic=False)
    
    eH, eloss = ggn_exact.diag(X, y)
    sH, sloss = ggn_stoch.diag(X, y)
    bploss, bpH = ggnbp.diag(X, y)
    
    np.testing.assert_allclose(bploss, eloss, atol=1e-8) 
    np.testing.assert_allclose(bpH, eH, atol=1e-8) 
    
    np.testing.assert_allclose(eloss, sloss, atol=1e-8) 
    np.testing.assert_allclose(eH, sH, atol=1e-8) 
    
test_ggn_exact_vs_stochastic(model, reg_Xy)

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=1e-08

Mismatched elements: 8 / 8 (100%)
Max absolute difference: 1.1201124
Max relative difference: 0.3260285
 x: array([9.239404, 2.315515, 3.606776, 9.239404, 2.315515, 3.606776,
       9.999999, 9.999999], dtype=float32)
 y: array([ 8.538967,  2.393187,  4.290271,  9.51006 ,  3.435627,  3.244742,
       10.883088, 10.20693 ], dtype=float32)

In [14]:
def test_ef_full_vs_ggn_full(model, Xy, likelihood): 
    X, y = Xy
    ef = BPE(model, likelihood)
    ggn = BPG(model, likelihood)
    
    H_ef, efloss = ef.full(X, y)
    H_ggn, ggnloss = ggn.full(X, y)
    H_ef, H_ggn = H_ef.detach().numpy(), H_ggn.detach().numpy()
    
    np.testing.assert_allclose(efloss, ggnloss, atol=1e-8) 
    np.testing.assert_allclose(H_ef, H_ggn, atol=1e-8) 
    
    
test_ef_full_vs_ggn_full(model, cls_Xy, 'classification')
test_ef_full_vs_ggn_full(model, reg_Xy, 'regression')

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=1e-08

Mismatched elements: 64 / 64 (100%)
Max absolute difference: 2.6726496
Max relative difference: 4.7019544
 x: array([[ 4.443159,  2.600602,  0.798379, -4.443159, -2.600602, -0.798379,
         1.348529, -1.348529],
       [ 2.600602,  4.864106,  1.219426, -2.600602, -4.864106, -1.219426,...
 y: array([[ 3.418147,  0.920167,  0.629899, -3.418147, -0.920167, -0.629899,
         0.389323, -0.389323],
       [ 0.920167,  2.191456,  0.424093, -0.920167, -2.191456, -0.424093,...

In [15]:
def test_full_vs_diag_ef(model, Xy, likelihood):
    X, y = Xy
    ef = BPE(model, likelihood)
    
    dH, dloss = ef.diag(X, y)
    fH, floss = ef.full(X, y)
    
    assert dloss == floss
    assert torch.allclose(dH, fH.diagonal())
    
test_full_vs_diag_ef(model, reg_Xy, 'regression')
test_full_vs_diag_ef(model, cls_Xy, 'classification')

In [16]:
def test_full_vs_diag_ggn(model, Xy, likelihood):
    X, y = Xy
    ggn = BPG(model, likelihood)
    
    fH, floss = ggn.full(X, y)  
    dH, dloss = ggn.diag(X, y)
    fH, dH = fH.detach().numpy(), dH.detach().numpy()

    np.testing.assert_allclose(fH.diagonal(), dH, atol=1e-6) 
    np.testing.assert_allclose(floss, dloss) 

test_full_vs_diag_ggn(model, reg_Xy, 'regression')
test_full_vs_diag_ggn(model, cls_Xy, 'classification')

In [19]:
import time
def test_jacobians_time(model, Xy):
    """
    time test when computing jacobians by setting `retain_graph` as True or forwarding for every output dimension. 
    """
    X, y = Xy
    ggn = BPG(model, 'regression', stochastic=False)
    print("test set size: ", len(X))
    print("======jacobians using retain_graph=====")
    start = time.time()
    Js, out = ggn.jacobians(X)
    end = time.time()
    time_Js = end-start
    print("time: ", time_Js)
    
    #print("======jacobians not using retain_graph=====")
    #start = time.time()
    #Js_wr, out_wr = ggn.jacobians_without_retain(X)
    #end = time.time()
    #time_Jswr = end-start
    #print("time: ", time_Jswr)
    
    #print("ratio: ", time_Js / time_Jswr)
    
    #np.testing.assert_allclose(Js.detach().numpy(), Js_wr.detach().numpy())
    #np.testing.assert_allclose(out.detach().numpy(), out_wr.detach().numpy())

test_jacobians_time(model, reg_Xy)

test set size:  10
time:  0.003782510757446289
