In [21]:
import torch

from backpack import backpack, extend
from backpack.extensions import BatchGrad
from backpack.extensions.firstorder.base import FirstOrderModuleExtension

from LANAM.models.activation.exu import ExU
from LANAM.models.extended_laplace.curvature.extensions import BatchGradExU

In [11]:
# register module-computation mapping
extension = BatchGrad()
extension.set_module_extension(ExU, BatchGradExU())

In [14]:
batch_size = 10
batch_axis = 0
input_size = 4
device = 'cpu'

inputs = torch.randn(batch_size, input_size, device=device)
targets = torch.randint(0, 2, (batch_size,), device=device)

reduction = ["mean", "sum"][1]
my_module = ExU(input_size, 2).to(device)
lossfunc = torch.nn.CrossEntropyLoss(reduction=reduction).to(device)

  return _no_grad_trunc_normal_(tensor, mean, std, a, b)


In [15]:
grad_batch_autograd = []

for input_n, target_n in zip(
    inputs.split(1, dim=batch_axis), targets.split(1, dim=batch_axis)
):
    loss_n = lossfunc(my_module(input_n), target_n)
    grad_n = torch.autograd.grad(loss_n, [my_module.bias])[0]
    grad_batch_autograd.append(grad_n)

grad_batch_autograd = torch.stack(grad_batch_autograd)

print("bias.shape:             ", my_module.bias.shape)
print("grad_batch_autograd.shape:", grad_batch_autograd.shape)

bias.shape:              torch.Size([4])
grad_batch_autograd.shape: torch.Size([10, 4])


In [16]:
my_module = extend(my_module)
lossfunc = extend(lossfunc)

loss = lossfunc(my_module(inputs), targets)

with backpack(extension):
    loss.backward()

grad_batch_backpack = my_module.bias.grad_batch

print("weight.shape:             ", my_module.bias.shape)
print("grad_batch_backpack.shape:", grad_batch_backpack.shape)

weight.shape:              torch.Size([4])
grad_batch_backpack.shape: torch.Size([10, 4])


In [17]:
match = torch.allclose(grad_batch_autograd, grad_batch_backpack)

print(f"autograd and BackPACK individual gradients match? {match}")

if not match:
    raise AssertionError(
        "Individual gradients don't match:"
        + f"\n{grad_batch_autograd}\nvs.\n{grad_batch_backpack}"
    )

autograd and BackPACK individual gradients match? True


In [18]:
grad_batch_autograd = []

for input_n, target_n in zip(
    inputs.split(1, dim=batch_axis), targets.split(1, dim=batch_axis)
):
    loss_n = lossfunc(my_module(input_n), target_n)
    grad_n = torch.autograd.grad(loss_n, [my_module.weight])[0]
    grad_batch_autograd.append(grad_n)

grad_batch_autograd = torch.stack(grad_batch_autograd)

print("weight.shape:             ", my_module.weight.shape)
print("grad_batch_autograd.shape:", grad_batch_autograd.shape)

weight.shape:              torch.Size([4, 2])
grad_batch_autograd.shape: torch.Size([10, 4, 2])


In [19]:
my_module = extend(my_module)
lossfunc = extend(lossfunc)

loss = lossfunc(my_module(inputs), targets)

with backpack(extension):
    loss.backward()

grad_batch_backpack = my_module.weight.grad_batch

print("weight.shape:             ", my_module.weight.shape)
print("grad_batch_backpack.shape:", grad_batch_backpack.shape)

weight.shape:              torch.Size([4, 2])
grad_batch_backpack.shape: torch.Size([10, 4, 2])


In [20]:
match = torch.allclose(grad_batch_autograd, grad_batch_backpack)

print(f"autograd and BackPACK individual gradients match? {match}")

if not match:
    raise AssertionError(
        "Individual gradients don't match:"
        + f"\n{grad_batch_autograd}\nvs.\n{grad_batch_backpack}"
    )

autograd and BackPACK individual gradients match? True
