In [None]:
from typing import Callable
import torch
import torch.nn as nn

In [None]:
def get_ladj_fast(func: Callable, z: torch.Tensor) -> torch.Tensor:
    # Copied from zuko
    # with torch.enable_grad():
    #     z = z.clone().requires_grad_()
    #     x = func(z)
    x = func(z)
    dj = torch.autograd.grad(x, z, torch.ones_like(x), create_graph=True)[0]
    ladj = torch.log(torch.abs(dj))
    ladj = torch.distributions.utils._sum_rightmost(ladj, -1)
    return ladj

In [None]:
def get_ladj_slow(func: Callable, z: torch.Tensor) -> torch.Tensor:
    ndim = z.shape[0]
    ladj = torch.zeros(ndim)
    for i in range(ndim):
        jac = torch.autograd.functional.jacobian(func, z[i, :], create_graph=True)
        ladj[i] = torch.log(torch.abs(torch.linalg.det(jac)))
    return ladj

In [None]:
ndim = 2
nsamp = 10

z = torch.randn(nsamp, ndim)
z.requires_grad_()
print(z)

In [None]:
def func(z: torch.Tensor) -> torch.Tensor:
    return torch.sin(z)

print(get_ladj_slow(func, z)[:10])
print(get_ladj_fast(func, z)[:10])

In [None]:
def func(z: torch.Tensor) -> torch.Tensor:
    layer = torch.nn.Linear(ndim, ndim)
    return layer(z)

print(get_ladj_slow(func, z)[:10])
print(get_ladj_fast(func, z)[:10])