# Benchmarks - Kernels

In MagmaClust and GPs in general, kernels are what makes it possible to compute a covariance matrix for any set of points, given a set of hyperparameters. They are used in many innermost loops, and their implementation is critical for performance.

**Main considerations when implementing kernels**

A good kernel implementation must be:
* fast, as it is used in many innermost loops
* usable at many dimensions, including a batch dimension with distinct hyperparameters for each element
* work on padded inputs (aka inputs with NaNs), maybe using a mask
* jittable, as it is used in many jit-compiled functions
* modular, as kernels can be combined in many ways
* both static and instance-based, to either carry around hyperparameters or be called with them
* easy to override, as users may want to define their own kernels

These goals are conflicting in some cases (e.g: a jittable version of a kernel is not trivial to write for most people), and the best implementation will depend on the specific use case.

---
## Setup

In [1]:
from kernax import ConstantKernel

# Jax configuration
USE_JIT = True
USE_X64 = True
DEBUG_NANS = False
VERBOSE = False

In [2]:
# Standard library imports
import os

os.environ['JAX_ENABLE_X64'] = str(USE_X64).lower()

import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

In [3]:
# Third party
import jax

jax.config.update("jax_disable_jit", not USE_JIT)
jax.config.update("jax_debug_nans", DEBUG_NANS)

In [4]:
# Third party
from jax import numpy as jnp
from jax import random as jr
from jax import vmap

import numpy as np

In [5]:
# Local
import kernax

In [6]:
# Config
key = jr.PRNGKey(0)

INFO:2025-11-21 14:31:28,179:jax._src.xla_bridge:812: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/miniconda3/envs/Kernax/bin/../lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)
2025-11-21 14:31:28,179 - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/miniconda3/envs/Kernax/bin/../lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)


---
## Data

In [7]:
B = 250  # Batch size
M = 100  # Number of points
N = 4  # Input dimension

In [8]:
# For pairwise covariance
input_1D_a = jnp.array([0.])
input_1D_b = jnp.array([1.])

input_ND_a = jnp.zeros(N)
input_ND_b = jnp.ones(N)

input_1D_a.shape, input_1D_b.shape, input_ND_a.shape, input_ND_b.shape

((1,), (1,), (4,), (4,))

In [9]:
# For cross-covariance
input_1D_grid_regular = jnp.linspace(0., 10., M).reshape(-1, 1)
key, subkey = jr.split(key)
input_1D_grid_irregular = jr.uniform(subkey, shape=(M, 1), minval=0., maxval=10.)

key, subkey = jr.split(key)
input_1D_grid_padded = input_1D_grid_irregular.copy()
nan_indices = jr.choice(subkey, M, shape=(int(M * 0.2),), replace=False)
for idx in nan_indices:
	input_1D_grid_padded = input_1D_grid_padded.at[idx, 0].set(jnp.nan)

# Create a grid with M^(1/N) points per dimension, then reshape to M points total
points_per_dim = int(np.ceil(M ** (1 / N)))
grid_1d = jnp.linspace(0., 5., points_per_dim)
grids = jnp.meshgrid(*[grid_1d for _ in range(N)], indexing='ij')
input_ND_grid_regular = jnp.stack([g.flatten() for g in grids], axis=-1)[:M]

key, subkey = jr.split(key)
input_ND_grid_irregular = jr.uniform(subkey, shape=(M, N), minval=0., maxval=10.)

# Padded version is a copy of the irregular version, with some NaNs
input_ND_grid_padded = input_ND_grid_irregular.copy()
key, subkey = jr.split(key)
nan_indices = jr.choice(subkey, M, shape=(int(M * 0.2),), replace=False)
for idx in nan_indices:
	dim_to_nan = jr.randint(subkey, (), 0, N)
	input_ND_grid_padded = input_ND_grid_padded.at[idx, dim_to_nan].set(jnp.nan)

input_1D_grid_regular.shape, input_1D_grid_irregular.shape, input_1D_grid_padded.shape, input_ND_grid_regular.shape, input_ND_grid_irregular.shape, input_ND_grid_padded.shape

((100, 1), (100, 1), (100, 1), (100, 4), (100, 4), (100, 4))

In [108]:
batched_input_1D_grid_regular = jnp.linspace(0., 10., M).repeat(B, -1).reshape(M, B, 1).swapaxes(0, 1)
key, subkey = jr.split(key)
batched_input_1D_grid_irregular = jr.uniform(subkey, shape=(B, M, 1), minval=0., maxval=10.)
batched_input_1D_grid_padded = batched_input_1D_grid_irregular.copy()
nan_indices = jr.choice(subkey, M, shape=(int(M * 0.2),), replace=False)
for b in range(B):
	for idx in nan_indices:
		batched_input_1D_grid_padded = batched_input_1D_grid_padded.at[b, idx, 0].set(jnp.nan)

batched_input_ND_grid_regular = jnp.tile(input_ND_grid_regular[None, :, :], (B, 1, 1))

key, subkey = jr.split(key)
batched_input_ND_grid_irregular = jr.uniform(subkey, shape=(B, M, N), minval=0., maxval=10.)

key, subkey = jr.split(key)
batched_input_ND_grid_padded = batched_input_ND_grid_irregular.copy()
nan_indices = jr.choice(subkey, M, shape=(int(M * 0.2),), replace=False)
for b in range(B):
	for idx in nan_indices:
		dim_to_nan = jr.randint(subkey, (), 0, N)
		batched_input_ND_grid_padded = batched_input_ND_grid_padded.at[b, idx, dim_to_nan].set(jnp.nan)

batched_input_1D_grid_regular.shape, batched_input_1D_grid_irregular.shape, batched_input_1D_grid_padded.shape, batched_input_ND_grid_regular.shape, batched_input_ND_grid_irregular.shape, batched_input_ND_grid_padded.shape

((250, 100, 1),
 (250, 100, 1),
 (250, 100, 1),
 (250, 100, 4),
 (250, 100, 4),
 (250, 100, 4))

---
## Current implementation

In [11]:
old_kernel = kernax.RBFKernel(length_scale=jnp.array(0.3), variance=jnp.array(1.0))
old_batched_kernel = old_kernel  # No change, as previous kernels were always batched

---
## Custom implementation(s)

*Start by copy-pasting the original function from the Kernax module, then bring modifications*

### Defaults of the previous implementation that we wish to correct/improve:

**We want kernels to be Equinox Modules**

This will make management of hyperparameters easier, as they will be stored as attributes of the kernel object. It also abstracts the whole "kernels as pytrees" logic away from the user. It might make some features (freezing hyperparameters, saving/loading models, handling static arguments like `active_dims`, compose kernels with wrappers, adding constraints like positivity to HPs, ...) easier to implement.

**We want to abstract batch/ARD handling away from the base AbstractKernel class**

Having kernels that work on batches of inputs, producing batches of covariance matrix is a very specific use-case, only really useful for Magma's multi-task GP setting. This code should not be in the base class, but rather in a mixin or a wrapper class that can be used when needed.

ARD handling can also be used by a wrapper kernel, that will transform the inputs before passing them to the base kernel.

An optimisation kernel that transforms HPs (exponentiation, softplus, ...) can also be used as a wrapper only for optimisation, without affecting the base kernel.

In [12]:
import jax.numpy as jnp
from jax import jit, vmap
from jax.tree_util import register_pytree_node_class, tree_map
from jax.tree import reduce
from jax.lax import cond

import equinox as eqx

from functools import partial


class StaticAbstractKernel:
	@classmethod
	@partial(jit, static_argnums=(0,))
	def pairwise_cov(cls, kern, x1: jnp.ndarray, x2: jnp.ndarray) -> jnp.ndarray:
		"""
		Compute the kernel covariance value between two vectors.

		:param kern: the kernel to use, containing hyperparameters
		:param x1: scalar array
		:param x2: scalar array
		:return: scalar array
		"""
		return jnp.array(jnp.nan)  # To be overwritten in subclasses

	@classmethod
	@partial(jit, static_argnums=(0,))
	def pairwise_cov_if_not_nan(cls, kern, x1: jnp.ndarray, x2: jnp.ndarray) -> jnp.ndarray:
		"""
		Returns NaN if either x1 or x2 is NaN, otherwise calls the compute_scalar method.

		:param kern: the kernel to use, containing hyperparameters
		:param x1: scalar array
		:param x2: scalar array
		:return: scalar array
		"""
		return cond(jnp.any(jnp.isnan(x1) | jnp.isnan(x2)),
		            lambda _: jnp.nan,
		            lambda _: cls.pairwise_cov(kern, x1, x2),
		            None)

	@classmethod
	@partial(jit, static_argnums=(0,))
	def cross_cov_vector(cls, kern, x1: jnp.ndarray, x2: jnp.ndarray) -> jnp.ndarray:
		"""
		Compute the kernel cross covariance values between an array of vectors (matrix) and a vector.

		:param kern: the kernel to use, containing hyperparameters
		:param x1: vector array (N, )
		:param x2: scalar array
		:return: vector array (N, )
		"""
		return vmap(lambda x: cls.pairwise_cov_if_not_nan(kern, x, x2), in_axes=0)(x1)

	@classmethod
	@partial(jit, static_argnums=(0,))
	def cross_cov_vector_if_not_nan(cls, kern, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Returns an array of NaN if scalar is NaN, otherwise calls the compute_vector method.

		:param kern: the kernel to use, containing hyperparameters
		:param x1: vector array (N, )
		:param x2: scalar array
		:param kwargs: hyperparameters of the kernel
		:return: vector array (N, )
		"""
		return cond(jnp.any(jnp.isnan(x2)),
		            lambda _: jnp.full(len(x1), jnp.nan),
		            lambda _: cls.cross_cov_vector(kern, x1, x2),
		            None)

	@classmethod
	@partial(jit, static_argnums=(0,))
	def cross_cov_matrix(cls, kern, x1: jnp.ndarray, x2: jnp.ndarray) -> jnp.ndarray:
		"""
		Compute the kernel covariance matrix between two vector arrays.

		:param x1: vector array (N, )
		:param x2: vector array (M, )
		:return: matrix array (N, M)
		"""
		return vmap(lambda x: cls.cross_cov_vector_if_not_nan(kern, x2, x), in_axes=0)(x1)


class AbstractKernel(eqx.Module):
	"""
	# TODO: check Equinox __str__ and __repr__ methods and adapt if needed
	def __str__(self):
		return f"{self.__class__.__name__}({', '.join([f'{key}={value}' for key, value in self.__dict__.items() if key not in self.static_attributes])})"

	def __repr__(self):
		return str(self)
	"""

	@jit
	def __call__(self, x1, x2=None):
		# If no x2 is provided, we compute the covariance between x1 and itself
		if x2 is None:
			x2 = x1

		# Turn scalar inputs into vectors
		x1, x2 = jnp.atleast_1d(x1), jnp.atleast_1d(x2)

		# Call the appropriate method
		if jnp.ndim(x1) == 1 and jnp.ndim(x2) == 1:
			return self.static_class.pairwise_cov_if_not_nan(self, x1, x2)
		elif jnp.ndim(x1) == 2 and jnp.ndim(x2) == 1:
			return self.static_class.cross_cov_vector_if_not_nan(self, x1, x2)
		elif jnp.ndim(x1) == 1 and jnp.ndim(x2) == 2:
			return self.static_class.cross_cov_vector_if_not_nan(self, x2, x1)
		elif jnp.ndim(x1) == 2 and jnp.ndim(x2) == 2:
			return self.static_class.cross_cov_matrix(self, x1, x2)
		else:
			raise ValueError(
				f"Invalid input dimensions: x1 has shape {x1.shape}, x2 has shape {x2.shape}. "
				"Expected 1D, 2D arrays or 3D arrays for batched inputs."
			)

	def __add__(self, other):
		from kernax.OperatorKernels import SumKernel
		return SumKernel(self, other)

	def __radd__(self, other):
		from kernax.OperatorKernels import SumKernel
		return SumKernel(other, self)

	def __neg__(self):
		from kernax.WrapperKernels import NegKernel
		return NegKernel(self)

	def __mul__(self, other):
		from kernax.OperatorKernels import ProductKernel
		return ProductKernel(self, other)

	def __rmul__(self, other):
		from kernax.OperatorKernels import ProductKernel
		return ProductKernel(other, self)


In [13]:
from jax import jit
from jax.tree_util import register_pytree_node_class
from jax import numpy as jnp

from functools import partial


# from kernax import StaticAbstractKernel, AbstractKernel


class StaticSEKernel(StaticAbstractKernel):
	@classmethod
	@partial(jit, static_argnums=(0,))
	def pairwise_cov(cls, kern: AbstractKernel, x1: jnp.ndarray, x2: jnp.ndarray) -> jnp.ndarray:
		"""
		Compute the kernel covariance value between two vectors.

		:param kern: the kernel to use, containing a `length_scale` parameters
		:param x1: scalar array
		:param x2: scalar array
		:return: scalar array
		"""
		return jnp.exp(-0.5 * ((x1 - x2) @ (x1 - x2)) / kern.length_scale ** 2)


class SEKernel(AbstractKernel):
	"""
	Squared Exponential (aka "RBF" or "Gaussian") Kernel
	"""
	length_scale: jax.Array = eqx.field(converter=jax.numpy.asarray)
	static_class = StaticSEKernel

	def __init__(self, length_scale):
		super().__init__()
		self.length_scale = length_scale


In [14]:
class BatchKernel(AbstractKernel):
	"""
	Wrapper kernel to add batch handling to any kernel.

	A basic kernel usually works on inputs of shape (N, I), and produces covariance matrices of shape (N, N).

	Wrapped inside a batch kernel, they can either:
	- still work on inputs of shape (N, I), but produce covariance matrices of shape (B, N, N), where B is the batch size. This is useful when the hyperparameters are batched, i.e. each batch element has its own set of hyperparameters.
	- or work on inputs of shape (B, N, I), producing covariance matrices of shape (B, N, N). This is useful when the inputs are batched, regardless of whether the hyperparameters are batched or not.

	A batch kernel can itself be wrapped inside another batch kernel, to handle multiple batch dimensions/hyperparameter sets.

	This class uses vmap to vectorize the kernel computation over the batch dimension.
	"""
	inner_kernel: AbstractKernel = eqx.field()
	batch_in_axes: bool = eqx.field(static=True)
	batch_over_inputs: int|None = eqx.field(static=True)

	def __init__(self, inner_kernel, batch_size, batch_in_axes=None, batch_over_inputs=True):
		"""
		:param inner_kernel: the kernel to wrap, must be an instance of AbstractKernel
		:param batch_size: the size of the batch (int)
		:param batch_in_axes: a value or pytree indicating which hyperparameters are batched (0)
							   or shared (None) across the batch.
							   If None, all hyperparameters are assumed to be shared across the batch.
							   If 0, all hyperparameters are assumed to be batched across the batch.
							   If a pytree, it must have the same structure as inner_kernel, with hyperparameter
							   leaves being either 0 (batched) or None (shared).
		:param batch_over_inputs: whether to expect inputs of shape (B, N, I) (True) or (N, I) (False)
		"""
		# TODO: batch_size isn't needed if hyperparameters are shared
		# TODO: explicit error message when batch_in_axes is all None and batch_over_inputs is False, as that makes vmap (and a Batch Kernel) useless
		# Default: all array hyperparameters are shared (None for all array leaves)
		if batch_in_axes is None:
			# Extract only array leaves and map them to None
			self.batch_in_axes = jax.tree_util.tree_map(lambda _: None, inner_kernel)
		elif batch_in_axes == 0:
			# All hyperparameters are batched
			self.batch_in_axes = jax.tree_util.tree_map(lambda _: 0, inner_kernel)
		else:
			self.batch_in_axes = batch_in_axes

		self.batch_over_inputs = 0 if batch_over_inputs else None

		# Add batch dimension to parameters where batch_in_axes is 0
		self.inner_kernel = jax.tree_util.tree_map(
			lambda param, batch_in_ax: param if batch_in_ax is None else jnp.repeat(param[None, ...], batch_size, axis=0),
			inner_kernel,
			self.batch_in_axes
		)


	def __call__(self, x1, x2=None):
		"""
		Compute the kernel over batched inputs using vmap.

		Args:
			x1: Input of shape (B, ..., N, I)
			x2: Optional second input of shape (B, ..., M, I)

		Returns:
			Kernel matrix of appropriate shape with batch dimension
		"""
		# vmap over the batch dimension of inner_kernel and inputs
		# Each batch element gets its own version of inner_kernel with corresponding hyperparameters
		return vmap(
			lambda kernel, x1, x2: kernel(x1, x2),
			in_axes=(self.batch_in_axes, self.batch_over_inputs, self.batch_over_inputs if x2 is not None else None)
		)(self.inner_kernel, x1, x2)

In [94]:
new_kernel = SEKernel(length_scale=jnp.array(0.3))
new_batched_kernel_SHP = BatchKernel(inner_kernel=new_kernel, batch_size=B)  # Shared hyperparameters
new_batched_kernel_DHP = BatchKernel(inner_kernel=new_kernel, batch_size=B, batch_in_axes=0)  # Distinct hyperparameters
new_batched_kernel_DHP_BI = BatchKernel(inner_kernel=new_kernel, batch_size=B, batch_in_axes=0, batch_over_inputs=False)  # Distinct hyperparameters, Batched inputs

---
## Comparison

### Pair-wise covariance

#### 1D inputs

In [19]:
jnp.allclose(old_kernel(input_1D_a, input_1D_b), new_kernel(input_1D_a, input_1D_b))

Array(True, dtype=bool)

In [20]:
%%timeit
old_kernel(input_1D_a, input_1D_b).block_until_ready()

4.32 μs ± 148 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [21]:
%%timeit
new_kernel(input_1D_a, input_1D_b).block_until_ready()

4.22 μs ± 158 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


#### ND inputs

In [22]:
jnp.allclose(old_kernel(input_ND_a, input_ND_b), new_kernel(input_ND_a, input_ND_b))

Array(True, dtype=bool)

In [23]:
%%timeit
old_kernel(input_ND_a, input_ND_b).block_until_ready()

4.24 μs ± 27.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [24]:
%%timeit
new_kernel(input_ND_a, input_ND_b).block_until_ready()

4.39 μs ± 23.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


#### ND inputs with active_dims

In [25]:
# TODO

#### ND inputs with ARD

In [26]:
# TODO

### Cross-covariance (Gram matrix)

#### On regular 1D input grid

In [27]:
jnp.allclose(old_kernel(input_1D_grid_regular), new_kernel(input_1D_grid_regular))

Array(True, dtype=bool)

In [28]:
%%timeit
old_kernel(input_1D_grid_regular).block_until_ready()

23.5 μs ± 106 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [29]:
%%timeit
new_kernel(input_1D_grid_regular).block_until_ready()

23.6 μs ± 103 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### On irregular 1D input grid

In [30]:
jnp.allclose(old_kernel(input_1D_grid_irregular), new_kernel(input_1D_grid_irregular))

Array(True, dtype=bool)

In [31]:
%%timeit
old_kernel(input_1D_grid_irregular).block_until_ready()

23.5 μs ± 497 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [32]:
%%timeit
new_kernel(input_1D_grid_irregular).block_until_ready()

23.8 μs ± 307 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### On padded 1D input grid

In [33]:
jnp.allclose(old_kernel(input_1D_grid_padded), new_kernel(input_1D_grid_padded), equal_nan=True)

Array(True, dtype=bool)

In [34]:
%%timeit
old_kernel(input_1D_grid_padded).block_until_ready()

23.5 μs ± 1.01 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [35]:
%%timeit
new_kernel(input_1D_grid_padded).block_until_ready()

27.1 μs ± 569 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### On regular ND input grid

In [36]:
jnp.allclose(old_kernel(input_ND_grid_regular), new_kernel(input_ND_grid_regular))

Array(True, dtype=bool)

In [37]:
%%timeit
old_kernel(input_ND_grid_regular).block_until_ready()

58.2 μs ± 2.96 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [38]:
%%timeit
new_kernel(input_ND_grid_regular).block_until_ready()

56.2 μs ± 1.84 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### On irregular ND input grid

In [39]:
np.asarray(old_kernel(input_ND_grid_irregular))

array([[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, ..., 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, ..., 0.0000000e+00,
        0.0000000e+00, 2.5218078e-38],
       ...,
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 1.0000000e+00,
        0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 0.0000000e+00,
        1.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 2.5218078e-38, ..., 0.0000000e+00,
        0.0000000e+00, 1.0000000e+00]], shape=(100, 100), dtype=float32)

In [40]:
np.asarray(new_kernel(input_ND_grid_irregular))

array([[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, ..., 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, ..., 0.0000000e+00,
        0.0000000e+00, 2.5218078e-38],
       ...,
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 1.0000000e+00,
        0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 0.0000000e+00,
        1.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 2.5218078e-38, ..., 0.0000000e+00,
        0.0000000e+00, 1.0000000e+00]], shape=(100, 100), dtype=float32)

In [41]:
jnp.allclose(old_kernel(input_ND_grid_irregular), new_kernel(input_ND_grid_irregular))

Array(True, dtype=bool)

In [42]:
%%timeit
old_kernel(input_ND_grid_irregular).block_until_ready()

54.7 μs ± 141 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [43]:
%%timeit
new_kernel(input_ND_grid_irregular).block_until_ready()

54.1 μs ± 356 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### On padded ND input grid

In [44]:
np.asarray(old_kernel(input_ND_grid_padded))

array([[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,           nan,
        0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, ...,           nan,
        0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, ...,           nan,
        0.0000000e+00, 2.5218078e-38],
       ...,
       [          nan,           nan,           nan, ...,           nan,
                  nan,           nan],
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,           nan,
        1.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 2.5218078e-38, ...,           nan,
        0.0000000e+00, 1.0000000e+00]], shape=(100, 100), dtype=float32)

In [45]:
np.asarray(new_kernel(input_ND_grid_padded))

array([[1.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,           nan,
        0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 1.0000000e+00, 0.0000000e+00, ...,           nan,
        0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, ...,           nan,
        0.0000000e+00, 2.5218078e-38],
       ...,
       [          nan,           nan,           nan, ...,           nan,
                  nan,           nan],
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,           nan,
        1.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 2.5218078e-38, ...,           nan,
        0.0000000e+00, 1.0000000e+00]], shape=(100, 100), dtype=float32)

In [46]:
jnp.allclose(old_kernel(input_ND_grid_padded), new_kernel(input_ND_grid_padded), equal_nan=True)

Array(True, dtype=bool)

In [47]:
%%timeit
old_kernel(input_ND_grid_padded).block_until_ready()

54.2 μs ± 795 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [48]:
%%timeit
new_kernel(input_ND_grid_padded).block_until_ready()

52.8 μs ± 731 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### On ND input grid with active_dims

In [49]:
# TODO

#### On ND input grid with ARD

In [50]:
# TODO

### Batched cross-covariance - Shared hyperparameters

#### Regular 1D inputs

In [51]:
jnp.allclose(old_batched_kernel(batched_input_1D_grid_regular), new_batched_kernel_SHP(batched_input_1D_grid_regular))

Array(True, dtype=bool)

In [52]:
%%timeit
old_batched_kernel(batched_input_1D_grid_regular).block_until_ready()

279 μs ± 787 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [53]:
%%timeit
new_batched_kernel_SHP(batched_input_1D_grid_regular).block_until_ready()

376 μs ± 1.35 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### Irregular 1D inputs

In [54]:
jnp.allclose(old_batched_kernel(batched_input_1D_grid_irregular), new_batched_kernel_SHP(batched_input_1D_grid_irregular))

Array(True, dtype=bool)

In [55]:
%%timeit
old_batched_kernel(batched_input_1D_grid_irregular).block_until_ready()

279 μs ± 1.05 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [56]:
%%timeit
new_batched_kernel_SHP(batched_input_1D_grid_irregular).block_until_ready()

379 μs ± 1.01 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### Padded 1D inputs

In [57]:
jnp.allclose(old_batched_kernel(batched_input_1D_grid_padded), new_batched_kernel_SHP(batched_input_1D_grid_padded),
             equal_nan=True)

Array(True, dtype=bool)

In [58]:
%%timeit
old_batched_kernel(batched_input_1D_grid_padded).block_until_ready()

253 μs ± 515 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [59]:
%%timeit
new_batched_kernel_SHP(batched_input_1D_grid_padded).block_until_ready()

359 μs ± 1.74 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### Regular ND inputs

In [60]:
jnp.allclose(old_batched_kernel(batched_input_ND_grid_regular), new_batched_kernel_SHP(batched_input_ND_grid_regular))

Array(True, dtype=bool)

In [61]:
%%timeit
old_batched_kernel(batched_input_ND_grid_regular).block_until_ready()

4.38 ms ± 5.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [62]:
%%timeit
new_batched_kernel_SHP(batched_input_ND_grid_regular).block_until_ready()

4.49 ms ± 14 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


#### Irregular ND inputs

In [63]:
jnp.allclose(old_batched_kernel(batched_input_ND_grid_irregular), new_batched_kernel_SHP(batched_input_ND_grid_irregular))

Array(True, dtype=bool)

In [64]:
%%timeit
old_batched_kernel(batched_input_ND_grid_irregular).block_until_ready()

4.39 ms ± 14.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [65]:
%%timeit
new_batched_kernel_SHP(batched_input_ND_grid_irregular).block_until_ready()

4.49 ms ± 21.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


#### Padded ND inputs

In [66]:
jnp.allclose(old_batched_kernel(batched_input_ND_grid_padded), new_batched_kernel_SHP(batched_input_ND_grid_padded),
             equal_nan=True)

Array(True, dtype=bool)

In [67]:
%%timeit
old_batched_kernel(batched_input_ND_grid_padded).block_until_ready()

4.32 ms ± 7.45 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [68]:
%%timeit
new_batched_kernel_SHP(batched_input_ND_grid_padded).block_until_ready()

4.45 ms ± 14.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Batched cross-covariance - Distinct hyperparameters

#### Regular 1D inputs

In [69]:
jnp.allclose(old_batched_kernel(batched_input_1D_grid_regular), new_batched_kernel_DHP(batched_input_1D_grid_regular))

Array(True, dtype=bool)

In [70]:
%%timeit
old_batched_kernel(batched_input_1D_grid_regular).block_until_ready()

279 μs ± 416 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [71]:
%%timeit
new_batched_kernel_DHP(batched_input_1D_grid_regular).block_until_ready()

389 μs ± 2.26 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### Irregular 1D inputs

In [72]:
jnp.allclose(old_batched_kernel(batched_input_1D_grid_irregular), new_batched_kernel_DHP(batched_input_1D_grid_irregular))

Array(True, dtype=bool)

In [73]:
%%timeit
old_batched_kernel(batched_input_1D_grid_irregular).block_until_ready()

279 μs ± 476 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [74]:
%%timeit
new_batched_kernel_DHP(batched_input_1D_grid_irregular).block_until_ready()

395 μs ± 1.54 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### Padded 1D inputs

In [75]:
jnp.allclose(old_batched_kernel(batched_input_1D_grid_padded), new_batched_kernel_DHP(batched_input_1D_grid_padded),
             equal_nan=True)

Array(True, dtype=bool)

In [76]:
%%timeit
old_batched_kernel(batched_input_1D_grid_padded).block_until_ready()

255 μs ± 2.47 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [77]:
%%timeit
new_batched_kernel_DHP(batched_input_1D_grid_padded).block_until_ready()

373 μs ± 976 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### Regular ND inputs

In [78]:
jnp.allclose(old_batched_kernel(batched_input_ND_grid_regular), new_batched_kernel_DHP(batched_input_ND_grid_regular))

Array(True, dtype=bool)

In [79]:
%%timeit
old_batched_kernel(batched_input_ND_grid_regular).block_until_ready()

4.41 ms ± 18.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [80]:
%%timeit
new_batched_kernel_DHP(batched_input_ND_grid_regular).block_until_ready()

4.7 ms ± 109 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


#### Irregular ND inputs

In [81]:
jnp.allclose(old_batched_kernel(batched_input_ND_grid_irregular), new_batched_kernel_DHP(batched_input_ND_grid_irregular))

Array(True, dtype=bool)

In [82]:
%%timeit
old_batched_kernel(batched_input_ND_grid_irregular).block_until_ready()

4.4 ms ± 4.86 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [83]:
%%timeit
new_batched_kernel_DHP(batched_input_ND_grid_irregular).block_until_ready()

4.71 ms ± 153 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


#### Padded ND inputs

In [84]:
jnp.allclose(old_batched_kernel(batched_input_ND_grid_padded), new_batched_kernel_DHP(batched_input_ND_grid_padded),
             equal_nan=True)

Array(True, dtype=bool)

In [85]:
%%timeit
old_batched_kernel(batched_input_ND_grid_padded).block_until_ready()

4.36 ms ± 28.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [86]:
%%timeit
new_batched_kernel_DHP(batched_input_ND_grid_padded).block_until_ready()

4.47 ms ± 5.86 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### On un-batched input with distinct hyperparameters

In [111]:
jnp.allclose(new_batched_kernel_SHP(batched_input_1D_grid_regular), new_batched_kernel_DHP_BI(input_1D_grid_regular))

Array(True, dtype=bool)

---
## Conclusion

---