In [3]:
from typing import Callable
import torch
import torch.nn as nn

In [45]:
def get_ladj_fast(func: Callable, z: torch.Tensor) -> torch.Tensor:
    # Copied from zuko
    # with torch.enable_grad():
    #     z = z.clone().requires_grad_()
    #     x = func(z)
    x = func(z)
    dj = torch.autograd.grad(x, z, torch.ones_like(x), create_graph=True)[0]
    ladj = torch.log(torch.abs(dj))
    ladj = torch.distributions.utils._sum_rightmost(ladj, -1)
    return ladj

In [46]:
def get_ladj_slow(func: Callable, z: torch.Tensor) -> torch.Tensor:
    ndim = z.shape[0]
    ladj = torch.zeros(ndim)
    for i in range(ndim):
        jac = torch.autograd.functional.jacobian(func, z[i, :], create_graph=True)
        ladj[i] = torch.log(torch.abs(torch.linalg.det(jac)))
    return ladj

In [49]:
ndim = 2
nsamp = 10

z = torch.randn(nsamp, ndim)
z.requires_grad_()
print(z)

tensor([[ 1.4742, -0.0877],
        [ 0.4071, -1.3831],
        [ 0.4139, -0.1267],
        [-0.9442, -0.2014],
        [ 0.5780, -0.7723],
        [-2.4239,  0.6479],
        [ 1.6165, -0.9512],
        [-0.0566,  0.5406],
        [-0.2317,  0.8546],
        [ 0.1791,  1.2622]], requires_grad=True)


In [50]:
def func(z: torch.Tensor) -> torch.Tensor:
    return torch.sin(z)

print(get_ladj_slow(func, z)[:10])
print(get_ladj_fast(func, z)[:10])

tensor([-2.3424, -1.7643, -0.0963, -0.5542, -0.5110, -0.5097, -3.6297, -0.1555,
        -0.4479, -1.2077], grad_fn=<SliceBackward0>)
tensor([-2.3424, -1.7643, -0.0963, -0.5542, -0.5110, -0.5097, -3.6297, -0.1555,
        -0.4479, -1.2077], grad_fn=<SliceBackward0>)


In [54]:
def func(z: torch.Tensor) -> torch.Tensor:
    layer = torch.nn.Linear(ndim, ndim)
    return layer(z)

print(get_ladj_slow(func, z)[:10])
print(get_ladj_fast(func, z)[:10])

tensor([-0.5359, -2.5240, -1.1045, -1.0119, -1.6705, -2.3893, -1.3587, -2.0419,
        -1.1163, -1.9069], grad_fn=<SliceBackward0>)
tensor([-1.1761, -1.1761, -1.1761, -1.1761, -1.1761, -1.1761, -1.1761, -1.1761,
        -1.1761, -1.1761], grad_fn=<SliceBackward0>)
