# Benchmarks - Kernels

In MagmaClust and GPs in general, kernels are what makes it possible to compute a covariance matrix for any set of points, given a set of hyperparameters. They are used in many innermost loops, and their implementation is critical for performance.

**Main considerations when implementing kernels**

A good kernel implementation must be:
* fast, as it is used in many innermost loops
* usable at many dimensions, including a batch dimension with distinct hyperparameters for each element
* work on padded inputs (aka inputs with NaNs), maybe using a mask
* jittable, as it is used in many jit-compiled functions
* modular, as kernels can be combined in many ways
* both static and instance-based, to either carry around hyperparameters or be called with them
* easy to override, as users may want to define their own kernels

These goals are conflicting in some cases (e.g: a jittable version of a kernel is not trivial to write for most people), and the best implementation will depend on the specific use case.

---
## Setup

In [1]:
# Standard library
import os

In [2]:
# Third party
import jax
from jax import jit, vmap
from jax.tree_util import register_pytree_node_class
from jax import numpy as jnp

import numpy as np

In [3]:
# Local
from MagmaClustPy.kernels import RBFKernel, AbstractKernel

In [4]:
# Config
os.environ['JAX_ENABLE_X64'] = "True"
key = jax.random.PRNGKey(0)

---
## Data

---
## Current implementation

In [5]:
old_kernel = RBFKernel(length_scale=jnp.array(1.), variance=jnp.array(1.))

---
## Custom implementation(s)

### Defaults of the previous implementation that we wish to correct/improve:

**The kwargs are converted to args**

To make vmap work, we convert kwargs to args in the old implementation

This leads to potential bugs where the order of the kwargs is not respected, either in provided params or in the class definition.
Alternative is tu use a "lambda" version of each compute_* method, with kwargs as a parameter, and then vmap this
lambda function, as presented in comments. However, this may lead to jit compiling many times the same function,
which is not optimal. I'm not sure if this is the case, so we should check.

In [30]:
@register_pytree_node_class
class NewAbstractKernel:
	def __init__(self, **kwargs):
		# Check that hyperparameters are all jnp arrays/scalars
		for key, value in kwargs.items():
			if not isinstance(value, jnp.ndarray):  # Check type
				raise ValueError(f"Parameter {key} must be a jnp.ndarray.")
			else:  # Check dimensionality
				if len(value.shape) > 1:
					raise ValueError(f"Parameter {key} must be a scalar or a 1D array, got shape {value.shape}.")

		# Register hyperparameters in *kwargs* as instance attributes
		self.__dict__.update(kwargs)

	@jit
	def check_kwargs(self, **kwargs):
		for key in self.__dict__:
			if key not in kwargs:
				kwargs[key] = self.__dict__[key]
		return kwargs

	@jit
	def __call__(self, x1, x2=None, **kwargs):
		# If no x2 is provided, we compute the covariance between x1 and itself
		if x2 is None:
			x2 = x1

		# Check kwargs
		kwargs = self.check_kwargs(**kwargs)

		# Call the appropriate method
		if jnp.isscalar(x1) and jnp.isscalar(x2):
			return self.compute_scalar(x1, x2, **kwargs)
		elif jnp.ndim(x1) == 1 and jnp.isscalar(x2):
			return self.compute_vector(x1, x2, **kwargs)
		elif jnp.isscalar(x1) and jnp.ndim(x2) == 1:
			return self.compute_vector(x2, x1, **kwargs)
		elif jnp.ndim(x1) == 1 and jnp.ndim(x2) == 1:
			return self.compute_matrix(x1, x2, **kwargs)
		elif jnp.ndim(x1) == 2 and jnp.ndim(x2) == 2:
			return self.compute_batch(x1, x2, **kwargs)
		else:
			return jnp.nan

	# Methods to use Kernel as a PyTree
	def tree_flatten(self):
		return tuple(self.__dict__.values()), None  # No static values

	@classmethod
	def tree_unflatten(cls, _, children):
		# This class being abstract, this function fails when called on an "abstract instance",
		# as we don't know the number of parameters the constructor expects, yet we send it children.
		# On a subclass, this will work as expected as long as the constructor has a clear number of
		# kwargs as parameters.
		return cls(*children)

	@jit
	def compute_scalar(self, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Compute the kernel covariance value between two scalar arrays.

		:param x1: scalar array
		:param x2: scalar array
		:param args: hyperparameters of the kernel
		:return: scalar array
		"""
		return jnp.array(jnp.nan)  # To be overwritten

	@jit
	def compute_vector(self, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Compute the kernel covariance value between a vector and a scalar.

		:param x1: vector array (N, )
		:param x2: scalar array
		:param args: hyperparameters of the kernel
		:return: vector array (N, )
		"""
		return vmap(lambda x: self.compute_scalar(x, x2, **kwargs), in_axes=0)(x1)
		#  return vmap(self.compute_scalar, in_axes=(0, None) + (None,) * len(args))(x1, x2, *args).squeeze()

	@jit
	def compute_matrix(self, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Compute the kernel covariance matrix between two vector arrays.

		:param x1: vector array (N, )
		:param x2: vector array (M, )
		:param args: hyperparameters of the kernel
		:return: matrix array (N, M)
		"""
		return vmap(lambda x: self.compute_vector(x2, x, **kwargs), in_axes=0)(x1)
		# return vmap(self.compute_vector, in_axes=(None, 0) + (None,) * len(args))(x2, x1, *args)

	@jit
	def compute_batch(self, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Compute the kernel covariance matrix between two batched vector arrays.

		:param x1: vector array (B, N)
		:param x2: vector array (B, M)
		:param args: hyperparameters of the kernel. Each HP that is a scalar will be common to the whole batch, and
		each HP that is a vector will be distinct and thus must have shape (B, )
		:return: tensor array (B, N, M)
		"""
		# vmap(self.compute_matrix)(x1, x2, **kwargs)
		common_hps = {key: value for key, value in kwargs.items() if jnp.isscalar(value)}
		distinct_hps = {key: value for key, value in kwargs.items() if not jnp.isscalar(value)}

		return vmap(lambda x, y, hps: self.compute_matrix(x, y, **hps, **common_hps), in_axes=(0, 0, 0))(x1, x2, distinct_hps)
		# kwargs_axes = tuple(None if jnp.isscalar(hp) else 0 for hp in kwargs)
		# return vmap(self.compute_matrix, in_axes=(0, 0) + kwargs_axes)(x1, x2, **kwargs)

In [31]:
@register_pytree_node_class
class NewRBFKernel(NewAbstractKernel):
	def __init__(self, length_scale=None, variance=None):
		if length_scale is None:
			length_scale = jnp.array([1.])
		if variance is None:
			variance = jnp.array([1.])
		super().__init__(length_scale=length_scale, variance=variance)

	@jit
	def compute_scalar(self, x1: jnp.ndarray, x2: jnp.ndarray, length_scale=None, variance=None) -> jnp.ndarray:
		return variance * jnp.exp(-0.5 * (x1 - x2) ** 2 / length_scale ** 2)

---
## Comparison

In [32]:
new_kernel = NewRBFKernel(length_scale=jnp.array(1.), variance=jnp.array(1.))

### On scalars

In [9]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, ())
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, ())
a, b

(Array(0.00729382, dtype=float32), Array(0.10429037, dtype=float32))

In [10]:
res1 = old_kernel(a, b)
np.asarray(res1)

array(0.9953069, dtype=float32)

In [11]:
res2 = new_kernel(a, b)
np.asarray(res2)

array(0.9953069, dtype=float32)

In [12]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [13]:
%timeit old_kernel(a, b).block_until_ready()

5.68 μs ± 165 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [14]:
%timeit new_kernel(a, b).block_until_ready()

5.67 μs ± 116 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


### On an array and a scalar

In [15]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (10000,))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, ())
a, b

(Array([0.08260381, 0.7807871 , 0.93827987, ..., 0.99492884, 0.05307555,
        0.8959063 ], dtype=float32),
 Array(0.01066005, dtype=float32))

In [16]:
res1 = old_kernel(a, b)
np.asarray(res1)

array([0.9974154 , 0.74338007, 0.6503535 , ..., 0.61607134, 0.99910086,
       0.6758187 ], dtype=float32)

In [17]:
res2 = new_kernel(a, b)
np.asarray(res2)

array([0.9974154 , 0.74338007, 0.6503535 , ..., 0.61607134, 0.99910086,
       0.6758187 ], dtype=float32)

In [18]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [19]:
%timeit old_kernel(a, b).block_until_ready()

25.7 μs ± 366 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [20]:
%timeit new_kernel(a, b).block_until_ready()

26.4 μs ± 1.66 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### On two arrays

In [21]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (10000,))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (15000,))
a.shape, b.shape

((10000,), (15000,))

In [22]:
res1 = old_kernel(a, b)
np.asarray(res1).shape

(10000, 15000)

In [23]:
res2 = new_kernel(a, b)
np.asarray(res2).shape

(10000, 15000)

In [24]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [25]:
%timeit old_kernel(a, b).block_until_ready()

39 ms ± 1.81 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [26]:
%timeit new_kernel(a, b).block_until_ready()

36.3 ms ± 1.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### On two batches of arrays with common HP

In [33]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (50, 100))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (50, 150))
a.shape, b.shape

((50, 100), (50, 150))

In [34]:
# Using common hyperparameters for all batches
res1 = old_kernel(a, b)
np.asarray(res1).shape

(50, 100, 150)

In [35]:
# Also using common hyperparameters for all batches
res2 = new_kernel(a, b)
np.asarray(res2).shape

(50, 100, 150)

In [36]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [37]:
%timeit old_kernel(a, b).block_until_ready()

323 μs ± 5.46 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [38]:
%timeit new_kernel(a, b).block_until_ready()

323 μs ± 5.35 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### On two batches of arrays with distinct HP

In [39]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (50, 100))  # 5 batches of 10-dimensional data
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (50, 150))  # 5 batches of 15-dimensional data

# Create distinct hyperparameters for each batch
key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values
key, subkey = jax.random.split(key)
distinct_variances = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values

distinct_length_scales, distinct_variances

(Array([1.4664395 , 1.1674494 , 1.2656815 , 1.4723855 , 0.56868327,
        0.9240197 , 1.2534711 , 1.0043712 , 1.3627535 , 0.7144728 ,
        0.8851514 , 0.64984024, 0.94593215, 1.4062843 , 1.0102545 ,
        1.4037178 , 1.137649  , 0.5411558 , 0.9424176 , 1.1860164 ,
        0.89514005, 0.83264935, 1.1943005 , 1.0778471 , 0.78289664,
        1.052645  , 0.87137353, 1.145605  , 0.8976754 , 1.3369253 ,
        0.86346245, 0.87112653, 0.9242132 , 0.535846  , 1.2206455 ,
        0.7226651 , 1.1459886 , 0.7372732 , 1.3149244 , 1.2490811 ,
        0.567139  , 1.4629116 , 0.79252124, 1.3502349 , 1.1053848 ,
        0.7974944 , 0.80799735, 1.2222965 , 0.6568692 , 0.7821989 ],      dtype=float32),
 Array([0.55955625, 0.7153524 , 0.9721482 , 1.344615  , 0.7779318 ,
        1.3943778 , 1.4921435 , 0.8077892 , 0.6169007 , 0.75196445,
        1.2752986 , 1.1109926 , 0.63451767, 1.3250592 , 0.9057598 ,
        0.6925794 , 1.1794176 , 1.006102  , 1.1518829 , 1.093091  ,
        0.72369075, 0.6473

In [40]:
res1 = old_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances)

In [41]:
res2 = new_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances)

In [42]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [45]:
%timeit old_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances).block_until_ready()

330 μs ± 15.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [46]:
%timeit new_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances).block_until_ready()

323 μs ± 2.55 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### On two batches of arrays with both common and distinct HP

In [47]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (50, 100))  # 5 batches of 10-dimensional data
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (50, 150))  # 5 batches of 15-dimensional data

# Create common and distinct hyperparameters for each batch
common_length_scale = jnp.array(1.)
key, subkey = jax.random.split(key)
distinct_variances = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values

common_length_scale, distinct_variances

(Array(1., dtype=float32, weak_type=True),
 Array([1.4910165 , 0.5416913 , 0.70949113, 0.62269235, 0.73912156,
        0.61616087, 0.96010077, 1.4451766 , 1.1460713 , 1.0582161 ,
        0.5557885 , 1.475741  , 0.77598464, 1.0314102 , 1.4972448 ,
        1.1258096 , 1.0357378 , 0.66969085, 1.0149455 , 1.0599906 ,
        1.0586796 , 0.9898578 , 0.63467884, 1.2513552 , 1.4747337 ,
        0.7394875 , 1.161504  , 1.2392048 , 1.2144717 , 1.0941921 ,
        1.3867334 , 1.223813  , 1.476492  , 1.3929558 , 1.357845  ,
        0.5265378 , 1.1944593 , 1.292541  , 1.0719829 , 1.0557784 ,
        1.1659764 , 1.0423886 , 1.3658253 , 0.7954941 , 1.2128425 ,
        1.2827029 , 1.1218177 , 1.4021076 , 1.1493796 , 1.3820095 ],      dtype=float32))

In [48]:
res1 = old_kernel(a, b, length_scale=common_length_scale, variance=distinct_variances)

In [49]:
res2 = new_kernel(a, b, length_scale=common_length_scale, variance=distinct_variances)

In [50]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [51]:
%timeit old_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances).block_until_ready()

331 μs ± 18.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [52]:
%timeit new_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances).block_until_ready()

318 μs ± 2.08 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


---
## Conclusion

This new way of handling kwargs with lambda functions for vmap doesn't introduce any slowdown and should prevent bugs related to the placement of those kwargs. It's definitely a better alternative.

---