In [1]:
import torch
from stochman import nnj
from stochman.hessian import HessianCalculator

ModuleNotFoundError: No module named 'stochman.laplace'

# Model

In [None]:
# model hyperparameter
_input_size = 5
_hidden_size = 36
_output_size = 2

# declare the model
model = nnj.Sequential([
                        nnj.Tanh(),
                        nnj.Linear(_input_size, _hidden_size),
                        nnj.Tanh(),
                        nnj.Linear(_hidden_size, _hidden_size),
                        nnj.Reshape(1,6,6),
                        nnj.Conv2d(1, 1, 3, stride=1, padding=1),
                        nnj.MaxPool2d(2),
                        nnj.Upsample(scale_factor=2),
                        nnj.Flatten(),
                        nnj.Tanh(),
                        nnj.Linear(_hidden_size, _hidden_size),
                        nnj.Tanh(),
                        nnj.L2Norm(),
                        nnj.Linear(_hidden_size, _output_size),
                        nnj.L2Norm()
                        ],
            add_hooks = True
)

# MSE loss

In [None]:
# dataset (random)
_batch_size = 13
dataset = torch.randn(_batch_size, _input_size)

### Exact Diagonal

In [None]:
# wrt to weight
mse_exact_diagonal_weight = HessianCalculator(wrt = "weight",
                                              loss_func = "mse",
                                              shape = "diagonal",
                                              speed = "half")
hessian = mse_exact_diagonal_weight.compute_hessian(dataset, model)
print('Diagonal exact Generalized Gauss Netwon, with respect to weight\n',hessian)


# wrt to input
mse_exact_diagonal_input = HessianCalculator(wrt = "input",
                                             loss_func = "mse",
                                             shape = "diagonal",
                                             speed = "half")
hessian = mse_exact_diagonal_input.compute_hessian(dataset, model)
print('Diagonal exact Generalized Gauss Netwon, with respect to input\n',hessian)

Diagonal exact Generalized Gauss Netwon, with respect to weight
 tensor([8.7947e-03, 1.0705e-02, 7.6393e-03,  ..., 2.4208e-02, 2.4775e+01,
        1.8568e+00], grad_fn=<MeanBackward1>)
Diagonal exact Generalized Gauss Netwon, with respect to input
 tensor([0.0225, 0.0420, 0.0104, 0.0273, 0.0212], grad_fn=<MeanBackward1>)


### Approx diagonal

In [None]:
# wrt to weight
mse_approx_diagonal_weight = HessianCalculator(wrt = "weight",
                                              loss_func = "mse",
                                              shape = "diagonal",
                                              speed = "fast")
hessian = mse_approx_diagonal_weight.compute_hessian(dataset, model)
print('Diagonal approximated Generalized Gauss Netwon, with respect to weight\n',hessian)


# wrt to input
mse_approx_diagonal_input = HessianCalculator(wrt = "input",
                                              loss_func = "mse",
                                              shape = "diagonal",
                                              speed = "fast")
hessian = mse_approx_diagonal_input.compute_hessian(dataset, model)
print('Diagonal approximated Generalized Gauss Netwon, with respect to input\n',hessian)

Diagonal approximated Generalized Gauss Netwon, with respect to weight
 tensor([3.2199e-03, 5.4905e-03, 3.5651e-03,  ..., 2.4208e-02, 2.4775e+01,
        1.8568e+00], grad_fn=<MeanBackward1>)
Diagonal approximated Generalized Gauss Netwon, with respect to input
 tensor([0.0164, 0.0128, 0.0135, 0.0100, 0.0089], grad_fn=<MeanBackward1>)


# Contrastive loss

In [None]:
_batch_size = 20
dataset = torch.randn(_batch_size, _input_size)

# indexes for constrastive loss
ap = [0,1]
p = [2,3]
an = [4,5,6]
n = [7,8,9] 
tuple_indices = tuple((ap, p, an, n))

### Exact Diagonal

In [None]:
# wrt to weight
contrastive_exact_diagonal_weight = HessianCalculator(wrt = "weight",
                                                      loss_func = "contrastive",
                                                      shape = "diagonal",
                                                      speed = "half")
hessian = contrastive_exact_diagonal_weight.compute_hessian(dataset, model, tuple_indices=tuple_indices)
print('Diagonal exact Generalized Gauss Netwon, with respect to weight\n',hessian)


# wrt to input
contrastive_exact_diagonal_input = HessianCalculator(wrt = "input",
                                                     loss_func = "contrastive",
                                                     shape = "diagonal",
                                                     speed = "half")
hessian = contrastive_exact_diagonal_input.compute_hessian(dataset, model, tuple_indices=tuple_indices)
print('Diagonal exact Generalized Gauss Netwon, with respect to input\n',hessian)

Diagonal exact Generalized Gauss Netwon, with respect to weight
 tensor([ 5.9148e-02,  3.9651e-02,  1.6767e-03,  ..., -3.4360e-01,
        -2.5451e+01, -1.6038e+01], grad_fn=<SubBackward0>)
Diagonal exact Generalized Gauss Netwon, with respect to input
 tensor([0.0644, 0.0762, 0.0124, 0.0223, 0.0263], grad_fn=<MeanBackward1>)


### Approx diagonal

In [None]:
# wrt to weight
contrastive_approx_diagonal_weight = HessianCalculator(wrt = "weight",
                                                       loss_func = "contrastive",
                                                       shape = "diagonal",
                                                       speed = "fast")
hessian = contrastive_approx_diagonal_weight.compute_hessian(dataset, model, tuple_indices=tuple_indices)
print('Diagonal approximated Generalized Gauss Netwon, with respect to weight\n',hessian)


# wrt to input
contrastive_approx_diagonal_input = HessianCalculator(wrt = "input",
                                                      loss_func = "contrastive",
                                                      shape = "diagonal",
                                                      speed = "fast")
hessian = contrastive_approx_diagonal_input.compute_hessian(dataset, model, tuple_indices=tuple_indices)
print('Diagonal approximated Generalized Gauss Netwon, with respect to input\n',hessian)

Diagonal approximated Generalized Gauss Netwon, with respect to weight
 tensor([-2.6338e-03, -1.6888e-02, -4.0947e-02,  ..., -3.4360e-01,
        -2.5451e+01, -1.6038e+01], grad_fn=<SubBackward0>)
Diagonal approximated Generalized Gauss Netwon, with respect to input
 tensor([-0.0384, -0.0512, -0.0033, -0.0304, -0.0149], grad_fn=<SubBackward0>)
