# Autodiff benchmarks

In [21]:
import jax
from jax import numpy as jnp
from jax import jit, vmap, grad

In [22]:
USE_JIT = True
jax.config.update("jax_disable_jit", not USE_JIT)

## RBF kernel

In [23]:
@jit
def rbf_kernel(input1, input2, lengthscale, variance):
	"""
	Compute the RBF kernel between two inputs.

	Args:
		input1: First input array.
		input2: Second input array.
		lengthscale: Lengthscale parameter of the RBF kernel.
		variance: Variance parameter of the RBF kernel.

	Returns:
		Kernel value between input1 and input2.
	"""
	sq_dist = jnp.sum((input1 - input2) ** 2) / (lengthscale ** 2)
	return variance * jnp.exp(-0.5 * sq_dist)

In [24]:
# autodiff grad of rbf with respect to lengthscale and variance
rbf_kernel_autodiff_grad = jit(grad(rbf_kernel, argnums=(2, 3)))

In [25]:
# explicit computation of the RBF kernel gradient with respect to lengthscale and variance
@jit
def rbf_kernel_explicit_grad(input1, input2, lengthscale, variance):
	"""
	Compute the explicit gradient of the RBF kernel with respect to lengthscale and variance.

	Args:
		input1: First input array.
		input2: Second input array.
		lengthscale: Lengthscale parameter of the RBF kernel.
		variance: Variance parameter of the RBF kernel.

	Returns:
		Tuple of gradients with respect to lengthscale and variance.
	"""
	sq_dist = jnp.sum((input1 - input2) ** 2) / (lengthscale ** 2)
	dk_dl = variance * jnp.exp(-0.5 * sq_dist) * (sq_dist / lengthscale)
	dk_dv = jnp.exp(-0.5 * sq_dist)
	return dk_dl, dk_dv

In [26]:
rbf_kernel_autodiff_grad(1, 2, 0.3, 1.)

(Array(0.14318226, dtype=float32, weak_type=True),
 Array(0.00386592, dtype=float32, weak_type=True))

In [27]:
rbf_kernel_explicit_grad(1, 2, 0.3, 1.)

(Array(0.14318225, dtype=float32, weak_type=True),
 Array(0.00386592, dtype=float32, weak_type=True))

In [28]:
%timeit rbf_kernel_autodiff_grad(1, 2, 0.3, 1.)[0].block_until_ready()

13.1 μs ± 56.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [29]:
%timeit rbf_kernel_explicit_grad(1, 2, 0.3, 1.)[0].block_until_ready()

13 μs ± 34.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


## Multivariate Normal Log Likelihood


In [30]:
x1 = jnp.array([1.0, 2.0, 3.0])
x2 = jnp.array([-1.0, 0., 1.0])

In [31]:
@jit
def mvn_log_pdf(x, mu, lengthscale, variance):
	"""
	Compute the log of a multivariate normal probability density function.

	Args:
		x: Input array.
		mu: Mean vector.
		sigma: Covariance matrix.
		lengthscale: Lengthscale parameter of the covariance matrix.
		variance: Variance parameter of the covariance matrix.

	Returns:
		Log probability density value.
	"""
	cross_cov = vmap(vmap(rbf_kernel, in_axes=(None, 0, None, None)), in_axes=(0, None, None, None))(x, x, lengthscale, variance)

	cov_det = jnp.linalg.det(cross_cov)
	cross_cov_L = jnp.linalg.cholesky(cross_cov)
	z = jnp.linalg.solve(cross_cov_L, x - mu)

	normalization_term = -0.5 * len(x) * jnp.log(2 * jnp.pi)
	determinant_term = -0.5 * jnp.log(cov_det)
	quadratic_term = -0.5 * (z.T @ z)

	return normalization_term + determinant_term + quadratic_term

In [32]:
mvn_log_pdf_autodiff_grad = jit(grad(mvn_log_pdf, argnums=(2, 3)))

In [33]:
@jit
def mvn_log_pdf_explicit_grad(x, mu, lengthscale, variance):
	"""
	Compute the explicit gradient of the multivariate normal log PDF with respect to lengthscale and variance.

	Args:
		x: Input array.
		mu: Mean vector.
		lengthscale: Lengthscale parameter of the covariance matrix.
		variance: Variance parameter of the covariance matrix.

	Returns:
		Tuple of gradients with respect to lengthscale and variance.
	"""
	cross_cov = vmap(vmap(rbf_kernel, in_axes=(None, 0, None, None)), in_axes=(0, None, None, None))(x, x, lengthscale, variance)
	cross_cov_grads = vmap(vmap(rbf_kernel_explicit_grad, in_axes=(None, 0, None, None)), in_axes=(0, None, None, None))(x, x, lengthscale, variance)
	cross_cov_L = jnp.linalg.cholesky(cross_cov)

	z = jnp.linalg.solve(cross_cov_L.T, jnp.linalg.solve(cross_cov_L, x - mu))
	quad_term_l = 0.5 * (z.T @ cross_cov_grads[0] @ z)
	quad_term_v = 0.5 * (z.T @ cross_cov_grads[1] @ z)

	# use solve, but it's exactly the same as `trace(inv(cross_cov) * cross_cov_grads)`
	det_term_l = -0.5 * jnp.trace(jnp.linalg.solve(cross_cov_L.T, jnp.linalg.solve(cross_cov_L, cross_cov_grads[0])))
	det_term_v = -0.5 * jnp.trace(jnp.linalg.solve(cross_cov_L.T, jnp.linalg.solve(cross_cov_L, cross_cov_grads[1])))

	return quad_term_l + det_term_l, quad_term_v + det_term_v

In [34]:
mvn_log_pdf_autodiff_grad(x1, jnp.array([0.0, 0.0, 0.0]), 0.3, 1.0)

(Array(1.1333827, dtype=float32, weak_type=True),
 Array(5.4692526, dtype=float32, weak_type=True))

In [35]:
mvn_log_pdf_explicit_grad(x1, jnp.array([0.0, 0.0, 0.0]), 0.3, 1.0)

(Array(1.1333826, dtype=float32), Array(5.469252, dtype=float32))

In [36]:
%timeit mvn_log_pdf_autodiff_grad(x1, jnp.array([0.0, 0.0, 0.0]), 0.3, 1.0)[0].block_until_ready()

58.1 μs ± 399 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [37]:
%timeit mvn_log_pdf_explicit_grad(x1, jnp.array([0.0, 0.0, 0.0]), 0.3, 1.0)[0].block_until_ready()

59.4 μs ± 367 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [38]:
# On a random, large array
x3 = jax.random.uniform(jax.random.PRNGKey(0), (1000,))
mu = jnp.zeros_like(x3)

In [39]:
%timeit mvn_log_pdf_autodiff_grad(x3, mu, 0.3, 1.0)[0].block_until_ready()

19.4 ms ± 870 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [40]:
%timeit mvn_log_pdf_explicit_grad(x3, mu, 0.3, 1.0)[0].block_until_ready()

17.6 ms ± 182 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
