# Benchmarks - Kernels

In MagmaClust and GPs in general, kernels are what makes it possible to compute a covariance matrix for any set of points, given a set of hyperparameters. They are used in many innermost loops, and their implementation is critical for performance.

**Main considerations when implementing kernels**

A good kernel implementation must be:
* fast, as it is used in many innermost loops
* usable at many dimensions, including a batch dimension with distinct hyperparameters for each element
* work on padded inputs (aka inputs with NaNs), maybe using a mask
* jittable, as it is used in many jit-compiled functions
* modular, as kernels can be combined in many ways
* both static and instance-based, to either carry around hyperparameters or be called with them
* easy to override, as users may want to define their own kernels

These goals are conflicting in some cases (e.g: a jittable version of a kernel is not trivial to write for most people), and the best implementation will depend on the specific use case.

---
## Setup

In [1]:
# Standard library
import os

os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
from jax import jit, vmap
from jax.tree_util import register_pytree_node_class
from jax import numpy as jnp
from jax import lax

import numpy as np

In [3]:
# Local
from MagmaClustPy.kernels import RBFKernel, AbstractKernel
from MagmaClustPy.utils import generate_dummy_db, preprocess_db

In [4]:
# Config
key = jax.random.PRNGKey(0)

---
## Data

---
## Current implementation

In [5]:
old_kernel = RBFKernel(length_scale=jnp.array(1.), variance=jnp.array(1.))

---
## Custom implementation(s)

### Defaults of the previous implementation that we wish to correct/improve:

**Kernels on padded arrays are slow**

That's because of two things:
- we compute the kernel on every pair of inputs, regardless of whether they are NaN or not
- the "aligned" padded arrays are huge, but kernels do not necessitate the inputs to be aligned

In [6]:
@register_pytree_node_class
class NewAbstractKernel:
	def __init__(self, **kwargs):
		# Check that hyperparameters are all jnp arrays/scalars
		for key, value in kwargs.items():
			if not isinstance(value, jnp.ndarray):  # Check type
				raise ValueError(f"Parameter {key} must be a jnp.ndarray.")
			else:  # Check dimensionality
				if len(value.shape) > 1:
					raise ValueError(f"Parameter {key} must be a scalar or a 1D array, got shape {value.shape}.")

		# Register hyperparameters in *kwargs* as instance attributes
		self.__dict__.update(kwargs)

	@jit
	def check_kwargs(self, **kwargs):
		for key in self.__dict__:
			if key not in kwargs:
				kwargs[key] = self.__dict__[key]
		return kwargs

	@jit
	def __call__(self, x1, x2=None, **kwargs):
		# If no x2 is provided, we compute the covariance between x1 and itself
		if x2 is None:
			x2 = x1

		# Check kwargs
		kwargs = self.check_kwargs(**kwargs)

		# Call the appropriate method
		if jnp.isscalar(x1) and jnp.isscalar(x2):
			return self.compute_scalar_if_not_nan(x1, x2, **kwargs)
		elif jnp.ndim(x1) == 1 and jnp.isscalar(x2):
			return self.compute_vector_if_not_nan(x1, x2, **kwargs)
		elif jnp.isscalar(x1) and jnp.ndim(x2) == 1:
			return self.compute_vector_if_not_nan(x2, x1, **kwargs)
		elif jnp.ndim(x1) == 1 and jnp.ndim(x2) == 1:
			return self.compute_matrix(x1, x2, **kwargs)
		elif jnp.ndim(x1) == 2 and jnp.ndim(x2) == 2:
			return self.compute_batch(x1, x2, **kwargs)
		else:
			return jnp.nan

	# Methods to use Kernel as a PyTree
	def tree_flatten(self):
		return tuple(self.__dict__.values()), None  # No static values

	@classmethod
	def tree_unflatten(cls, _, children):
		# This class being abstract, this function fails when called on an "abstract instance",
		# as we don't know the number of parameters the constructor expects, yet we send it children.
		# On a subclass, this will work as expected as long as the constructor has a clear number of
		# kwargs as parameters.
		return cls(*children)

	@jit
	def compute_scalar_if_not_nan(self, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Returns NaN if either x1 or x2 is NaN, otherwise calls the compute_scalar method.

		:param x1: scalar array
		:param x2: scalar array
		:param kwargs: hyperparameters of the kernel
		:return: scalar array
		"""
		return lax.cond(jnp.isnan(x1) | jnp.isnan(x2), lambda _: jnp.nan, lambda _: self.compute_scalar(x1, x2, **kwargs), None)

	@jit
	def compute_scalar(self, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Compute the kernel covariance value between two scalar arrays.

		:param x1: scalar array
		:param x2: scalar array
		:param kwargs: hyperparameters of the kernel
		:return: scalar array
		"""
		return jnp.array(jnp.nan)  # To be overwritten in subclasses

	@jit
	def compute_vector(self, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Compute the kernel covariance value between a vector and a scalar.

		:param x1: vector array (N, )
		:param x2: scalar array
		:param kwargs: hyperparameters of the kernel
		:return: vector array (N, )
		"""
		return vmap(lambda x: self.compute_scalar_if_not_nan(x, x2, **kwargs), in_axes=0)(x1)

	@jit
	def compute_vector_if_not_nan(self, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Returns an array of NaN if scalar is NaN, otherwise calls the compute_vector method.

		:param x1: vector array (N, )
		:param x2: scalar array
		:param kwargs: hyperparameters of the kernel
		:return: vector array (N, )
		"""
		return lax.cond(jnp.any(jnp.isnan(x2)), lambda _: x1 * jnp.nan, lambda _: self.compute_vector(x1, x2, **kwargs), None)

	@jit
	def compute_matrix(self, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Compute the kernel covariance matrix between two vector arrays.

		:param x1: vector array (N, )
		:param x2: vector array (M, )
		:param kwargs: hyperparameters of the kernel
		:return: matrix array (N, M)
		"""
		return vmap(lambda x: self.compute_vector_if_not_nan(x2, x, **kwargs), in_axes=0)(x1)

	@jit
	def compute_batch(self, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Compute the kernel covariance matrix between two batched vector arrays.

		:param x1: vector array (B, N)
		:param x2: vector array (B, M)
		:param kwargs: hyperparameters of the kernel. Each HP that is a scalar will be common to the whole batch, and
		each HP that is a vector will be distinct and thus must have shape (B, )
		:return: tensor array (B, N, M)
		"""
		# vmap(self.compute_matrix)(x1, x2, **kwargs)
		common_hps = {key: value for key, value in kwargs.items() if jnp.isscalar(value)}
		distinct_hps = {key: value for key, value in kwargs.items() if not jnp.isscalar(value)}

		return vmap(lambda x, y, hps: self.compute_matrix(x, y, **hps, **common_hps), in_axes=(0, 0, 0))(x1, x2,
		                                                                                                distinct_hps)

In [7]:
@register_pytree_node_class
class NewRBFKernel(NewAbstractKernel):
	def __init__(self, length_scale=None, variance=None):
		if length_scale is None:
			length_scale = jnp.array([1.])
		if variance is None:
			variance = jnp.array([1.])
		super().__init__(length_scale=length_scale, variance=variance)

	@jit
	def compute_scalar(self, x1: jnp.ndarray, x2: jnp.ndarray, length_scale=None, variance=None) -> jnp.ndarray:
		return variance * jnp.exp(-0.5 * (x1 - x2) ** 2 / length_scale ** 2)

'\n\t@jit\n\tdef compute_vector(self, x1: jnp.ndarray, x2: jnp.ndarray, length_scale=None, variance=None) -> jnp.ndarray:\n\t\treturn variance * jnp.exp(-0.5 * (x1 - x2) ** 2 / length_scale ** 2)  # Works the same as scalar\n\n\t@jit\n\tdef compute_matrix(self, x1: jnp.ndarray, x2: jnp.ndarray, length_scale=None, variance=None) -> jnp.ndarray:\n\t\treturn variance * jnp.exp(-0.5 * (x1[:, None] - x2) ** 2 / length_scale ** 2)  # Broadcast over x1\n\n\t@jit\n\tdef compute_batch(self, x1: jnp.ndarray, x2: jnp.ndarray, length_scale=None, variance=None) -> jnp.ndarray:\n\t\t# Broadcasts\n\t\tif length_scale.ndim == 0:\n\t\t\tlength_scale = jnp.broadcast_to(length_scale, (len(x1), 1, 1))\n\t\telse:\n\t\t\tlength_scale = length_scale[:, None, None]\n\n\t\tif variance.ndim == 0:\n\t\t\tvariance = jnp.broadcast_to(variance, (len(x1), 1, 1))\n\t\telse:\n\t\t\tvariance = variance[:, None, None]\n\n\t\tsquared_diff = (x1[:, :, None] - x2[:, None, :]) ** 2\n\n\t\treturn variance * jnp.exp(-0.5 * sq

---
## Comparison

In [8]:
new_kernel = NewRBFKernel(length_scale=jnp.array(1.), variance=jnp.array(1.))

### On padded datasets

In [9]:
grid = jnp.arange(-200, 200, 1, dtype=jnp.float64)
db = generate_dummy_db(50, 10, 100, grid, key)
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape, padded_outputs.shape, masks.shape

((398,), (50, 398), (50, 398), (50, 398))

In [10]:
# Covariance on padded matrix
res1 = old_kernel(padded_inputs)
res1.shape

(50, 398, 398)

In [11]:
# Covariance on un-padded matrix
res2 = new_kernel(padded_inputs)
res2.shape

(50, 398, 398)

In [12]:
# Check that values in un-padded matrix correspond to the values in the padded matrix
jnp.allclose(jnp.nan_to_num(res1), jnp.nan_to_num(res2))

Array(True, dtype=bool)

In [13]:
%timeit old_kernel(padded_inputs).block_until_ready()

6.09 ms ± 109 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
%timeit new_kernel(padded_inputs).block_until_ready()

3.29 ms ± 50.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
%timeit old_kernel(padded_inputs[0]).block_until_ready()

202 μs ± 5.55 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [16]:
%timeit new_kernel(padded_inputs[0]).block_until_ready()

141 μs ± 3.67 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [17]:
jnp.nansum(masks, axis=1).max()

Array(90, dtype=int64)

In [18]:
# Equivalent dense inputs
key, subkey = jax.random.split(key)
dense_inputs = jax.random.uniform(subkey, (50, int(jnp.nansum(masks, axis=1).max().item())))

In [19]:
%timeit old_kernel(dense_inputs).block_until_ready()

557 μs ± 10.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [50]:
padded_inputs.shape, dense_inputs.shape

((50, 398), (50, 90))

### On scalars

In [20]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, ())
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, ())
a, b

(Array(0.67119299, dtype=float64), Array(0.35166634, dtype=float64))

In [21]:
res1 = old_kernel(a, b)
np.asarray(res1)

array(0.95023245)

In [22]:
res2 = new_kernel(a, b)
np.asarray(res2)

array(0.95023245)

In [23]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [24]:
%timeit old_kernel(a, b).block_until_ready()

6.12 μs ± 241 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [25]:
%timeit new_kernel(a, b).block_until_ready()

6.47 μs ± 77.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


### On an array and a scalar

In [26]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (10000,))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, ())
a, b

(Array([0.15974131, 0.14370377, 0.62062701, ..., 0.7647444 , 0.49158483,
        0.34969892], dtype=float64),
 Array(0.96106866, dtype=float64))

In [27]:
res1 = old_kernel(a, b)
np.asarray(res1)

array([0.72537772, 0.71602322, 0.94369689, ..., 0.9809129 , 0.89564824,
       0.8295379 ])

In [28]:
res2 = new_kernel(a, b)
np.asarray(res2)

array([0.72537772, 0.71602322, 0.94369689, ..., 0.9809129 , 0.89564824,
       0.8295379 ])

In [29]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [30]:
%timeit old_kernel(a, b).block_until_ready()

54.8 μs ± 1.97 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [31]:
%timeit new_kernel(a, b).block_until_ready()

56.6 μs ± 1.79 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### On two arrays

In [32]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (10000,))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (15000,))
a.shape, b.shape

((10000,), (15000,))

In [33]:
res1 = old_kernel(a, b)
np.asarray(res1).shape

(10000, 15000)

In [34]:
res2 = new_kernel(a, b)
np.asarray(res2).shape

(10000, 15000)

In [35]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [36]:
%timeit old_kernel(a, b).block_until_ready()

177 ms ± 15.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [37]:
%timeit new_kernel(a, b).block_until_ready()

131 ms ± 3.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### On two batches of arrays with common HP

In [38]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (50, 100))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (50, 150))
a.shape, b.shape

((50, 100), (50, 150))

In [39]:
# Using common hyperparameters for all batches
res1 = old_kernel(a, b)
np.asarray(res1).shape

(50, 100, 150)

In [40]:
# Also using common hyperparameters for all batches
res2 = new_kernel(a, b)
np.asarray(res2).shape

(50, 100, 150)

In [41]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [42]:
%timeit old_kernel(a, b).block_until_ready()

949 μs ± 20.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [43]:
%timeit new_kernel(a, b).block_until_ready()

818 μs ± 96.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### On two batches of arrays with distinct HP

In [44]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (50, 100))  # 5 batches of 10-dimensional data
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (50, 150))  # 5 batches of 15-dimensional data

# Create distinct hyperparameters for each batch
key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values
key, subkey = jax.random.split(key)
distinct_variances = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values

distinct_length_scales, distinct_variances

(Array([1.40087618, 0.50347402, 0.81557193, 1.34923058, 1.00457197,
        0.65788761, 0.9219105 , 0.72545942, 1.23310028, 0.91450771,
        0.52153169, 1.07353706, 0.84001156, 0.63915561, 1.28793987,
        0.960267  , 0.5185922 , 1.46299715, 1.2229984 , 0.73096194,
        1.47447511, 1.41269148, 0.62063165, 1.39239143, 0.56357474,
        0.95879739, 1.23934137, 0.66682831, 1.27697895, 1.41619929,
        1.1099631 , 1.14399616, 1.30613247, 0.78241536, 1.13752401,
        0.7997413 , 0.86489189, 1.1392135 , 1.05768611, 0.64485643,
        0.78303156, 0.61316633, 0.53466457, 0.93155447, 0.81706077,
        0.89970457, 0.6447488 , 1.16024019, 1.30514671, 0.9256145 ],      dtype=float64),
 Array([1.34874564, 1.01689555, 1.11981318, 0.96395389, 0.73946461,
        0.57796114, 0.86174911, 0.87934257, 1.08225335, 1.42074674,
        1.38278328, 0.69312996, 0.62528442, 0.52303255, 1.35397314,
        0.62574799, 1.25923923, 0.82557538, 1.32550512, 1.13693679,
        0.61045989, 0.9583

In [45]:
res1 = old_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances)

In [46]:
res2 = new_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances)

In [47]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [48]:
%timeit old_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances).block_until_ready()

949 μs ± 14.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [49]:
%timeit new_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances).block_until_ready()

782 μs ± 19.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### On two batches of arrays with both common and distinct HP

---
## Conclusion

The conditional logic in `compute_vector` that simplifies computations on NaNs gives us roughly a 2x improvement in speed on padded matrices. However, this solution is still 6x slower than computing the covariance batch on a "denser" input. *In fine*, computing a batch kernel is a $O(M*N^2)$ operation. Computing covariance on smallest possible (meaning "padded but not aligned") vectors (N = 90) is obviously faster than on fully padded and aligned ones (N = 400).

It is interesting to note that this speed gain is only obtained if no implementation of `compute_matrix` and `compute_batch` is given, as those implementations would skip `compute_vector`.

**In the future, we should explore whether computing kernels on "compact" padded inputs and then map the values to an "aligned" batch preserves this speed gain. If so, it might be useful to consider using this "compact" representation throughout the whole algorithm.**

---