
# Example using all extensions

Basic example showing how compute the gradient,
and and other quantities with BackPACK,
on a linear model for MNIST.


In [4]:
import sys
print(sys.path)


['/Users/davidsuckrow/Documents/Developing/bachelor_thesis/experiments/exp_03_searching_values', '/opt/anaconda3/envs/taylor/lib/python38.zip', '/opt/anaconda3/envs/taylor/lib/python3.8', '/opt/anaconda3/envs/taylor/lib/python3.8/lib-dynload', '', '/opt/anaconda3/envs/taylor/lib/python3.8/site-packages', '/opt/anaconda3/envs/taylor/lib/python3.8/site-packages/setuptools/_vendor']


Let's start by loading some dummy data and extending the model



In [5]:
from torch import rand
from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential

from backpack import backpack, extend
from backpack.extensions import (
    GGNMP,
    HMP,
    KFAC,
    KFLR,
    KFRA,
    PCHMP,
    BatchDiagGGNExact,
    BatchDiagGGNMC,
    BatchDiagHessian,
    BatchGrad,
    BatchL2Grad,
    DiagGGNExact,
    DiagGGNMC,
    DiagHessian,
    SqrtGGNExact,
    SqrtGGNMC,
    SumGradSquared,
    Variance,
)
from backpack.utils.examples import load_one_batch_mnist

X, y = load_one_batch_mnist(batch_size=512)

model = Sequential(Flatten(), Linear(784, 10))
lossfunc = CrossEntropyLoss()

model = extend(model)
lossfunc = extend(lossfunc)

## First order extensions



Batch gradients



In [6]:
loss = lossfunc(model(X), y)
with backpack(BatchGrad()):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".grad_batch.shape:       ", param.grad_batch.shape)

1.weight
.grad.shape:              torch.Size([10, 784])
.grad_batch.shape:        torch.Size([512, 10, 784])
1.bias
.grad.shape:              torch.Size([10])
.grad_batch.shape:        torch.Size([512, 10])


Variance



In [7]:
loss = lossfunc(model(X), y)
with backpack(Variance()):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".variance.shape:         ", param.variance.shape)

1.weight
.grad.shape:              torch.Size([10, 784])
.variance.shape:          torch.Size([10, 784])
1.bias
.grad.shape:              torch.Size([10])
.variance.shape:          torch.Size([10])


Second moment/sum of gradients squared



In [8]:
loss = lossfunc(model(X), y)
with backpack(SumGradSquared()):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".sum_grad_squared.shape: ", param.sum_grad_squared.shape)

1.weight
.grad.shape:              torch.Size([10, 784])
.sum_grad_squared.shape:  torch.Size([10, 784])
1.bias
.grad.shape:              torch.Size([10])
.sum_grad_squared.shape:  torch.Size([10])


L2 norm of individual gradients



In [9]:
loss = lossfunc(model(X), y)
with backpack(BatchL2Grad()):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".batch_l2.shape:         ", param.batch_l2.shape)

1.weight
.grad.shape:              torch.Size([10, 784])
.batch_l2.shape:          torch.Size([512])
1.bias
.grad.shape:              torch.Size([10])
.batch_l2.shape:          torch.Size([512])


It's also possible to ask for multiple quantities at once



In [10]:
loss = lossfunc(model(X), y)
with backpack(BatchGrad(), Variance(), SumGradSquared(), BatchL2Grad()):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".grad_batch.shape:       ", param.grad_batch.shape)
    print(".variance.shape:         ", param.variance.shape)
    print(".sum_grad_squared.shape: ", param.sum_grad_squared.shape)
    print(".batch_l2.shape:         ", param.batch_l2.shape)

1.weight
.grad.shape:              torch.Size([10, 784])
.grad_batch.shape:        torch.Size([512, 10, 784])
.variance.shape:          torch.Size([10, 784])
.sum_grad_squared.shape:  torch.Size([10, 784])
.batch_l2.shape:          torch.Size([512])
1.bias
.grad.shape:              torch.Size([10])
.grad_batch.shape:        torch.Size([512, 10])
.variance.shape:          torch.Size([10])
.sum_grad_squared.shape:  torch.Size([10])
.batch_l2.shape:          torch.Size([512])


## Second order extensions



Diagonal of the generalized Gauss-Newton and its Monte-Carlo approximation



In [11]:
loss = lossfunc(model(X), y)
with backpack(DiagGGNExact(), DiagGGNMC(mc_samples=1)):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".diag_ggn_mc.shape:      ", param.diag_ggn_mc.shape)
    print(".diag_ggn_exact.shape:   ", param.diag_ggn_exact.shape)

1.weight
.grad.shape:              torch.Size([10, 784])
.diag_ggn_mc.shape:       torch.Size([10, 784])
.diag_ggn_exact.shape:    torch.Size([10, 784])
1.bias
.grad.shape:              torch.Size([10])
.diag_ggn_mc.shape:       torch.Size([10])
.diag_ggn_exact.shape:    torch.Size([10])


Per-sample diagonal of the generalized Gauss-Newton and its Monte-Carlo approximation



In [12]:
loss = lossfunc(model(X), y)
with backpack(BatchDiagGGNExact(), BatchDiagGGNMC(mc_samples=1)):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".diag_ggn_mc_batch.shape:      ", param.diag_ggn_mc_batch.shape)
    print(".diag_ggn_exact_batch.shape:   ", param.diag_ggn_exact_batch.shape)

1.weight
.diag_ggn_mc_batch.shape:       torch.Size([512, 10, 784])
.diag_ggn_exact_batch.shape:    torch.Size([512, 10, 784])
1.bias
.diag_ggn_mc_batch.shape:       torch.Size([512, 10])
.diag_ggn_exact_batch.shape:    torch.Size([512, 10])


KFAC, KFRA and KFLR



In [13]:
loss = lossfunc(model(X), y)
with backpack(KFAC(mc_samples=1), KFLR(), KFRA()):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".kfac (shapes):          ", [kfac.shape for kfac in param.kfac])
    print(".kflr (shapes):          ", [kflr.shape for kflr in param.kflr])
    print(".kfra (shapes):          ", [kfra.shape for kfra in param.kfra])

1.weight
.grad.shape:              torch.Size([10, 784])
.kfac (shapes):           [torch.Size([10, 10]), torch.Size([784, 784])]
.kflr (shapes):           [torch.Size([10, 10]), torch.Size([784, 784])]
.kfra (shapes):           [torch.Size([10, 10]), torch.Size([784, 784])]
1.bias
.grad.shape:              torch.Size([10])
.kfac (shapes):           [torch.Size([10, 10])]
.kflr (shapes):           [torch.Size([10, 10])]
.kfra (shapes):           [torch.Size([10, 10])]


Diagonal Hessian and per-sample diagonal Hessian



In [14]:
loss = lossfunc(model(X), y)
with backpack(DiagHessian(), BatchDiagHessian()):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".diag_h.shape:           ", param.diag_h.shape)
    print(".diag_h_batch.shape:     ", param.diag_h_batch.shape)

1.weight
.grad.shape:              torch.Size([10, 784])
.diag_h.shape:            torch.Size([10, 784])
.diag_h_batch.shape:      torch.Size([512, 10, 784])
1.bias
.grad.shape:              torch.Size([10])
.diag_h.shape:            torch.Size([10])
.diag_h_batch.shape:      torch.Size([512, 10])


Matrix square root of the generalized Gauss-Newton or its Monte-Carlo approximation



In [15]:
loss = lossfunc(model(X), y)
with backpack(SqrtGGNExact(), SqrtGGNMC(mc_samples=1)):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".sqrt_ggn_exact.shape:   ", param.sqrt_ggn_exact.shape)
    print(".sqrt_ggn_mc.shape:      ", param.sqrt_ggn_mc.shape)

1.weight
.grad.shape:              torch.Size([10, 784])
.sqrt_ggn_exact.shape:    torch.Size([10, 512, 10, 784])
.sqrt_ggn_mc.shape:       torch.Size([1, 512, 10, 784])
1.bias
.grad.shape:              torch.Size([10])
.sqrt_ggn_exact.shape:    torch.Size([10, 512, 10])
.sqrt_ggn_mc.shape:       torch.Size([1, 512, 10])


## Block-diagonal curvature products



Curvature-matrix product (``MP``) extensions provide functions
that multiply with the block diagonal of different curvature matrices, such as

- the Hessian (:code:`HMP`)
- the generalized Gauss-Newton (:code:`GGNMP`)
- the positive-curvature Hessian (:code:`PCHMP`)



In [16]:
loss = lossfunc(model(X), y)

with backpack(
    HMP(),
    GGNMP(),
    PCHMP(savefield="pchmp_clip", modify="clip"),
    PCHMP(savefield="pchmp_abs", modify="abs"),
):
    loss.backward()

Multiply a random vector with curvature blocks.



In [17]:
V = 1

for name, param in model.named_parameters():
    vec = rand(V, *param.shape)
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print("vec.shape:               ", vec.shape)
    print(".hmp(vec).shape:         ", param.hmp(vec).shape)
    print(".ggnmp(vec).shape:       ", param.ggnmp(vec).shape)
    print(".pchmp_clip(vec).shape:  ", param.pchmp_clip(vec).shape)
    print(".pchmp_abs(vec).shape:   ", param.pchmp_abs(vec).shape)

1.weight
.grad.shape:              torch.Size([10, 784])
vec.shape:                torch.Size([1, 10, 784])
.hmp(vec).shape:          torch.Size([1, 10, 784])
.ggnmp(vec).shape:        torch.Size([1, 10, 784])
.pchmp_clip(vec).shape:   torch.Size([1, 10, 784])
.pchmp_abs(vec).shape:    torch.Size([1, 10, 784])
1.bias
.grad.shape:              torch.Size([10])
vec.shape:                torch.Size([1, 10])
.hmp(vec).shape:          torch.Size([1, 10])
.ggnmp(vec).shape:        torch.Size([1, 10])
.pchmp_clip(vec).shape:   torch.Size([1, 10])
.pchmp_abs(vec).shape:    torch.Size([1, 10])


Multiply a collection of three vectors (a matrix) with curvature blocks.



In [18]:
V = 3

for name, param in model.named_parameters():
    vec = rand(V, *param.shape)
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print("vec.shape:               ", vec.shape)
    print(".hmp(vec).shape:         ", param.hmp(vec).shape)
    print(".ggnmp(vec).shape:       ", param.ggnmp(vec).shape)
    print(".pchmp_clip(vec).shape:  ", param.pchmp_clip(vec).shape)
    print(".pchmp_abs(vec).shape:   ", param.pchmp_abs(vec).shape)

1.weight
.grad.shape:              torch.Size([10, 784])
vec.shape:                torch.Size([3, 10, 784])
.hmp(vec).shape:          torch.Size([3, 10, 784])
.ggnmp(vec).shape:        torch.Size([3, 10, 784])
.pchmp_clip(vec).shape:   torch.Size([3, 10, 784])
.pchmp_abs(vec).shape:    torch.Size([3, 10, 784])
1.bias
.grad.shape:              torch.Size([10])
vec.shape:                torch.Size([3, 10])
.hmp(vec).shape:          torch.Size([3, 10])
.ggnmp(vec).shape:        torch.Size([3, 10])
.pchmp_clip(vec).shape:   torch.Size([3, 10])
.pchmp_abs(vec).shape:    torch.Size([3, 10])
