In [9]:
import torch
from stochman import nnj
from stochman.laplace import HessianCalculator

# Model

In [10]:
# model hyperparameter
_input_size = 5
_hidden_size = 200
_output_size = 2

# declare the model
model = nnj.Sequential([
                        nnj.Tanh(),
                        nnj.Linear(_input_size, _hidden_size),
                        nnj.Tanh(),
                        nnj.Linear(_hidden_size, _hidden_size),
                        nnj.Tanh(),
                        nnj.Linear(_hidden_size, _hidden_size),
                        nnj.Tanh(),
                        #nnj.L2Norm(),
                        nnj.Linear(_hidden_size, _output_size),
                        #nnj.L2Norm()
                        ],
            add_hooks = True
)

# MSE loss

In [11]:
# dataset (random)
_batch_size = 10
dataset = torch.randn(_batch_size, _input_size)

### Exact Diagonal

In [12]:
# wrt to weight
mse_exact_diagonal_weight = HessianCalculator(wrt = "weight",
                                              loss_func = "mse",
                                              shape = "diagonal",
                                              speed = "half")
hessian = mse_exact_diagonal_weight.compute_hessian(dataset, model)
print('Diagonal exact Generalized Gauss Netwon, with respect to weight\n',hessian)


# wrt to input
mse_exact_diagonal_input = HessianCalculator(wrt = "input",
                                             loss_func = "mse",
                                             shape = "diagonal",
                                             speed = "half")
hessian = mse_exact_diagonal_input.compute_hessian(dataset, model)
print('Diagonal exact Generalized Gauss Netwon, with respect to input\n',hessian)

Diagonal exact Generalized Gauss Netwon, with respect to weight
 tensor([3.9511e-05, 2.0655e-05, 3.5548e-05,  ..., 1.6478e-02, 1.0000e+00,
        1.0000e+00], grad_fn=<MeanBackward1>)
Diagonal exact Generalized Gauss Netwon, with respect to input
 tensor([0.0015, 0.0008, 0.0004, 0.0015, 0.0007], grad_fn=<MeanBackward1>)


### Approx diagonal

In [13]:
# wrt to weight
mse_approx_diagonal_weight = HessianCalculator(wrt = "weight",
                                              loss_func = "mse",
                                              shape = "diagonal",
                                              speed = "fast")
hessian = mse_approx_diagonal_weight.compute_hessian(dataset, model)
print('Diagonal approximated Generalized Gauss Netwon, with respect to weight\n',hessian)


# wrt to input
mse_approx_diagonal_input = HessianCalculator(wrt = "input",
                                              loss_func = "mse",
                                              shape = "diagonal",
                                              speed = "fast")
hessian = mse_approx_diagonal_input.compute_hessian(dataset, model)
print('Diagonal approximated Generalized Gauss Netwon, with respect to input\n',hessian)

Diagonal approximated Generalized Gauss Netwon, with respect to weight
 tensor([8.3223e-05, 4.5165e-05, 7.3964e-05,  ..., 1.6478e-02, 1.0000e+00,
        1.0000e+00], grad_fn=<MeanBackward1>)
Diagonal approximated Generalized Gauss Netwon, with respect to input
 tensor([0.0018, 0.0024, 0.0019, 0.0014, 0.0020], grad_fn=<MeanBackward1>)


# Contrastive loss

In [14]:
_batch_size = 20
dataset = torch.randn(_batch_size, _input_size)

# indexes for constrastive loss
ap = [0,1]
p = [2,3]
an = [4,5,6]
n = [7,8,9] 
tuple_indices = tuple((ap, p, an, n))

### Exact Diagonal

In [15]:
# wrt to weight
contrastive_exact_diagonal_weight = HessianCalculator(wrt = "weight",
                                                      loss_func = "contrastive",
                                                      shape = "diagonal",
                                                      speed = "half")
hessian = contrastive_exact_diagonal_weight.compute_hessian(dataset, model, tuple_indices=tuple_indices)
print('Diagonal exact Generalized Gauss Netwon, with respect to weight\n',hessian)


# wrt to input
contrastive_exact_diagonal_input = HessianCalculator(wrt = "input",
                                                     loss_func = "mse",
                                                     shape = "diagonal",
                                                     speed = "half")
hessian = contrastive_exact_diagonal_input.compute_hessian(dataset, model, tuple_indices=tuple_indices)
print('Diagonal exact Generalized Gauss Netwon, with respect to input\n',hessian)

Diagonal exact Generalized Gauss Netwon, with respect to weight
 tensor([-1.0632e-04, -3.6515e-05,  1.6400e-04,  ...,  3.5362e-03,
         0.0000e+00,  0.0000e+00], grad_fn=<SubBackward0>)
Diagonal exact Generalized Gauss Netwon, with respect to input
 tensor([0.0008, 0.0006, 0.0003, 0.0012, 0.0005], grad_fn=<MeanBackward1>)


### Approx diagonal

In [16]:
# wrt to weight
contrastive_approx_diagonal_weight = HessianCalculator(wrt = "weight",
                                                       loss_func = "contrastive",
                                                       shape = "diagonal",
                                                       speed = "fast")
hessian = contrastive_approx_diagonal_weight.compute_hessian(dataset, model, tuple_indices=tuple_indices)
print('Diagonal approximated Generalized Gauss Netwon, with respect to weight\n',hessian)


# wrt to input
contrastive_approx_diagonal_input = HessianCalculator(wrt = "input",
                                                      loss_func = "contrastive",
                                                      shape = "diagonal",
                                                      speed = "fast")
hessian = contrastive_approx_diagonal_input.compute_hessian(dataset, model, tuple_indices=tuple_indices)
print('Diagonal approximated Generalized Gauss Netwon, with respect to input\n',hessian)

Diagonal approximated Generalized Gauss Netwon, with respect to weight
 tensor([-1.9523e-04,  1.2209e-05,  4.2422e-04,  ...,  3.5362e-03,
         0.0000e+00,  0.0000e+00], grad_fn=<SubBackward0>)
Diagonal approximated Generalized Gauss Netwon, with respect to input
 tensor([ 2.2839e-03, -1.3082e-03,  1.2383e-05, -1.6364e-03,  3.6427e-04],
       grad_fn=<SubBackward0>)
