# Benchmarks - Kernels

In MagmaClust and GPs in general, kernels are what makes it possible to compute a covariance matrix for any set of points, given a set of hyperparameters. They are used in many innermost loops, and their implementation is critical for performance.

**Main considerations when implementing kernels**

A good kernel implementation must be:
* fast, as it is used in many innermost loops
* usable at many dimensions, including a batch dimension with distinct hyperparameters for each element
* work on padded inputs (aka inputs with NaNs), maybe using a mask
* jittable, as it is used in many jit-compiled functions
* modular, as kernels can be combined in many ways
* both static and instance-based, to either carry around hyperparameters or be called with them
* easy to override, as users may want to define their own kernels

These goals are conflicting in some cases (e.g: a jittable version of a kernel is not trivial to write for most people), and the best implementation will depend on the specific use case.

---
## Setup

In [1]:
# Standard library
import os

In [2]:
# Third party
import jax
from jax import jit, vmap
from jax.tree_util import register_pytree_node_class
from jax import numpy as jnp

import numpy as np

In [3]:
# Local

In [4]:
# Config
os.environ['JAX_ENABLE_X64'] = "True"

---
## Data

---
## Current implementation

In [5]:
class Kernel:
	def __init__(self, **kwargs):
		self.params = kwargs

	def __call__(self, x1, x2=None, **kwargs):
		x2 = x2 if x2 is not None else x1

		# Prepare keyword arguments
		if jnp.ndim(x1) < 2 or jnp.ndim(x2) < 2:
			kwargs = self.load_and_check_kwargs(kwargs, method="single")
		elif jnp.ndim(x1) == 2 and jnp.ndim(x2) == 2:
			kwargs = self.load_and_check_kwargs(kwargs, method="batch", batch_size=x1.shape[0])
		else:
			raise ValueError("Invalid input shapes.")

		# Call the appropriate method
		if jnp.isscalar(x1) and jnp.isscalar(x2):
			return self.compute_scalar(x1, x2, **kwargs)
		elif jnp.isscalar(x1) and jnp.ndim(x2) == 1:
			return self.compute_vector(x1, x2, **kwargs)
		elif jnp.ndim(x1) == 1 and jnp.isscalar(x2):
			return self.compute_vector(x2, x1, **kwargs)
		elif jnp.ndim(x1) == 1 and jnp.ndim(x2) == 1:
			return self.compute_matrix(x1, x2, **kwargs)
		elif jnp.ndim(x1) == 2 and jnp.ndim(x2) == 2:
			return self.compute_batch(x1, x2, **kwargs)
		else:
			raise ValueError(f"Invalid input shapes: {x1.shape} and {x2.shape}.")

	@classmethod
	def compute(cls, x1, x2=None, **kwargs):
		return cls(**kwargs)(x1, x2)

	def load_and_check_kwargs(self, kwargs, method="single", batch_size=None):
		"""
		Load parameters from instance or kwargs, ensure correct shapes.
		- method="single": Scalar parameters only.
		- method="batch": Parameters must match batch size or be broadcastable.
		"""
		for key in self.params:
			if key not in kwargs:
				kwargs[key] = self.params[key]

		for key, value in kwargs.items():
			if value is None:
				raise ValueError(f"Missing parameter: {key}")

			if method == "single":
				if not jnp.isscalar(value):
					raise ValueError(f"Parameter {key} must be a scalar.")
			elif method == "batch":
				if jnp.isscalar(value):
					kwargs[key] = jnp.full((batch_size,), value)
				elif value.shape[0] != batch_size:
					raise ValueError(f"Parameter {key} must be a scalar or match batch size {batch_size}")

		return kwargs

	def params_to_tensor(self):
		return jnp.array([v for v in self.params.values()])

	def tensor_to_params(self, tensor):
		assert len(tensor) == len(self.params), f"Invalid tensor size. Got {len(tensor)}, expected {len(self.params)}."

		self.params = {k: tensor[i] for i, k in enumerate(self.params.keys())}

	def compute_scalar(self, x1: float, x2: float, **kwargs) -> float:
		raise NotImplementedError

	def compute_vector(self, x1: float, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		return vmap(lambda x: self.compute_scalar(x1, x, **kwargs), in_axes=0)(x2)

	def compute_matrix(self, x1: jnp.ndarray, x2: jnp.ndarray = None, **kwargs) -> jnp.ndarray:
		x2 = x2 if x2 is not None else x1
		return vmap(lambda x: self.compute_vector(x, x2, **kwargs), in_axes=0)(x1)

	def compute_batch(self, x1: jnp.ndarray, x2: jnp.ndarray = None, **kwargs) -> jnp.ndarray:
		if x2 is None:
			x2 = x1
		if x1.shape[0] != x2.shape[0]:
			raise ValueError("Batch sizes must match.")

		# Just compute covariances between corresponding pairs from each batch
		return vmap(self.compute_matrix)(x1, x2, **kwargs)

In [6]:
class RBFKernel(Kernel):
	"""
	Radial Basis Function (RBF) kernel, also known as the squared exponential kernel.

	Parameters:
		- length_scale (float): Determines how quickly the function decays as points move apart.
		- variance (float): Determines the amplitude of the kernel.
	"""

	def compute_scalar(self, x1: float, x2: float, **kwargs) -> float:
		return kwargs["variance"] * jnp.exp(-0.5 * (x1 - x2) ** 2 / kwargs["length_scale"] ** 2)



	def compute_vector(self, x1: float, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		return kwargs["variance"] * jnp.exp(-0.5 * (x1 - x2) ** 2 / kwargs["length_scale"] ** 2)

	def compute_matrix(self, x1: jnp.ndarray, x2: jnp.ndarray = None, **kwargs) -> jnp.ndarray:
		x2 = x2 if x2 is not None else x1
		return kwargs["variance"] * jnp.exp(-0.5 * (x1[:, None] - x2) ** 2 / kwargs["length_scale"] ** 2)

	def compute_batch(self, x1: jnp.ndarray, x2: jnp.ndarray = None, **kwargs) -> jnp.ndarray:
		if x1.shape[0] != x2.shape[0]:
			raise ValueError("Batch sizes must match.")

		length_scale = kwargs["length_scale"][:, None, None]
		variance = kwargs["variance"][:, None, None]

		# Calculate pairwise distances and apply kernel
		squared_diff = (x1[:, :, None] - x2[:, None, :]) ** 2
		return variance * jnp.exp(-0.5 * squared_diff / length_scale ** 2)

In [7]:
class SEMagmaKernel(Kernel):
	def compute_scalar(self, x1: float, x2: float, **kwargs) -> float:
		return jnp.exp(kwargs["variance"] - jnp.exp(-kwargs["length_scale"]) * jnp.sum((x1 - x2) ** 2) * 0.5)

---
## Custom implementation(s)

### Defaults of the previous implementation that we wish to correct/improve:

**The `compute_*` methods are not jittable.**

-> Jitting them requires the class to extend PyTrees, to use class attributes in their logic (see https://docs.jax.dev/en/latest/faq.html#how-to-use-jit-with-methods and https://docs.jax.dev/en/latest/pytrees.html#extending-pytrees)

-> It also requires having the hyperparameters as attributes instead of a param dictionary. However, we still want the abstract "Kernel" class to work for any set of provided hyperparameters kwargs.

-> Finally, it requires to rework the optionnal param handling, as the current conditionnal logic implementation is not jittable. Doing so might require to have custom implementations for the `common_hp` and `distinct_hp` setups, expressed as a boolean parameter. In `common_hp`, the hyperparameter is a scalar, common to every task. In `distinct_hp`, it is an array, with values for every task. For example, this function is jittable and handles this difference in dimensionality:

```python
def my_fun(x: jnp.array, y: jnp.array) -> jnp.array:
	return x * y

@partial(jit, static_argnums=2)
def my_fun_2(x: jnp.array, y: jnp.array, batched: bool) -> jnp.array:
	# batched = wether y has a batch dimension or not
	if batched:
		return vmap(my_fun)(x, y)[:,None]
	else:
		return vmap(my_fun, in_axes=(0, None))(x, y)
```

**Hyperparameters are either all common to the batch or all distinct.**

-> We want to allow for a mix of both, where some hyperparameters are common to the whole batch and some are distinct. This is useful for example when the lengthscale is common to the whole batch but the variance is distinct for each element.

### Compromises we are willing to make:

**Error prevention/messages can be less informative.**

-> JIT doesn't allow for exceptions and extensive if/else checking. We are willing to make a tradeoff for performances, even if some users might not understand in what way they are missusing the API.

In [8]:
#TODO: check vmap along kwargs
# To make vmap work, we convert kwargs to args in this implementation
# This leads to potential bugs where the order of the kwargs is not respected, either in provided params or in the class definition
# Alternative is tu use a "lambda" version of each compute_* method, with kwargs as a parameter, and then vmap this lambda function, as presented in comments. However, this may lead to jit compiling many times the same function, which is not optimal. I'm not sure if this is the case, so we should check.

@register_pytree_node_class
class AbstractKernel:
	def __init__(self, **kwargs):
		# Check that hyperparameters are all jnp arrays/scalars
		for key, value in kwargs.items():
			if not isinstance(value, jnp.ndarray):  # Check type
				raise ValueError(f"Parameter {key} must be a jnp.ndarray.")
			else:  # Check dimensionality
				if len(value.shape) > 1:
					raise ValueError(f"Parameter {key} must be a scalar or a 1D array, got shape {value.shape}.")

		# Register hyperparameters in *kwargs* as instance attributes
		self.__dict__.update(kwargs)

	@jit
	def check_kwargs(self, **kwargs):
		for key in self.__dict__:
			if key not in kwargs:
				kwargs[key] = self.__dict__[key]
		return kwargs

	@jit
	def __call__(self, x1, x2=None, **kwargs):
		# If no x2 is provided, we compute the covariance between x1 and itself
		if x2 is None:
			x2 = x1

		# Check kwargs
		kwargs = self.check_kwargs(**kwargs)
		args = kwargs.values()

		# Call the appropriate method
		if jnp.isscalar(x1) and jnp.isscalar(x2):
			return self.compute_scalar(x1, x2, *args)
		elif jnp.ndim(x1) == 1 and jnp.isscalar(x2):
			return self.compute_vector(x1, x2, *args)
		elif jnp.isscalar(x1) and jnp.ndim(x2) == 1:
			return self.compute_vector(x2, x1, *args)
		elif jnp.ndim(x1) == 1 and jnp.ndim(x2) == 1:
			return self.compute_matrix(x1, x2, *args)
		elif jnp.ndim(x1) == 2 and jnp.ndim(x2) == 2:
			return self.compute_batch(x1, x2, *args)
		else:
			return jnp.nan

	# Methods to use Kernel as a PyTree
	def tree_flatten(self):
		return tuple(self.__dict__.values()), None  # No static values

	@classmethod
	def tree_unflatten(cls, _, children):
		# This class being abstract, this function fails when called on an "abstract instance",
		# as we don't know the number of parameters the constructor expects yet we send it children.
		# On a subclass, this will work as expected as long as the constructor has a clear number of
		# kwargs as parameters.
		return cls(*children)

	@jit
	def compute_scalar(self, x1: jnp.ndarray, x2: jnp.ndarray, *args) -> jnp.ndarray:
		"""
		Compute the kernel covariance value between two scalar arrays.

		:param x1: scalar array
		:param x2: scalar array
		:param kwargs: hyperparameters of the kernel
		:return: scalar array
		"""
		return jnp.nan  # To be overwritten

	@jit
	def compute_vector(self, x1: jnp.ndarray, x2: jnp.ndarray, *args) -> jnp.ndarray:
		"""
		Compute the kernel covariance value between a vector and a scalar.

		:param x1: vector array (N,)
		:param x2: scalar array
		:param kwargs: hyperparameters of the kernel
		:return: vector array (N,)
		"""
		#return vmap(lambda x: self.compute_scalar(x, x2, **kwargs), in_axes=0)(x1)
		return vmap(self.compute_scalar, in_axes=(0, None) + (None,)*len(args))(x1, x2, *args).squeeze()

	@jit
	def compute_matrix(self, x1: jnp.ndarray, x2: jnp.ndarray, *args) -> jnp.ndarray:
		"""
		Compute the kernel covariance matrix between two vector arrays.

		:param x1: vector array (N,)
		:param x2: vector array (M,)
		:param kwargs: hyperparameters of the kernel
		:return: matrix array (N, M)
		"""
		#vmap(lambda x: self.compute_vector(x, x2, **kwargs), in_axes=0)(x1)
		return vmap(self.compute_vector, in_axes=(None, 0) + (None,)*len(args))(x2, x1, *args)

	@jit
	def compute_batch(self, x1: jnp.ndarray, x2: jnp.ndarray, *args) -> jnp.ndarray:
		"""
		Compute the kernel covariance matrix between two batched vector arrays.

		:param x1: vector array (B, N)
		:param x2: vector array (B, M)
		:param kwargs: hyperparameters of the kernel. Each HP that is a scalar will be common to the whole batch, and each HP that is a vector will be distinct and thus must have shape (B,)
		:return: tensor array (B, N, M)
		"""
		#vmap(self.compute_matrix)(x1, x2, **kwargs)
		args_axes = tuple(None if jnp.isscalar(hp) else 0 for hp in args)

		return vmap(self.compute_matrix, in_axes=(0, 0) + args_axes)(x1, x2, *args)

In [9]:
@register_pytree_node_class
class NewRBFKernel(AbstractKernel):
	def __init__(self, length_scale=None, variance=None):
		if length_scale is None:
			length_scale = jnp.array([1.])
		if variance is None:
			variance = jnp.array([1.])
		super().__init__(length_scale=length_scale, variance=variance)

	@jit
	def compute_scalar(self, x1: jnp.ndarray, x2: jnp.ndarray, length_scale=None, variance=None) -> jnp.ndarray:
		return variance * jnp.exp(-0.5 * (x1 - x2) ** 2 / length_scale ** 2)

---
## Comparison

In [10]:
old_RBF = RBFKernel(length_scale=1.0, variance=1.0)
new_RBF = NewRBFKernel(length_scale=jnp.array(1.), variance=jnp.array(1.))
key = jax.random.PRNGKey(0)

### On scalars

In [11]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, ())
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, ())
a, b

(Array(0.00729382, dtype=float32), Array(0.10429037, dtype=float32))

In [12]:
res1 = old_RBF(a, b)
np.asarray(res1)

array(0.9953069, dtype=float32)

In [13]:
res2 = new_RBF(a, b)
np.asarray(res2)

array(0.9953069, dtype=float32)

In [14]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [15]:
%timeit old_RBF(a, b).block_until_ready()

57.4 μs ± 734 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [16]:
%timeit new_RBF(a, b).block_until_ready()

5.68 μs ± 50.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


### On an array and a scalar

In [17]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (10000,))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, ())
a, b

(Array([0.08260381, 0.7807871 , 0.93827987, ..., 0.99492884, 0.05307555,
        0.8959063 ], dtype=float32),
 Array(0.01066005, dtype=float32))

In [18]:
res1 = old_RBF(a, b)
np.asarray(res1)

array([0.9974154 , 0.74338007, 0.6503535 , ..., 0.61607134, 0.99910086,
       0.6758187 ], dtype=float32)

In [19]:
res2 = new_RBF(a, b)
np.asarray(res2)

array([0.9974154 , 0.74338007, 0.6503535 , ..., 0.61607134, 0.99910086,
       0.6758187 ], dtype=float32)

In [20]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [21]:
%timeit old_RBF(a, b).block_until_ready()

487 μs ± 28.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [22]:
%timeit new_RBF(a, b).block_until_ready()

25.4 μs ± 268 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### On two arrays

In [23]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (10000,))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (15000,))
a.shape, b.shape

((10000,), (15000,))

In [24]:
res1 = old_RBF(a, b)
np.asarray(res1).shape

(10000, 15000)

In [25]:
res2 = new_RBF(a, b)
np.asarray(res2).shape

(10000, 15000)

In [26]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [27]:
%timeit old_RBF(a, b).block_until_ready()

233 ms ± 10.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [28]:
%timeit new_RBF(a, b).block_until_ready()

36.5 ms ± 277 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### On two batches of arrays with common HP

In [29]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (50, 100))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (50, 150))
a.shape, b.shape

((50, 100), (50, 150))

In [30]:
# Using common hyperparameters for all batches
res1 = old_RBF(a, b)
np.asarray(res1).shape

(50, 100, 150)

In [31]:
# Also using common hyperparameters for all batches
res2 = new_RBF(a, b)
np.asarray(res2).shape

(50, 100, 150)

In [32]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [33]:
%timeit old_RBF(a, b).block_until_ready()

1.53 ms ± 10.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [34]:
%timeit new_RBF(a, b).block_until_ready()

341 μs ± 27.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### On two batches of arrays with distinct HP

In [35]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (50, 100))  # 5 batches of 10-dimensional data
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (50, 150))  # 5 batches of 15-dimensional data

# Create distinct hyperparameters for each batch
key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values
key, subkey = jax.random.split(key)
distinct_variances = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values

distinct_length_scales, distinct_variances

(Array([0.69834626, 1.1941253 , 0.74907124, 1.4650234 , 0.6623002 ,
        0.5150695 , 1.3077172 , 0.8598883 , 1.3379186 , 0.5345807 ,
        0.9925809 , 1.0119901 , 0.6867107 , 1.3121855 , 1.4377677 ,
        1.397761  , 0.66495323, 1.1306376 , 0.95631206, 0.6137631 ,
        0.86514413, 1.2786899 , 1.0824803 , 0.89162004, 1.080331  ,
        0.70636487, 0.88181865, 1.2177788 , 0.9001007 , 0.88983023,
        1.0880367 , 1.0784531 , 1.1646242 , 1.3984015 , 1.2634037 ,
        0.5314708 , 0.53026426, 1.309241  , 0.7298783 , 1.2418199 ,
        0.9763527 , 0.52760875, 1.0559071 , 1.2554548 , 1.3978969 ,
        1.2237661 , 0.7706127 , 1.2761428 , 1.1990191 , 0.59786415],      dtype=float32),
 Array([0.9009923 , 0.7235527 , 0.8779906 , 1.007704  , 0.8343816 ,
        0.961216  , 0.652822  , 0.7110219 , 0.84446895, 1.1748973 ,
        0.5548018 , 1.3568562 , 1.4311233 , 0.9435116 , 0.68918073,
        0.751114  , 1.4021089 , 1.2387986 , 0.9405923 , 1.2024236 ,
        0.8921603 , 0.5305

In [36]:
res1 = old_RBF(a, b, length_scale=distinct_length_scales, variance=distinct_variances)

In [37]:
res2 = new_RBF(a, b, length_scale=distinct_length_scales, variance=distinct_variances)

In [38]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [39]:
%timeit old_RBF(a, b, length_scale=distinct_length_scales, variance=distinct_variances).block_until_ready()

1.5 ms ± 79.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [40]:
%timeit new_RBF(a, b, length_scale=distinct_length_scales, variance=distinct_variances).block_until_ready()

318 μs ± 4.64 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


---
## Conclusion

---

## Conclusion

This benchmark highlights the fact that the new, jittable implementation is up to 20x faster on CPU, with little to no compromise on the API side.

The new implementation is also more flexible, as it allows for a mix of common and distinct hyperparameters for each batch, which can be useful in many cases.

Next steps to improve this implementation would be to check if the `vmap` calls are optimal (maybe lambda functions would work better with kwargs), and to test the implementation on GPU to see if the speedup is even more significant.