# Efficient Batched Hessian Calculation in PyTorch

In [None]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

In [None]:
from hessian import hessian, jacobian

In [None]:
print(f'Using PyTorch version: {t.__version__}')
!nvidia-smi -L

To get started, here is a naive way of computing the jacobian of a vector-valued function:

In [None]:
def f(weight, bias, x):
    return F.linear(x, weight, bias).tanh()

In [None]:
D = 16
weight = t.randn(D, D)  # affine mapping R^D -> R^D
bias = t.randn(D)
x = t.randn(D)  # feature vector

In [None]:
def compute_jacobian(xp):
    assert xp.dim() == 1
    unit_vectors = t.eye(xp.size(0))
    jacobian_rows = [t.autograd.grad(f(weight, bias, xp), xp, vec)[0]
                     for vec in unit_vectors]
    return t.stack(jacobian_rows)

In [None]:
xp = x.clone().requires_grad_()
print(f'Input shape: {xp.shape}')
my_jacobian = compute_jacobian(xp)
print(f'Jacobian shape: {my_jacobian.shape}')

In [None]:
newxp = xp.clone().requires_grad_()
outputs = f(weight, bias, newxp)
lib_jacobian = jacobian(outputs, newxp)

In [None]:
t.isclose(my_jacobian, lib_jacobian).all()

As we can see, the output of the above is a `[D, D]` matrix, with each row being the jacobian of the ith function value with respect to all inputs.

Computing the Jacobian row-by-row like this is very computationally inefficient, particularly with larger matrices.

Rather than looping, we can vectorise the above using `vmap` and `vjp`: the vector-Jacobian product function:

In [None]:
# from torch.autograd.functional import vjp
from functorch import vmap, vjp

In [None]:
_, vjp_fn = vjp(partial(f, weight, bias), x)
ft_jacobian, = vmap(vjp_fn)(t.eye(x.size(0)))

assert t.allclose(ft_jacobian, my_jacobian)

Functorch provides a handy alias to the above:

In [None]:
from functorch import jacrev
ft_jacobian = jacrev(f, argnums=2)(weight, bias, x)

assert t.allclose(ft_jacobian, jacobian)

We can also flip the problem around and say we want to compute the Jacobians of the parameters to our model (i.e. the weight and bias terms), rather than the input:

In [None]:
ft_jac_weight, ft_jac_bias = jacrev(f, argnums=(0, 1))(weight, bias, x)

Note that if we're computing the Jacobian of a $\mathbb{R}^N \to \mathbb{R}^M$ function (where there are more outputs and inputs $M > N$), then `jacfwd` (the version of the above using forward-mode automatic differentiation) is preferred. Otherwise use `jacrev` which uses the usual AD.

In reverse-mode AD, we compute the Jacobian row-by-row, while in forward-mode AD (which uses Jacobian-vector products), we are computing it column-by-column. Since the Jacobian matrix has $M$ rows and $N$ columns, if it is taller or wider one way, we might prefer the method that deals with fewer rows or columns.

In [None]:
Din = 32
Dout = 2048
weight = t.randn(Dout, Din)
bias = t.randn(Dout)
x = t.randn(Din)

print(f'weight shape: {weight.shape}')

Here we seem to have a taller matrix (taller, that is, than it is wide). Hence, we should use forward mode 

In [None]:
from functorch import jacfwd
using_fwd = jacfwd(f, argnums=2)(weight, bias, x)

If the function f gave fewer outputs than inputs, then we should use `jacrev`.

## Hessian Computation

Recall that a Hessian is merely the Jacobian of the Jacobian; the matrix of second-order derivatives.

In [None]:
from functorch import hessian

Din = 512
Dout = 32
weight = t.randn(Dout, Din)
bias = t.randn(Dout)
x = t.randn(Din)

hess_api = hessian(f, argnums=2)(weight, bias, x)
hess_revrev = jacrev(jacrev(f, argnums=2), argnums=2)(weight, bias, x)
assert t.allclose(hess_api, hess_revrev)

## Batched Computation

In the examples above, we've been using singleton vectors. We usually want to take the Jacobian (and Hessian) of a batch of outputs with respect to a batch of inputs.

Given a batch of inputs of shape `(B, N)` and a function $f: \mathbb{R}^N \to \mathbb{R}^M$, we'd like a Jacobian of shape `(B, M, N)`.

We can vectorise this operation using vmap:

In [None]:
batch_size = 64
Din = 31
Dout = 33

weight = t.randn(Dout, Din)
bias = t.randn(Dout)
print(f'f is a transformation from {weight.size(1)} to {weight.size(0)}')

x = t.randn(batch_size, Din)
print(f'Input batch is of size: {x.shape}')

In [None]:
batch_jacobian_fn = vmap(jacrev(f, argnums=2), in_dims=(None, None, 0))
batch_jacobian = batch_jacobian_fn(weight, bias, x)
print(f'Resulting Jacobian is of size: {batch_jacobian.shape}')

We can compute batched Hessians in a similar manner:

In [None]:
batch_hessian_fn = vmap(hessian(f, argnums=2), in_dims=(None, None, 0))
batch_hess = batch_hessian_fn(weight, bias, x)
print(f'Resulting Hessian is of size: {batch_hess.shape}')

## Advanced usage

In [None]:
batch_size = 64
Din = 32
Dout = 1

weights = t.randn(Dout, Din)

def critic(states, actions):
    assert states.shape == (batch_size, Din)
    assert actions.shape == (batch_size, Din)
    tmp = states + actions.sum(1)[:, None]
    return F.linear(tmp, weights).tanh(), F.linear(tmp, weights).tanh()

def td_err(states, actions, q_target):
    q1, q2 = critic(states, actions)
    return q_target - q1, q_target - q2

states = t.randn(batch_size, Din)
actions = t.randn(batch_size, Din)
q_target = t.randn(batch_size, Dout)

a, b = jacrev(td_err, has_aux=True, argnums=(1, 1))(states, actions, actions, q_target)

In [None]:
b.shape

In [None]:
(a, vjpfunc) = vjp(td_err, states, actions, q_target)

In [None]:
aa, bb = a

In [None]:
bb.shape

In [None]:
b.shape

(jac1, jac2), (td_err_1, td_err_2) = jacrev(td_err, has_aux=True, argnums=1)(states, actions, q_target)