# Benchmarks - Kernels

In MagmaClust and GPs in general, kernels are what makes it possible to compute a covariance matrix for any set of points, given a set of hyperparameters. They are used in many innermost loops, and their implementation is critical for performance.

**Main considerations when implementing kernels**

A good kernel implementation must be:
* fast, as it is used in many innermost loops
* usable at many dimensions, including a batch dimension with distinct hyperparameters for each element
* work on padded inputs (aka inputs with NaNs), maybe using a mask
* jittable, as it is used in many jit-compiled functions
* modular, as kernels can be combined in many ways
* both static and instance-based, to either carry around hyperparameters or be called with them
* easy to override, as users may want to define their own kernels

These goals are conflicting in some cases (e.g: a jittable version of a kernel is not trivial to write for most people), and the best implementation will depend on the specific use case.

---
## Setup

In [1]:
# Standard library
import os

In [2]:
# Third party
import jax
from jax import jit, vmap
from jax.tree_util import register_pytree_node_class
from jax import numpy as jnp

import numpy as np

In [3]:
# Local
from MagmaClustPy.kernels import RBFKernel, AbstractKernel

In [4]:
# Config
os.environ['JAX_ENABLE_X64'] = "True"

---
## Data

---
## Current implementation

---
## Custom implementation(s)

In [7]:
@register_pytree_node_class
class NewSEMagmaKernel(AbstractKernel):
	def __init__(self, length_scale=None, variance=None):
		if length_scale is None:
			length_scale = jnp.array([1.])
		if variance is None:
			variance = jnp.array([1.])
		super().__init__(length_scale=length_scale, variance=variance)

	@jit
	def compute_scalar(self, x1: jnp.ndarray, x2: jnp.ndarray, length_scale=None, variance=None) -> jnp.ndarray:
		return jnp.exp(variance - jnp.exp(-length_scale) * jnp.sum((x1 - x2) ** 2) * 0.5)

### Defaults of the previous implementation that we wish to correct/improve:

**The `compute_*` methods are not jittable.**

-> Jitting them requires the class to extend PyTrees, to use class attributes in their logic (see https://docs.jax.dev/en/latest/faq.html#how-to-use-jit-with-methods and https://docs.jax.dev/en/latest/pytrees.html#extending-pytrees)

-> It also requires having the hyperparameters as attributes instead of a param dictionary. However, we still want the abstract "Kernel" class to work for any set of provided hyperparameters kwargs.

-> Finally, it requires to rework the optionnal param handling, as the current conditionnal logic implementation is not jittable. Doing so might require to have custom implementations for the `common_hp` and `distinct_hp` setups, expressed as a boolean parameter. In `common_hp`, the hyperparameter is a scalar, common to every task. In `distinct_hp`, it is an array, with values for every task. For example, this function is jittable and handles this difference in dimensionality:

```python
def my_fun(x: jnp.array, y: jnp.array) -> jnp.array:
	return x * y

@partial(jit, static_argnums=2)
def my_fun_2(x: jnp.array, y: jnp.array, batched: bool) -> jnp.array:
	# batched = wether y has a batch dimension or not
	if batched:
		return vmap(my_fun)(x, y)[:,None]
	else:
		return vmap(my_fun, in_axes=(0, None))(x, y)
```

**Hyperparameters are either all common to the batch or all distinct.**

-> We want to allow for a mix of both, where some hyperparameters are common to the whole batch and some are distinct. This is useful for example when the lengthscale is common to the whole batch but the variance is distinct for each element.

### Compromises we are willing to make:

**Error prevention/messages can be less informative.**

-> JIT doesn't allow for exceptions and extensive if/else checking. We are willing to make a tradeoff for performances, even if some users might not understand in what way they are missusing the API.

---
## Comparison

In [8]:
old_SE = OldSEMagmaKernel(length_scale=1.0, variance=1.0)
new_SE = NewSEMagmaKernel(length_scale=jnp.array(1.), variance=jnp.array(1.))
key = jax.random.PRNGKey(0)

### On scalars

In [9]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, ())
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, ())
a, b

(Array(0.00729382, dtype=float32), Array(0.10429037, dtype=float32))

In [10]:
res1 = old_SE(a, b)
np.asarray(res1)

array(2.7135818, dtype=float32)

In [11]:
res2 = new_SE(a, b)
np.asarray(res2)

array(2.7135818, dtype=float32)

In [12]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [13]:
%timeit old_SE(a, b).block_until_ready()

64.2 μs ± 105 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [14]:
%timeit new_SE(a, b).block_until_ready()

5.4 μs ± 20.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


### On an array and a scalar

In [15]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (10000,))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, ())
a, b

(Array([0.08260381, 0.7807871 , 0.93827987, ..., 0.99492884, 0.05307555,
        0.8959063 ], dtype=float32),
 Array(0.01066005, dtype=float32))

In [16]:
res1 = old_SE(a, b)
np.asarray(res1)

array([2.7156951, 2.4373372, 2.3203633, ..., 2.2745948, 2.7173824,
       2.3533823], dtype=float32)

In [17]:
res2 = new_SE(a, b)
np.asarray(res2)

array([2.7156951, 2.4373372, 2.3203633, ..., 2.2745948, 2.7173824,
       2.3533823], dtype=float32)

In [18]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [19]:
%timeit old_SE(a, b).block_until_ready()

773 μs ± 130 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [20]:
%timeit new_SE(a, b).block_until_ready()

20.8 μs ± 348 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### On two arrays

In [21]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (10000,))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (15000,))
a.shape, b.shape

((10000,), (15000,))

In [22]:
res1 = old_SE(a, b)
np.asarray(res1).shape

(10000, 15000)

In [23]:
res2 = new_SE(a, b)
np.asarray(res2).shape

(10000, 15000)

In [24]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [25]:
%timeit old_SE(a, b).block_until_ready()

294 ms ± 15.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [26]:
%timeit new_SE(a, b).block_until_ready()

34 ms ± 283 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### On two batches of arrays with common HP

In [27]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (50, 100))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (50, 150))
a.shape, b.shape

((50, 100), (50, 150))

In [28]:
# Using common hyperparameters for all batches
res1 = old_SE(a, b)
np.asarray(res1).shape

(50, 100, 150)

In [29]:
# Also using common hyperparameters for all batches
res2 = new_SE(a, b)
np.asarray(res2).shape

(50, 100, 150)

In [30]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [31]:
%timeit old_SE(a, b).block_until_ready()

1.97 ms ± 83 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [32]:
%timeit new_SE(a, b).block_until_ready()

324 μs ± 14.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### On two batches of arrays with distinct HP

In [33]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (50, 100))  # 5 batches of 10-dimensional data
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (50, 150))  # 5 batches of 15-dimensional data

# Create distinct hyperparameters for each batch
key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values
key, subkey = jax.random.split(key)
distinct_variances = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values

distinct_length_scales, distinct_variances

(Array([0.69834626, 1.1941253 , 0.74907124, 1.4650234 , 0.6623002 ,
        0.5150695 , 1.3077172 , 0.8598883 , 1.3379186 , 0.5345807 ,
        0.9925809 , 1.0119901 , 0.6867107 , 1.3121855 , 1.4377677 ,
        1.397761  , 0.66495323, 1.1306376 , 0.95631206, 0.6137631 ,
        0.86514413, 1.2786899 , 1.0824803 , 0.89162004, 1.080331  ,
        0.70636487, 0.88181865, 1.2177788 , 0.9001007 , 0.88983023,
        1.0880367 , 1.0784531 , 1.1646242 , 1.3984015 , 1.2634037 ,
        0.5314708 , 0.53026426, 1.309241  , 0.7298783 , 1.2418199 ,
        0.9763527 , 0.52760875, 1.0559071 , 1.2554548 , 1.3978969 ,
        1.2237661 , 0.7706127 , 1.2761428 , 1.1990191 , 0.59786415],      dtype=float32),
 Array([0.9009923 , 0.7235527 , 0.8779906 , 1.007704  , 0.8343816 ,
        0.961216  , 0.652822  , 0.7110219 , 0.84446895, 1.1748973 ,
        0.5548018 , 1.3568562 , 1.4311233 , 0.9435116 , 0.68918073,
        0.751114  , 1.4021089 , 1.2387986 , 0.9405923 , 1.2024236 ,
        0.8921603 , 0.5305

In [34]:
res1 = old_SE(a, b, length_scale=distinct_length_scales, variance=distinct_variances)

In [35]:
res2 = new_SE(a, b, length_scale=distinct_length_scales, variance=distinct_variances)

In [36]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [38]:
%timeit old_SE(a, b, length_scale=distinct_length_scales, variance=distinct_variances).block_until_ready()

1.79 ms ± 24.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [39]:
%timeit new_SE(a, b, length_scale=distinct_length_scales, variance=distinct_variances).block_until_ready()

311 μs ± 1.27 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


---
## Conclusion

---

## Conclusion

This benchmark highlights the fact that the new, jittable implementation is up to 20x faster on CPU, with little to no compromise on the API side.

The new implementation is also more flexible, as it allows for a mix of common and distinct hyperparameters for each batch, which can be useful in many cases.

Next steps to improve this implementation would be to check if the `vmap` calls are optimal (maybe lambda functions would work better with kwargs), and to test the implementation on GPU to see if the speedup is even more significant.