# Benchmarks - Kernels

In MagmaClust and GPs in general, kernels are what makes it possible to compute a covariance matrix for any set of points, given a set of hyperparameters. They are used in many innermost loops, and their implementation is critical for performance.

**Main considerations when implementing kernels**

A good kernel implementation must be:
* fast, as it is used in many innermost loops
* usable at many dimensions, including a batch dimension with distinct hyperparameters for each element
* work on padded inputs (aka inputs with NaNs), maybe using a mask
* jittable, as it is used in many jit-compiled functions
* modular, as kernels can be combined in many ways
* both static and instance-based, to either carry around hyperparameters or be called with them
* easy to override, as users may want to define their own kernels

These goals are conflicting in some cases (e.g: a jittable version of a kernel is not trivial to write for most people), and the best implementation will depend on the specific use case.

---
## Setup

In [1]:
# Jax configuration
USE_JIT = True
USE_X64 = True
DEBUG_NANS = False
VERBOSE = False

In [2]:
# Standard library imports
import os

os.environ['JAX_ENABLE_X64'] = str(USE_X64).lower()

import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

In [3]:
# Third party
import jax

jax.config.update("jax_disable_jit", not USE_JIT)
jax.config.update("jax_debug_nans", DEBUG_NANS)

In [4]:
# Third party
from jax import numpy as jnp
from jax import random as jr
from jax import vmap

import numpy as np

In [5]:
# Local
import kernax
from kernax import ConstantKernel

In [6]:
# Config
key = jr.PRNGKey(0)

INFO:2025-12-02 15:41:30,804:jax._src.xla_bridge:812: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/miniconda3/envs/Kernax/bin/../lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)
2025-12-02 15:41:30,804 - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/miniconda3/envs/Kernax/bin/../lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)


---
## Data

In [7]:
B = 250  # Batch size
M = 100  # Number of points
N = 4  # Input dimension

In [8]:
# For pairwise covariance
input_1D_a = jnp.array([0.])
input_1D_b = jnp.array([1.])

input_ND_a = jnp.zeros(N)
input_ND_b = jnp.ones(N)

input_1D_a.shape, input_1D_b.shape, input_ND_a.shape, input_ND_b.shape

((1,), (1,), (4,), (4,))

In [9]:
# For cross-covariance
input_1D_grid_regular = jnp.linspace(0., 10., M).reshape(-1, 1)
key, subkey = jr.split(key)
input_1D_grid_irregular = jr.uniform(subkey, shape=(M, 1), minval=0., maxval=10.)

key, subkey = jr.split(key)
input_1D_grid_padded = input_1D_grid_irregular.copy()
nan_indices = jr.choice(subkey, M, shape=(int(M * 0.2),), replace=False)
for idx in nan_indices:
	input_1D_grid_padded = input_1D_grid_padded.at[idx, 0].set(jnp.nan)

# Create a grid with M^(1/N) points per dimension, then reshape to M points total
points_per_dim = int(np.ceil(M ** (1 / N)))
grid_1d = jnp.linspace(0., 5., points_per_dim)
grids = jnp.meshgrid(*[grid_1d for _ in range(N)], indexing='ij')
input_ND_grid_regular = jnp.stack([g.flatten() for g in grids], axis=-1)[:M]

key, subkey = jr.split(key)
input_ND_grid_irregular = jr.uniform(subkey, shape=(M, N), minval=0., maxval=10.)

# Padded version is a copy of the irregular version, with some NaNs
input_ND_grid_padded = input_ND_grid_irregular.copy()
key, subkey = jr.split(key)
nan_indices = jr.choice(subkey, M, shape=(int(M * 0.2),), replace=False)
for idx in nan_indices:
	dim_to_nan = jr.randint(subkey, (), 0, N)
	input_ND_grid_padded = input_ND_grid_padded.at[idx, dim_to_nan].set(jnp.nan)

input_1D_grid_regular.shape, input_1D_grid_irregular.shape, input_1D_grid_padded.shape, input_ND_grid_regular.shape, input_ND_grid_irregular.shape, input_ND_grid_padded.shape

((100, 1), (100, 1), (100, 1), (100, 4), (100, 4), (100, 4))

In [10]:
batched_input_1D_grid_regular = jnp.linspace(0., 10., M).repeat(B, -1).reshape(M, B, 1).swapaxes(0, 1)
key, subkey = jr.split(key)
batched_input_1D_grid_irregular = jr.uniform(subkey, shape=(B, M, 1), minval=0., maxval=10.)
batched_input_1D_grid_padded = batched_input_1D_grid_irregular.copy()
nan_indices = jr.choice(subkey, M, shape=(int(M * 0.2),), replace=False)
for b in range(B):
	for idx in nan_indices:
		batched_input_1D_grid_padded = batched_input_1D_grid_padded.at[b, idx, 0].set(jnp.nan)

batched_input_ND_grid_regular = jnp.tile(input_ND_grid_regular[None, :, :], (B, 1, 1))

key, subkey = jr.split(key)
batched_input_ND_grid_irregular = jr.uniform(subkey, shape=(B, M, N), minval=0., maxval=10.)

key, subkey = jr.split(key)
batched_input_ND_grid_padded = batched_input_ND_grid_irregular.copy()
nan_indices = jr.choice(subkey, M, shape=(int(M * 0.2),), replace=False)
for b in range(B):
	for idx in nan_indices:
		dim_to_nan = jr.randint(subkey, (), 0, N)
		batched_input_ND_grid_padded = batched_input_ND_grid_padded.at[b, idx, dim_to_nan].set(jnp.nan)

batched_input_1D_grid_regular.shape, batched_input_1D_grid_irregular.shape, batched_input_1D_grid_padded.shape, batched_input_ND_grid_regular.shape, batched_input_ND_grid_irregular.shape, batched_input_ND_grid_padded.shape

((250, 100, 1),
 (250, 100, 1),
 (250, 100, 1),
 (250, 100, 4),
 (250, 100, 4),
 (250, 100, 4))

---
## Current implementation

In [11]:
old_kernel = kernax.SEKernel(length_scale=jnp.array(0.3))
old_batched_kernel = old_kernel  # No change, as previous kernels were always batched

---
## Custom implementation(s)

*Start by copy-pasting the original function from the Kernax module, then bring modifications*

In [12]:
from functools import partial

import jax.numpy as jnp
from jax import vmap
from jax.lax import cond
import equinox as eqx
from equinox import filter_jit


class AbstractKernel(eqx.Module):
	"""
	# TODO: check Equinox __str__ and __repr__ methods and adapt if needed
	def __str__(self):
		return f"{self.__class__.__name__}({', '.join([f'{key}={value}' for key, value in self.__dict__.items() if key not in self.static_attributes])})"

	def __repr__(self):
		return str(self)
	"""
	@filter_jit
	def __call__(self, x1, x2=None):
		# If no x2 is provided, we compute the covariance between x1 and itself
		if x2 is None:
			x2 = x1


		# Turn scalar inputs into vectors
		x1, x2 = jnp.atleast_1d(x1), jnp.atleast_1d(x2)

		# Call the appropriate method
		if jnp.ndim(x1) == 1 and jnp.ndim(x2) == 1:
			return self.static_class.pairwise_cov_if_not_nan(self, x1, x2)
		elif jnp.ndim(x1) == 2 and jnp.ndim(x2) == 1:
			return self.static_class.cross_cov_vector_if_not_nan(self, x1, x2)
		elif jnp.ndim(x1) == 1 and jnp.ndim(x2) == 2:
			return self.static_class.cross_cov_vector_if_not_nan(self, x2, x1)
		elif jnp.ndim(x1) == 2 and jnp.ndim(x2) == 2:
			return self.static_class.cross_cov_matrix(self, x1, x2)
		else:
			raise ValueError(
				f"Invalid input dimensions: x1 has shape {x1.shape}, x2 has shape {x2.shape}. "
				"Expected 1D, 2D arrays or 3D arrays for batched inputs."
			)

	def __add__(self, other):
		from kernax.OperatorKernels import SumKernel
		return SumKernel(self, other)

	def __radd__(self, other):
		from kernax.OperatorKernels import SumKernel
		return SumKernel(other, self)

	def __neg__(self):
		from kernax.WrapperKernels import NegKernel
		return NegKernel(self)

	def __mul__(self, other):
		from kernax.OperatorKernels import ProductKernel
		return ProductKernel(self, other)

	def __rmul__(self, other):
		from kernax.OperatorKernels import ProductKernel
		return ProductKernel(other, self)


class StaticAbstractKernel:
	@classmethod
	@filter_jit
	def pairwise_cov(cls, kern: AbstractKernel, x1: jnp.ndarray, x2: jnp.ndarray) -> jnp.ndarray:
		"""
		Compute the kernel covariance value between two vectors.

		:param kern: kernel instance containing the hyperparameters
		:param x1: scalar array
		:param x2: scalar array
		:return: scalar array
		"""
		return jnp.array(jnp.nan)  # To be overwritten in subclasses

	@classmethod
	@filter_jit
	def pairwise_cov_if_not_nan(cls, kern: AbstractKernel, x1: jnp.ndarray, x2: jnp.ndarray) -> jnp.ndarray:
		"""
		Returns NaN if either x1 or x2 is NaN, otherwise calls the compute_scalar method.

		:param kern: kernel instance containing the hyperparameters
		:param x1: scalar array
		:param x2: scalar array
		:return: scalar array
		"""
		return cond(jnp.any(jnp.isnan(x1) | jnp.isnan(x2)),
		            lambda _: jnp.nan,
		            lambda _: cls.pairwise_cov(kern, x1, x2),
		            None)

	@classmethod
	@filter_jit
	def cross_cov_vector(cls, kern: AbstractKernel, x1: jnp.ndarray, x2: jnp.ndarray) -> jnp.ndarray:
		"""
		Compute the kernel cross covariance values between an array of vectors (matrix) and a vector.

		:param kern: kernel instance containing the hyperparameters
		:param x1: vector array (N, )
		:param x2: scalar array
		:return: vector array (N, )
		"""
		return vmap(lambda x: cls.pairwise_cov_if_not_nan(kern, x, x2), in_axes=0)(x1)

	@classmethod
	@filter_jit
	def cross_cov_vector_if_not_nan(cls, kern: AbstractKernel, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Returns an array of NaN if scalar is NaN, otherwise calls the compute_vector method.

		:param kern: kernel instance containing the hyperparameters
		:param x1: vector array (N, )
		:param x2: scalar array
		:param kwargs: hyperparameters of the kernel
		:return: vector array (N, )
		"""
		return cond(jnp.any(jnp.isnan(x2)),
		            lambda _: jnp.full(len(x1), jnp.nan),
		            lambda _: cls.cross_cov_vector(kern, x1, x2),
		            None)

	@classmethod
	@filter_jit
	def cross_cov_matrix(cls, kern: AbstractKernel, x1: jnp.ndarray, x2: jnp.ndarray) -> jnp.ndarray:
		"""
		Compute the kernel covariance matrix between two vector arrays.

		:param kern: kernel instance containing the hyperparameters
		:param x1: vector array (N, )
		:param x2: vector array (M, )
		:return: matrix array (N, M)
		"""
		return vmap(lambda x: cls.cross_cov_vector_if_not_nan(kern, x2, x), in_axes=0)(x1)


In [13]:
from functools import partial

from jax import Array
from jax import numpy as jnp
import equinox as eqx
from equinox import filter_jit

from kernax import StaticAbstractKernel, AbstractKernel


class StaticSEKernel(StaticAbstractKernel):
	@classmethod
	@filter_jit
	def pairwise_cov(cls, kern: AbstractKernel, x1: jnp.ndarray, x2: jnp.ndarray) -> jnp.ndarray:
		"""
		Compute the kernel covariance value between two vectors.

		:param kern: kernel instance containing the hyperparameters
		:param x1: scalar array
		:param x2: scalar array
		:return: scalar array
		"""
		kern = eqx.combine(kern)
		return jnp.exp(-0.5 * ((x1 - x2) @ (x1 - x2)) / kern.length_scale ** 2)


class SEKernel(AbstractKernel):
	"""
	Squared Exponential (aka "RBF" or "Gaussian") Kernel
	"""
	length_scale: Array = eqx.field(converter=jnp.asarray)
	static_class = StaticSEKernel

	def __init__(self, length_scale):
		super().__init__()
		self.length_scale = length_scale


In [14]:
from functools import partial

import jax.numpy as jnp
from jax import jit, vmap
from jax.lax import cond
import jax.tree_util as jtu
import equinox as eqx

from kernax import StaticAbstractKernel, AbstractKernel, ConstantKernel


class WrapperKernel(AbstractKernel):
	""" Class for kernels that perform some operation on the output of another "inner" kernel."""
	inner_kernel: AbstractKernel = eqx.field()

	def __init__(self, inner_kernel=None):
		"""
		Instantiates a wrapper kernel with the given inner kernel.

		:param inner_kernel: the inner kernel to wrap
		"""
		# If the inner kernel is not a kernel, we try to convert it to a ConstantKernel
		if not isinstance(inner_kernel, AbstractKernel):
			inner_kernel = ConstantKernel(value=inner_kernel)

		self.inner_kernel = inner_kernel



class BatchKernel(WrapperKernel):
	"""
	Wrapper kernel to add batch handling to any kernel.

	A basic kernel usually works on inputs of shape (N, I), and produces covariance matrices of shape (N, N).

	Wrapped inside a batch kernel, they can either:
	- still work on inputs of shape (N, I), but produce covariance matrices of shape (B, N, N), where B is the batch size. This is useful when the hyperparameters are batched, i.e. each batch element has its own set of hyperparameters.
	- or work on inputs of shape (B, N, I), producing covariance matrices of shape (B, N, N). This is useful when the inputs are batched, regardless of whether the hyperparameters are batched or not.

	A batch kernel can itself be wrapped inside another batch kernel, to handle multiple batch dimensions/hyperparameter sets.

	This class uses vmap to vectorize the kernel computation over the batch dimension.
	"""
	inner_kernel: AbstractKernel = eqx.field()
	batch_in_axes: bool = eqx.field(static=True)
	batch_over_inputs: int|None = eqx.field(static=True)

	def __init__(self, inner_kernel, batch_size, batch_in_axes=None, batch_over_inputs=True):
		"""
		:param inner_kernel: the kernel to wrap, must be an instance of AbstractKernel
		:param batch_size: the size of the batch (int)
		:param batch_in_axes: a value or pytree indicating which hyperparameters are batched (0)
							   or shared (None) across the batch.
							   If None, all hyperparameters are assumed to be shared across the batch.
							   If 0, all hyperparameters are assumed to be batched across the batch.
							   If a pytree, it must have the same structure as inner_kernel, with hyperparameter
							   leaves being either 0 (batched) or None (shared).
		:param batch_over_inputs: whether to expect inputs of shape (B, N, I) (True) or (N, I) (False)
		"""
		# Initialize the WrapperKernel
		super().__init__(inner_kernel=inner_kernel)

		# TODO: batch_size isn't needed if hyperparameters are shared
		# TODO: explicit error message when batch_in_axes is all None and batch_over_inputs is False, as that makes vmap (and a Batch Kernel) useless
		# Default: all array hyperparameters are shared (None for all array leaves)
		if batch_in_axes is None:
			# Extract only array leaves and map them to None
			self.batch_in_axes = jtu.tree_map(lambda _: None, inner_kernel)
		elif batch_in_axes == 0:
			# All hyperparameters are batched
			self.batch_in_axes = jtu.tree_map(lambda _: 0, inner_kernel)
		else:
			self.batch_in_axes = batch_in_axes

		self.batch_over_inputs = 0 if batch_over_inputs else None

		# Add batch dimension to parameters where batch_in_axes is 0
		self.inner_kernel = jtu.tree_map(
			lambda param, batch_in_ax: param if batch_in_ax is None else jnp.repeat(param[None, ...], batch_size, axis=0),
			self.inner_kernel,
			self.batch_in_axes
		)

	@jit
	def __call__(self, x1, x2=None):
		"""
		Compute the kernel over batched inputs using vmap.

		Args:
			x1: Input of shape (B, ..., N, I)
			x2: Optional second input of shape (B, ..., M, I)

		Returns:
			Kernel matrix of appropriate shape with batch dimension
		"""
		# vmap over the batch dimension of inner_kernel and inputs
		# Each batch element gets its own version of inner_kernel with corresponding hyperparameters
		return vmap(
			lambda kernel, x1, x2: kernel(x1, x2),
			in_axes=(self.batch_in_axes, self.batch_over_inputs, self.batch_over_inputs if x2 is not None else None)
		)(self.inner_kernel, x1, x2)


In [15]:
class ActiveDimsKernel(WrapperKernel):
	"""
	Wrapper kernel to select active dimensions from the inputs before passing them to the inner kernel.
	"""
	active_dims: jnp.ndarray = eqx.field(static=True, converter=jnp.array)

	def __init__(self, inner_kernel, active_dims):
		"""
		:param inner_kernel: the kernel to wrap, must be an instance of AbstractKernel
		:param active_dims: the indices of the active dimensions to select from the inputs (1D array of integers)
		"""
		super().__init__(inner_kernel=inner_kernel)
		self.active_dims = active_dims

	@jit
	def __call__(self, x1: jnp.ndarray, x2: jnp.ndarray = None) -> jnp.ndarray:
		# TODO: add runtime error if active_dims doesn't match input dimensions
		if x2 is None:
			x2 = x1

		return self.inner_kernel(x1[..., self.active_dims], x2[..., self.active_dims])

In [16]:
class ARDKernel(WrapperKernel):
	"""
	Wrapper kernel to apply Automatic Relevance Determination (ARD) to the inputs before passing them to the inner kernel.
	Each input dimension is scaled by a separate length scale hyperparameter.
	"""
	length_scales: jnp.ndarray = eqx.field(converter=jnp.array)

	def __init__(self, inner_kernel, length_scales):
		"""
		:param inner_kernel: the kernel to wrap, must be an instance of AbstractKernel
		:param length_scales: the length scales for each input dimension (1D array of floats)
		"""
		# TODO: for now, this kernel only works as the direct child of an Isotropic kernel, as it modifies the inner kernel length_scale directly
		#  It would be nice if it could work on combinations of Isotropic kernels, modifying every "length_scale" parameter it finds in the inner kernel tree
		#  It's hard to implement this behavior at the instance-level. Maybe we should make the user do it, or have a utility function to do it for them.
		#  (It's not as simple as turning length_scale to a static attribute, see https://github.com/patrick-kidger/equinox/issues/154)
		#  (It's also not as simple as setting `length_scale` to 1 during initialization, because the value could still be optimized to other values)
		super().__init__(inner_kernel=inner_kernel)
		self.length_scales = length_scales

	@jit
	def __call__(self, x1: jnp.ndarray, x2: jnp.ndarray = None) -> jnp.ndarray:
		if x2 is None:
			x2 = x1

		self.inner_kernel.length_scale = jnp.ones_like(self.inner_kernel.length_scale)  # Ensure inner kernel length_scale is 1

		return self.inner_kernel(x1 / self.length_scales, x2 / self.length_scales)

### Defaults of the previous implementation that we wish to correct/improve:

**We want to abstract batch/ARD handling away from the base AbstractKernel class**

Having kernels that work on batches of inputs, producing batches of covariance matrix is a very specific use-case, only really useful for Magma's multi-task GP setting. This code should not be in the base class, but rather in a mixin or a wrapper class that can be used when needed.

ARD handling can also be used by a wrapper kernel, that will transform the inputs before passing them to the base kernel.

An optimisation kernel that transforms HPs (exponentiation, softplus, ...) can also be used as a wrapper only for optimisation, without affecting the base kernel.

In [17]:
new_kernel = SEKernel(length_scale=jnp.array(0.3))
new_batched_kernel_SHP = BatchKernel(inner_kernel=new_kernel, batch_size=B)  # Shared hyperparameters
new_batched_kernel_DHP = BatchKernel(inner_kernel=new_kernel, batch_size=B, batch_in_axes=0)  # Distinct hyperparameters
new_batched_kernel_DHP_BI = BatchKernel(inner_kernel=new_kernel, batch_size=B, batch_in_axes=0, batch_over_inputs=False)  # Distinct hyperparameters, Batched inputs

In [18]:
new_kernel

SEKernel(length_scale=weak_f64[])

In [19]:
new_kernel_ative_dims = ActiveDimsKernel(inner_kernel=new_kernel, active_dims=[0])
key, subkey = jr.split(key)
ard_length_scales = jr.uniform(subkey, shape=(N,), minval=0.5, maxval=2.0)
new_kernel_ARD = ARDKernel(inner_kernel=new_kernel, length_scales=ard_length_scales)

  new_kernel_ative_dims = ActiveDimsKernel(inner_kernel=new_kernel, active_dims=[0])


---
## Comparison

### Pair-wise covariance

#### 1D inputs

In [20]:
jnp.allclose(old_kernel(input_1D_a, input_1D_b), new_kernel(input_1D_a, input_1D_b))

Array(True, dtype=bool)

In [21]:
%%timeit
old_kernel(input_1D_a, input_1D_b).block_until_ready()

28.6 μs ± 1.47 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [22]:
%%timeit
new_kernel(input_1D_a, input_1D_b).block_until_ready()

26.9 μs ± 138 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### ND inputs

In [23]:
jnp.allclose(old_kernel(input_ND_a, input_ND_b), new_kernel(input_ND_a, input_ND_b))

Array(True, dtype=bool)

In [24]:
%%timeit
old_kernel(input_ND_a, input_ND_b).block_until_ready()

27.5 μs ± 322 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [25]:
%%timeit
new_kernel(input_ND_a, input_ND_b).block_until_ready()

27 μs ± 207 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### ND inputs with active_dims

In [26]:
jnp.allclose(input_1D_a, input_ND_a[0]), jnp.allclose(input_1D_b, input_ND_b[0])

(Array(True, dtype=bool), Array(True, dtype=bool))

In [27]:
jnp.allclose(old_kernel(input_1D_a, input_1D_b), new_kernel_ative_dims(input_ND_a, input_ND_b))

Array(True, dtype=bool)

In [28]:
input_1D_grid_regular

Array([[ 0.        ],
       [ 0.1010101 ],
       [ 0.2020202 ],
       [ 0.3030303 ],
       [ 0.4040404 ],
       [ 0.50505051],
       [ 0.60606061],
       [ 0.70707071],
       [ 0.80808081],
       [ 0.90909091],
       [ 1.01010101],
       [ 1.11111111],
       [ 1.21212121],
       [ 1.31313131],
       [ 1.41414141],
       [ 1.51515152],
       [ 1.61616162],
       [ 1.71717172],
       [ 1.81818182],
       [ 1.91919192],
       [ 2.02020202],
       [ 2.12121212],
       [ 2.22222222],
       [ 2.32323232],
       [ 2.42424242],
       [ 2.52525253],
       [ 2.62626263],
       [ 2.72727273],
       [ 2.82828283],
       [ 2.92929293],
       [ 3.03030303],
       [ 3.13131313],
       [ 3.23232323],
       [ 3.33333333],
       [ 3.43434343],
       [ 3.53535354],
       [ 3.63636364],
       [ 3.73737374],
       [ 3.83838384],
       [ 3.93939394],
       [ 4.04040404],
       [ 4.14141414],
       [ 4.24242424],
       [ 4.34343434],
       [ 4.44444444],
       [ 4

In [29]:
input_ND_grid_regular[10]

Array([0.        , 0.        , 3.33333333, 3.33333333], dtype=float64)

In [30]:
old_kernel(input_1D_a, input_1D_b)

Array(0.00386592, dtype=float64)

In [31]:
new_kernel_ative_dims(input_ND_a, input_ND_b)

Array(0.00386592, dtype=float64)

In [32]:
# TODO

#### ND inputs with ARD

In [33]:
# TODO

### Cross-covariance (Gram matrix)

#### On regular 1D input grid

In [34]:
jnp.allclose(old_kernel(input_1D_grid_regular), new_kernel(input_1D_grid_regular))

Array(True, dtype=bool)

In [35]:
%%timeit
old_kernel(input_1D_grid_regular).block_until_ready()

69.1 μs ± 638 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [36]:
%%timeit
new_kernel(input_1D_grid_regular).block_until_ready()

70.5 μs ± 842 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### On irregular 1D input grid

In [37]:
jnp.allclose(old_kernel(input_1D_grid_irregular), new_kernel(input_1D_grid_irregular))

Array(True, dtype=bool)

In [38]:
%%timeit
old_kernel(input_1D_grid_irregular).block_until_ready()

70.8 μs ± 1.55 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [39]:
%%timeit
new_kernel(input_1D_grid_irregular).block_until_ready()

70.6 μs ± 1.7 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### On padded 1D input grid

In [40]:
jnp.allclose(old_kernel(input_1D_grid_padded), new_kernel(input_1D_grid_padded), equal_nan=True)

Array(True, dtype=bool)

In [41]:
%%timeit
old_kernel(input_1D_grid_padded).block_until_ready()

67.9 μs ± 680 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [42]:
%%timeit
new_kernel(input_1D_grid_padded).block_until_ready()

70.7 μs ± 1.17 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### On regular ND input grid

In [43]:
jnp.allclose(old_kernel(input_ND_grid_regular), new_kernel(input_ND_grid_regular))

Array(True, dtype=bool)

In [44]:
%%timeit
old_kernel(input_ND_grid_regular).block_until_ready()

95.9 μs ± 1.71 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [45]:
%%timeit
new_kernel(input_ND_grid_regular).block_until_ready()

97.1 μs ± 2.11 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### On irregular ND input grid

In [46]:
np.asarray(old_kernel(input_ND_grid_irregular))

array([[1.00000000e+000, 9.41030794e-127, 2.53771102e-136, ...,
        5.76859952e-081, 1.82065224e-017, 4.09909225e-082],
       [9.41030794e-127, 1.00000000e+000, 1.18237080e-209, ...,
        3.36453945e-079, 9.70492285e-143, 3.60090181e-216],
       [2.53771102e-136, 1.18237080e-209, 1.00000000e+000, ...,
        1.13244540e-099, 1.12323097e-240, 2.83822014e-149],
       ...,
       [5.76859952e-081, 3.36453945e-079, 1.13244540e-099, ...,
        1.00000000e+000, 2.96316562e-138, 2.90961345e-081],
       [1.82065224e-017, 9.70492285e-143, 1.12323097e-240, ...,
        2.96316562e-138, 1.00000000e+000, 7.63318691e-119],
       [4.09909225e-082, 3.60090181e-216, 2.83822014e-149, ...,
        2.90961345e-081, 7.63318691e-119, 1.00000000e+000]],
      shape=(100, 100))

In [47]:
np.asarray(new_kernel(input_ND_grid_irregular))

array([[1.00000000e+000, 9.41030794e-127, 2.53771102e-136, ...,
        5.76859952e-081, 1.82065224e-017, 4.09909225e-082],
       [9.41030794e-127, 1.00000000e+000, 1.18237080e-209, ...,
        3.36453945e-079, 9.70492285e-143, 3.60090181e-216],
       [2.53771102e-136, 1.18237080e-209, 1.00000000e+000, ...,
        1.13244540e-099, 1.12323097e-240, 2.83822014e-149],
       ...,
       [5.76859952e-081, 3.36453945e-079, 1.13244540e-099, ...,
        1.00000000e+000, 2.96316562e-138, 2.90961345e-081],
       [1.82065224e-017, 9.70492285e-143, 1.12323097e-240, ...,
        2.96316562e-138, 1.00000000e+000, 7.63318691e-119],
       [4.09909225e-082, 3.60090181e-216, 2.83822014e-149, ...,
        2.90961345e-081, 7.63318691e-119, 1.00000000e+000]],
      shape=(100, 100))

In [48]:
jnp.allclose(old_kernel(input_ND_grid_irregular), new_kernel(input_ND_grid_irregular))

Array(True, dtype=bool)

In [49]:
%%timeit
old_kernel(input_ND_grid_irregular).block_until_ready()

96.4 μs ± 1.6 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [50]:
%%timeit
new_kernel(input_ND_grid_irregular).block_until_ready()

94.8 μs ± 2.79 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### On padded ND input grid

In [51]:
np.asarray(old_kernel(input_ND_grid_padded))

array([[1.00000000e+000, 9.41030794e-127, 2.53771102e-136, ...,
                    nan, 1.82065224e-017, 4.09909225e-082],
       [9.41030794e-127, 1.00000000e+000, 1.18237080e-209, ...,
                    nan, 9.70492285e-143, 3.60090181e-216],
       [2.53771102e-136, 1.18237080e-209, 1.00000000e+000, ...,
                    nan, 1.12323097e-240, 2.83822014e-149],
       ...,
       [            nan,             nan,             nan, ...,
                    nan,             nan,             nan],
       [1.82065224e-017, 9.70492285e-143, 1.12323097e-240, ...,
                    nan, 1.00000000e+000, 7.63318691e-119],
       [4.09909225e-082, 3.60090181e-216, 2.83822014e-149, ...,
                    nan, 7.63318691e-119, 1.00000000e+000]],
      shape=(100, 100))

In [52]:
np.asarray(new_kernel(input_ND_grid_padded))

array([[1.00000000e+000, 9.41030794e-127, 2.53771102e-136, ...,
                    nan, 1.82065224e-017, 4.09909225e-082],
       [9.41030794e-127, 1.00000000e+000, 1.18237080e-209, ...,
                    nan, 9.70492285e-143, 3.60090181e-216],
       [2.53771102e-136, 1.18237080e-209, 1.00000000e+000, ...,
                    nan, 1.12323097e-240, 2.83822014e-149],
       ...,
       [            nan,             nan,             nan, ...,
                    nan,             nan,             nan],
       [1.82065224e-017, 9.70492285e-143, 1.12323097e-240, ...,
                    nan, 1.00000000e+000, 7.63318691e-119],
       [4.09909225e-082, 3.60090181e-216, 2.83822014e-149, ...,
                    nan, 7.63318691e-119, 1.00000000e+000]],
      shape=(100, 100))

In [53]:
jnp.allclose(old_kernel(input_ND_grid_padded), new_kernel(input_ND_grid_padded), equal_nan=True)

Array(True, dtype=bool)

In [54]:
%%timeit
old_kernel(input_ND_grid_padded).block_until_ready()

87.6 μs ± 1.59 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [55]:
%%timeit
new_kernel(input_ND_grid_padded).block_until_ready()

91.9 μs ± 2.1 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### On ND input grid with active_dims

In [56]:
# TODO

#### On ND input grid with ARD

In [57]:
# TODO

### Batched cross-covariance - Shared hyperparameters

#### Regular 1D inputs

In [58]:
jnp.allclose(old_batched_kernel(batched_input_1D_grid_regular), new_batched_kernel_SHP(batched_input_1D_grid_regular))

ValueError: Invalid input dimensions: x1 has shape (250, 100, 1), x2 has shape (250, 100, 1). Expected scalar, 1D or 2D arrays as inputs.

In [52]:
%%timeit
old_batched_kernel(batched_input_1D_grid_regular).block_until_ready()

279 μs ± 787 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [53]:
%%timeit
new_batched_kernel_SHP(batched_input_1D_grid_regular).block_until_ready()

376 μs ± 1.35 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### Irregular 1D inputs

In [54]:
jnp.allclose(old_batched_kernel(batched_input_1D_grid_irregular), new_batched_kernel_SHP(batched_input_1D_grid_irregular))

Array(True, dtype=bool)

In [55]:
%%timeit
old_batched_kernel(batched_input_1D_grid_irregular).block_until_ready()

279 μs ± 1.05 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [56]:
%%timeit
new_batched_kernel_SHP(batched_input_1D_grid_irregular).block_until_ready()

379 μs ± 1.01 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### Padded 1D inputs

In [57]:
jnp.allclose(old_batched_kernel(batched_input_1D_grid_padded), new_batched_kernel_SHP(batched_input_1D_grid_padded),
             equal_nan=True)

Array(True, dtype=bool)

In [58]:
%%timeit
old_batched_kernel(batched_input_1D_grid_padded).block_until_ready()

253 μs ± 515 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [59]:
%%timeit
new_batched_kernel_SHP(batched_input_1D_grid_padded).block_until_ready()

359 μs ± 1.74 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### Regular ND inputs

In [60]:
jnp.allclose(old_batched_kernel(batched_input_ND_grid_regular), new_batched_kernel_SHP(batched_input_ND_grid_regular))

Array(True, dtype=bool)

In [61]:
%%timeit
old_batched_kernel(batched_input_ND_grid_regular).block_until_ready()

4.38 ms ± 5.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [62]:
%%timeit
new_batched_kernel_SHP(batched_input_ND_grid_regular).block_until_ready()

4.49 ms ± 14 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


#### Irregular ND inputs

In [63]:
jnp.allclose(old_batched_kernel(batched_input_ND_grid_irregular), new_batched_kernel_SHP(batched_input_ND_grid_irregular))

Array(True, dtype=bool)

In [64]:
%%timeit
old_batched_kernel(batched_input_ND_grid_irregular).block_until_ready()

4.39 ms ± 14.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [65]:
%%timeit
new_batched_kernel_SHP(batched_input_ND_grid_irregular).block_until_ready()

4.49 ms ± 21.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


#### Padded ND inputs

In [66]:
jnp.allclose(old_batched_kernel(batched_input_ND_grid_padded), new_batched_kernel_SHP(batched_input_ND_grid_padded),
             equal_nan=True)

Array(True, dtype=bool)

In [67]:
%%timeit
old_batched_kernel(batched_input_ND_grid_padded).block_until_ready()

4.32 ms ± 7.45 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [68]:
%%timeit
new_batched_kernel_SHP(batched_input_ND_grid_padded).block_until_ready()

4.45 ms ± 14.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Batched cross-covariance - Distinct hyperparameters

#### Regular 1D inputs

In [69]:
jnp.allclose(old_batched_kernel(batched_input_1D_grid_regular), new_batched_kernel_DHP(batched_input_1D_grid_regular))

Array(True, dtype=bool)

In [70]:
%%timeit
old_batched_kernel(batched_input_1D_grid_regular).block_until_ready()

279 μs ± 416 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [71]:
%%timeit
new_batched_kernel_DHP(batched_input_1D_grid_regular).block_until_ready()

389 μs ± 2.26 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### Irregular 1D inputs

In [72]:
jnp.allclose(old_batched_kernel(batched_input_1D_grid_irregular), new_batched_kernel_DHP(batched_input_1D_grid_irregular))

Array(True, dtype=bool)

In [73]:
%%timeit
old_batched_kernel(batched_input_1D_grid_irregular).block_until_ready()

279 μs ± 476 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [74]:
%%timeit
new_batched_kernel_DHP(batched_input_1D_grid_irregular).block_until_ready()

395 μs ± 1.54 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### Padded 1D inputs

In [75]:
jnp.allclose(old_batched_kernel(batched_input_1D_grid_padded), new_batched_kernel_DHP(batched_input_1D_grid_padded),
             equal_nan=True)

Array(True, dtype=bool)

In [76]:
%%timeit
old_batched_kernel(batched_input_1D_grid_padded).block_until_ready()

255 μs ± 2.47 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [77]:
%%timeit
new_batched_kernel_DHP(batched_input_1D_grid_padded).block_until_ready()

373 μs ± 976 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### Regular ND inputs

In [78]:
jnp.allclose(old_batched_kernel(batched_input_ND_grid_regular), new_batched_kernel_DHP(batched_input_ND_grid_regular))

Array(True, dtype=bool)

In [79]:
%%timeit
old_batched_kernel(batched_input_ND_grid_regular).block_until_ready()

4.41 ms ± 18.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [80]:
%%timeit
new_batched_kernel_DHP(batched_input_ND_grid_regular).block_until_ready()

4.7 ms ± 109 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


#### Irregular ND inputs

In [81]:
jnp.allclose(old_batched_kernel(batched_input_ND_grid_irregular), new_batched_kernel_DHP(batched_input_ND_grid_irregular))

Array(True, dtype=bool)

In [82]:
%%timeit
old_batched_kernel(batched_input_ND_grid_irregular).block_until_ready()

4.4 ms ± 4.86 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [83]:
%%timeit
new_batched_kernel_DHP(batched_input_ND_grid_irregular).block_until_ready()

4.71 ms ± 153 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


#### Padded ND inputs

In [84]:
jnp.allclose(old_batched_kernel(batched_input_ND_grid_padded), new_batched_kernel_DHP(batched_input_ND_grid_padded),
             equal_nan=True)

Array(True, dtype=bool)

In [85]:
%%timeit
old_batched_kernel(batched_input_ND_grid_padded).block_until_ready()

4.36 ms ± 28.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [86]:
%%timeit
new_batched_kernel_DHP(batched_input_ND_grid_padded).block_until_ready()

4.47 ms ± 5.86 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### On un-batched input with distinct hyperparameters

In [49]:
jnp.allclose(new_batched_kernel_SHP(batched_input_1D_grid_regular), new_batched_kernel_DHP_BI(input_1D_grid_regular))

Array(True, dtype=bool)

---
## Conclusion

---