## Hessian trace estimation


This notebook shows how to efficiently approximate the trace of a Hessian, using automatic differentiation ([PyTorch](https://pytorch.org)) and a recently published randomized algorithm called [Hutch++](https://arxiv.org/abs/2010.09649).

In [1]:
import torch
torch.set_printoptions(precision=3)

In [2]:
class LinearOperator(object):
    def __init__(self, matvec):
        self._matvec = matvec

    def matvec(self, vecs):
        return self._matvec(vecs)

In [3]:
def hutchpp(A, d, m):
    """https://arxiv.org/abs/2010.09649

    A is the LinearOperator whose trace to estimate
    d is the input dimension
    m is the number of queries (larger m yields better estimates)
    """
    S = torch.randn(d, m // 3)
    G = torch.randn(d, m // 3)
    Q, _ = torch.qr(A.matvec(S))
    proj = G - Q @ (Q.T @ G)
    return torch.trace(Q.T @ A.matvec(Q)) + (3./m)*torch.trace(proj.T @ A.matvec(proj))

In [4]:
torch.manual_seed(0)
d = 1000

B = torch.randn(d, d)
A = B.T @ B
torch.trace(A)

tensor(999781.188)

In [5]:
m = 100
estimate = hutchpp(LinearOperator(lambda vec: A@vec), d, m)
estimate

tensor(982108.562)

In [6]:
percent_error = 100*(estimate - torch.trace(A)).abs() / torch.trace(A)
percent_error

tensor(1.768)

In [7]:
def make_hvp(f, x):
    def hvp(vecs):
        # torch.autograd.functional.vhp doesn't support batching
        vecs = torch.split(vecs, 1, dim=1)
        products = []
        for v in vecs:
            _, product = torch.autograd.functional.vhp(f, x, v.squeeze())
            products.append(product)
        return torch.stack(products, dim=1)
    return LinearOperator(hvp)

In [8]:
def cubic(x):
    return (x**3).mean()


x = torch.arange(5, dtype=torch.float)
hvp = make_hvp(cubic, x)
hessian = hvp.matvec(torch.eye(x.nelement()))
hessian

tensor([[0.000, 0.000, 0.000, 0.000, 0.000],
        [0.000, 1.200, 0.000, 0.000, 0.000],
        [0.000, 0.000, 2.400, 0.000, 0.000],
        [0.000, 0.000, 0.000, 3.600, 0.000],
        [0.000, 0.000, 0.000, 0.000, 4.800]])

In [9]:
x = torch.arange(10000, dtype=torch.float)
hvp = make_hvp(cubic, x)
%time hessian = hvp.matvec(torch.eye(x.nelement()))
torch.trace(hessian)

CPU times: user 3.14 s, sys: 275 ms, total: 3.42 s
Wall time: 3.27 s


tensor(29996.998)

In [10]:
%time estimate = hutchpp(hvp, d=x.nelement(), m=100)
estimate

CPU times: user 72.4 ms, sys: 11.5 ms, total: 83.9 ms
Wall time: 40.4 ms


tensor(29615.984)

In [11]:
percent_error = 100*(estimate - torch.trace(hessian)).abs() / torch.trace(hessian)
percent_error

tensor(1.270)