# Benchmarks - Kernels

In MagmaClust and GPs in general, kernels are what makes it possible to compute a covariance matrix for any set of points, given a set of hyperparameters. They are used in many innermost loops, and their implementation is critical for performance.

**Main considerations when implementing kernels**

A good kernel implementation must be:
* fast, as it is used in many innermost loops
* usable at many dimensions, including a batch dimension with distinct hyperparameters for each element
* work on padded inputs (aka inputs with NaNs), maybe using a mask
* jittable, as it is used in many jit-compiled functions
* modular, as kernels can be combined in many ways
* both static and instance-based, to either carry around hyperparameters or be called with them
* easy to override, as users may want to define their own kernels

These goals are conflicting in some cases (e.g: a jittable version of a kernel is not trivial to write for most people), and the best implementation will depend on the specific use case.

---
## Setup

In [1]:
# Standard library
import os

os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
from jax import jit, vmap
from jax.tree_util import register_pytree_node_class
from jax import numpy as jnp
from jax import lax
from jax.lax import cond

import gpjax as gpx
from flax import nnx

import numpy as np

In [3]:
# Local
from Kernax import RBFKernel, SEMagmaKernel, NoisySEMagmaKernel
from MagmaClustPy.utils import generate_dummy_db, preprocess_db

In [4]:
# Config
key = jax.random.PRNGKey(0)

---
## Data

---
## Current implementation

In [5]:
old_kernel = RBFKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))

---
## Custom implementation(s)

### Defaults of the previous implementation that we wish to correct/improve:

**Kernels on padded arrays are slow**

That's because of two things:
- we compute the kernel on every pair of inputs, regardless of whether they are NaN or not
- the "aligned" padded arrays are huge, but kernels do not necessitate the inputs to be aligned

---
## Comparison

In [6]:
new_kernel = gpx.kernels.RBF(variance = 1., lengthscale = 0.3)

In [7]:
#@jit
def _compute_tensor(x):
    return vmap(lambda i: new_kernel.cross_covariance(i.reshape(-1, 1), i.reshape(-1, 1)))(x)

In [8]:
new_kernel.compute_tensor = _compute_tensor

In [9]:
test_inputs = jnp.array([0.40, 4.45, 7.60, 8.30, 3.50, 5.10, 8.85, 9.35]).reshape(-1, 1)

In [10]:
np.asarray(old_kernel(test_inputs.flatten()))

array([[1.00000000e+000, 2.66020642e-040, 8.37894253e-126,
        2.63300774e-151, 6.50878852e-024, 5.03983226e-054,
        5.29448405e-173, 5.40659309e-194],
       [2.66020642e-040, 1.00000000e+000, 1.14687658e-024,
        1.72605960e-036, 6.64501128e-003, 9.56344448e-002,
        1.94632663e-047, 1.17473960e-058],
       [8.37894253e-126, 1.14687658e-024, 1.00000000e+000,
        6.57285286e-002, 2.76516394e-041, 8.32396968e-016,
        1.69856677e-004, 4.08283604e-008],
       [2.63300774e-151, 1.72605960e-036, 6.57285286e-002,
        1.00000000e+000, 2.57220937e-056, 1.96548382e-025,
        1.86270464e-001, 2.18749112e-003],
       [6.50878852e-024, 6.64501128e-003, 2.76516394e-041,
        2.57220937e-056, 1.00000000e+000, 6.65836147e-007,
        8.73263905e-070, 2.69005790e-083],
       [5.03983226e-054, 9.56344448e-002, 8.32396968e-016,
        1.96548382e-025, 6.65836147e-007, 1.00000000e+000,
        1.17691094e-034, 2.62878528e-044],
       [5.29448405e-173, 1.9463266

In [11]:
np.asarray(new_kernel.cross_covariance(test_inputs, test_inputs))

array([[1.00000000e+000, 2.66020642e-040, 8.37894253e-126,
        2.63300774e-151, 6.50878852e-024, 5.03983226e-054,
        5.29448405e-173, 5.40659309e-194],
       [2.66020642e-040, 1.00000000e+000, 1.14687658e-024,
        1.72605960e-036, 6.64501128e-003, 9.56344448e-002,
        1.94632663e-047, 1.17473960e-058],
       [8.37894253e-126, 1.14687658e-024, 1.00000000e+000,
        6.57285286e-002, 2.76516394e-041, 8.32396968e-016,
        1.69856677e-004, 4.08283604e-008],
       [2.63300774e-151, 1.72605960e-036, 6.57285286e-002,
        1.00000000e+000, 2.57220937e-056, 1.96548382e-025,
        1.86270464e-001, 2.18749112e-003],
       [6.50878852e-024, 6.64501128e-003, 2.76516394e-041,
        2.57220937e-056, 1.00000000e+000, 6.65836147e-007,
        8.73263905e-070, 2.69005790e-083],
       [5.03983226e-054, 9.56344448e-002, 8.32396968e-016,
        1.96548382e-025, 6.65836147e-007, 1.00000000e+000,
        1.17691094e-034, 2.62878528e-044],
       [5.29448405e-173, 1.9463266

### On padded datasets

In [12]:
grid = jnp.arange(-200, 200, 1, dtype=jnp.float64)
db = generate_dummy_db(50, 10, 100, grid, key)
all_inputs, padded_inputs, padded_outputs, masks = preprocess_db(db)
all_inputs.shape, padded_inputs.shape, padded_outputs.shape, masks.shape

((398,), (50, 398), (50, 398), (50, 398))

In [13]:
# Covariance on padded matrix
res1 = old_kernel(padded_inputs)
res1.shape

(50, 398, 398)

In [14]:
# Covariance on un-padded matrix
# res2 = new_kernel(padded_inputs)
res2 = new_kernel.compute_tensor(padded_inputs)
res2.shape

(50, 398, 398)

In [15]:
# Check that values in un-padded matrix correspond to the values in the padded matrix
jnp.allclose(jnp.nan_to_num(res1), jnp.nan_to_num(res2))

Array(True, dtype=bool)

In [16]:
%timeit old_kernel(padded_inputs).block_until_ready()

3.26 ms ± 82.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [17]:
%timeit new_kernel.compute_tensor(padded_inputs).block_until_ready()

23.6 ms ± 809 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [18]:
%timeit old_kernel(padded_inputs[0]).block_until_ready()

136 μs ± 1.91 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [19]:
%timeit new_kernel.compute_tensor(padded_inputs[0]).block_until_ready()

2.68 ms ± 26.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [20]:
%timeit new_kernel.cross_covariance(padded_inputs[0].reshape(-1, 1), padded_inputs[0].reshape(-1, 1)).block_until_ready()

2.41 ms ± 49.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [21]:
jnp.nansum(masks, axis=1).max()

Array(90, dtype=int64)

In [22]:
# Equivalent dense inputs
key, subkey = jax.random.split(key)
dense_inputs = jax.random.uniform(subkey, (50, int(jnp.nansum(masks, axis=1).max().item())))

In [23]:
%timeit old_kernel(dense_inputs).block_until_ready()

482 μs ± 11.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [24]:
%timeit new_kernel.compute_tensor(dense_inputs).block_until_ready()

3.99 ms ± 202 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [25]:
padded_inputs.shape, dense_inputs.shape

((50, 398), (50, 90))

### On scalars

In [26]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, ())
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, ())
a, b

(Array(0.67119299, dtype=float64), Array(0.35166634, dtype=float64))

In [27]:
res1 = old_kernel(a, b)
np.asarray(res1)

array(0.56710711)

In [28]:
res2 = new_kernel(a.reshape(-1, 1), b.reshape(-1, 1))
np.asarray(res2)

array(0.56710711)

In [29]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [30]:
%timeit old_kernel(a, b).block_until_ready()

5.78 μs ± 29.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [31]:
%timeit new_kernel(a.reshape(-1, 1), b.reshape(-1, 1)).block_until_ready()

147 μs ± 1.39 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### On an array and a scalar

In [32]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (10000,))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, ())
a, b

(Array([0.15974131, 0.14370377, 0.62062701, ..., 0.7647444 , 0.49158483,
        0.34969892], dtype=float64),
 Array(0.96106866, dtype=float64))

In [33]:
res1 = old_kernel(a, b)
np.asarray(res1)

array([0.02823017, 0.02443875, 0.52524432, ..., 0.80724426, 0.29389634,
       0.12536618])

In [34]:
res2 = new_kernel(a, b)
np.asarray(res2)

IndexError: Too many indices: 0-dimensional array indexed with 1 regular index.

In [42]:
jnp.allclose(res1, res2)

Array(False, dtype=bool)

In [43]:
%timeit old_kernel(a, b).block_until_ready()

47.1 ms ± 3.44 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [44]:
%timeit new_kernel(a, b).block_until_ready()

33.9 ms ± 9.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### On two arrays

In [45]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (10000,))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (15000,))
a.shape, b.shape

((10000,), (15000,))

In [46]:
res1 = old_kernel(a, b)
np.asarray(res1).shape

(10000, 15000)

In [47]:
res2 = new_kernel(a, b)
np.asarray(res2).shape

(10000, 15000)

In [48]:
jnp.allclose(res1, res2)

Array(False, dtype=bool)

In [49]:
%timeit old_kernel(a, b).block_until_ready()

1min 10s ± 5.42 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [50]:
%timeit new_kernel(a, b).block_until_ready()

14.5 s ± 4.42 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


### On two batches of arrays with shared HP

In [51]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (50, 100))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (50, 150))
a.shape, b.shape

((50, 100), (50, 150))

In [52]:
# Using shared hyperparameters for all batches
res1 = old_kernel(a, b)
np.asarray(res1).shape

{'inner_kernel': NewSEMagmaKernel(length_scale=0.3, variance=1.0), 'noise': Array(-2.5, dtype=float64, weak_type=True)} {'length_scale': Array(0.3, dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)} {'inner_kernel': NewSEMagmaKernel(length_scale=0.3, variance=1.0), 'noise': Array(-2.5, dtype=float64, weak_type=True), 'length_scale': Array(0.3, dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)}


(50, 100, 150)

In [53]:
# Also using shared hyperparameters for all batches
res2 = new_kernel(a, b)
np.asarray(res2).shape

{'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)} {} {'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)}


(50, 100, 150)

In [54]:
jnp.allclose(res1, res2)

Array(False, dtype=bool)

In [55]:
%timeit old_kernel(a, b).block_until_ready()

{'inner_kernel': NewSEMagmaKernel(length_scale=0.3, variance=1.0), 'noise': Array(-2.5, dtype=float64, weak_type=True)} {'length_scale': Array(0.3, dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)} {'inner_kernel': NewSEMagmaKernel(length_scale=0.3, variance=1.0), 'noise': Array(-2.5, dtype=float64, weak_type=True), 'length_scale': Array(0.3, dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)}
{'inner_kernel': NewSEMagmaKernel(length_scale=0.3, variance=1.0), 'noise': Array(-2.5, dtype=float64, weak_type=True)} {'length_scale': Array(0.3, dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)} {'inner_kernel': NewSEMagmaKernel(length_scale=0.3, variance=1.0), 'noise': Array(-2.5, dtype=float64, weak_type=True), 'length_scale': Array(0.3, dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)}
{'inner_kernel': NewSEMagmaKernel(length_scale=0.3, variance=1.0),

In [56]:
%timeit new_kernel(a, b).block_until_ready()

{'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)} {} {'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)}
{'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)} {} {'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)}
{'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)} {} {'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)}
{'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)} {} {'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)}
{'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dty

### On two batches of arrays with distinct HP

In [57]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (50, 100))  # 5 batches of 10-dimensional data
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (50, 150))  # 5 batches of 15-dimensional data

# Create distinct hyperparameters for each batch
key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values
key, subkey = jax.random.split(key)
distinct_variances = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values

distinct_length_scales, distinct_variances

(Array([1.40087618, 0.50347402, 0.81557193, 1.34923058, 1.00457197,
        0.65788761, 0.9219105 , 0.72545942, 1.23310028, 0.91450771,
        0.52153169, 1.07353706, 0.84001156, 0.63915561, 1.28793987,
        0.960267  , 0.5185922 , 1.46299715, 1.2229984 , 0.73096194,
        1.47447511, 1.41269148, 0.62063165, 1.39239143, 0.56357474,
        0.95879739, 1.23934137, 0.66682831, 1.27697895, 1.41619929,
        1.1099631 , 1.14399616, 1.30613247, 0.78241536, 1.13752401,
        0.7997413 , 0.86489189, 1.1392135 , 1.05768611, 0.64485643,
        0.78303156, 0.61316633, 0.53466457, 0.93155447, 0.81706077,
        0.89970457, 0.6447488 , 1.16024019, 1.30514671, 0.9256145 ],      dtype=float64),
 Array([1.34874564, 1.01689555, 1.11981318, 0.96395389, 0.73946461,
        0.57796114, 0.86174911, 0.87934257, 1.08225335, 1.42074674,
        1.38278328, 0.69312996, 0.62528442, 0.52303255, 1.35397314,
        0.62574799, 1.25923923, 0.82557538, 1.32550512, 1.13693679,
        0.61045989, 0.9583

In [58]:
res1 = old_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances)

{'length_scale': Array([1.40087618, 0.50347402, 0.81557193, 1.34923058, 1.00457197,
       0.65788761, 0.9219105 , 0.72545942, 1.23310028, 0.91450771,
       0.52153169, 1.07353706, 0.84001156, 0.63915561, 1.28793987,
       0.960267  , 0.5185922 , 1.46299715, 1.2229984 , 0.73096194,
       1.47447511, 1.41269148, 0.62063165, 1.39239143, 0.56357474,
       0.95879739, 1.23934137, 0.66682831, 1.27697895, 1.41619929,
       1.1099631 , 1.14399616, 1.30613247, 0.78241536, 1.13752401,
       0.7997413 , 0.86489189, 1.1392135 , 1.05768611, 0.64485643,
       0.78303156, 0.61316633, 0.53466457, 0.93155447, 0.81706077,
       0.89970457, 0.6447488 , 1.16024019, 1.30514671, 0.9256145 ],      dtype=float64), 'variance': Array([1.34874564, 1.01689555, 1.11981318, 0.96395389, 0.73946461,
       0.57796114, 0.86174911, 0.87934257, 1.08225335, 1.42074674,
       1.38278328, 0.69312996, 0.62528442, 0.52303255, 1.35397314,
       0.62574799, 1.25923923, 0.82557538, 1.32550512, 1.13693679,
       0.61

In [59]:
res2 = new_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances)

{'length_scale': Array([1.40087618, 0.50347402, 0.81557193, 1.34923058, 1.00457197,
       0.65788761, 0.9219105 , 0.72545942, 1.23310028, 0.91450771,
       0.52153169, 1.07353706, 0.84001156, 0.63915561, 1.28793987,
       0.960267  , 0.5185922 , 1.46299715, 1.2229984 , 0.73096194,
       1.47447511, 1.41269148, 0.62063165, 1.39239143, 0.56357474,
       0.95879739, 1.23934137, 0.66682831, 1.27697895, 1.41619929,
       1.1099631 , 1.14399616, 1.30613247, 0.78241536, 1.13752401,
       0.7997413 , 0.86489189, 1.1392135 , 1.05768611, 0.64485643,
       0.78303156, 0.61316633, 0.53466457, 0.93155447, 0.81706077,
       0.89970457, 0.6447488 , 1.16024019, 1.30514671, 0.9256145 ],      dtype=float64), 'variance': Array([1.34874564, 1.01689555, 1.11981318, 0.96395389, 0.73946461,
       0.57796114, 0.86174911, 0.87934257, 1.08225335, 1.42074674,
       1.38278328, 0.69312996, 0.62528442, 0.52303255, 1.35397314,
       0.62574799, 1.25923923, 0.82557538, 1.32550512, 1.13693679,
       0.61

In [60]:
jnp.allclose(res1, res2)

Array(False, dtype=bool)

In [61]:
%timeit old_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances).block_until_ready()

{'length_scale': Array([1.40087618, 0.50347402, 0.81557193, 1.34923058, 1.00457197,
       0.65788761, 0.9219105 , 0.72545942, 1.23310028, 0.91450771,
       0.52153169, 1.07353706, 0.84001156, 0.63915561, 1.28793987,
       0.960267  , 0.5185922 , 1.46299715, 1.2229984 , 0.73096194,
       1.47447511, 1.41269148, 0.62063165, 1.39239143, 0.56357474,
       0.95879739, 1.23934137, 0.66682831, 1.27697895, 1.41619929,
       1.1099631 , 1.14399616, 1.30613247, 0.78241536, 1.13752401,
       0.7997413 , 0.86489189, 1.1392135 , 1.05768611, 0.64485643,
       0.78303156, 0.61316633, 0.53466457, 0.93155447, 0.81706077,
       0.89970457, 0.6447488 , 1.16024019, 1.30514671, 0.9256145 ],      dtype=float64), 'variance': Array([1.34874564, 1.01689555, 1.11981318, 0.96395389, 0.73946461,
       0.57796114, 0.86174911, 0.87934257, 1.08225335, 1.42074674,
       1.38278328, 0.69312996, 0.62528442, 0.52303255, 1.35397314,
       0.62574799, 1.25923923, 0.82557538, 1.32550512, 1.13693679,
       0.61

In [62]:
%timeit new_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances).block_until_ready()

{'length_scale': Array([1.40087618, 0.50347402, 0.81557193, 1.34923058, 1.00457197,
       0.65788761, 0.9219105 , 0.72545942, 1.23310028, 0.91450771,
       0.52153169, 1.07353706, 0.84001156, 0.63915561, 1.28793987,
       0.960267  , 0.5185922 , 1.46299715, 1.2229984 , 0.73096194,
       1.47447511, 1.41269148, 0.62063165, 1.39239143, 0.56357474,
       0.95879739, 1.23934137, 0.66682831, 1.27697895, 1.41619929,
       1.1099631 , 1.14399616, 1.30613247, 0.78241536, 1.13752401,
       0.7997413 , 0.86489189, 1.1392135 , 1.05768611, 0.64485643,
       0.78303156, 0.61316633, 0.53466457, 0.93155447, 0.81706077,
       0.89970457, 0.6447488 , 1.16024019, 1.30514671, 0.9256145 ],      dtype=float64), 'variance': Array([1.34874564, 1.01689555, 1.11981318, 0.96395389, 0.73946461,
       0.57796114, 0.86174911, 0.87934257, 1.08225335, 1.42074674,
       1.38278328, 0.69312996, 0.62528442, 0.52303255, 1.35397314,
       0.62574799, 1.25923923, 0.82557538, 1.32550512, 1.13693679,
       0.61

### On two batches of arrays with both shared and distinct HP

---
## Conclusion

The conditional logic in `compute_vector` that simplifies computations on NaNs gives us roughly a 2x improvement in speed on padded matrices. However, this solution is still 6x slower than computing the covariance batch on a "denser" input. *In fine*, computing a batch kernel is a $O(M*N^2)$ operation. Computing covariance on smallest possible (meaning "padded but not aligned") vectors (N = 90) is obviously faster than on fully padded and aligned ones (N = 400).

It is interesting to note that this speed gain is only obtained if no implementation of `compute_matrix` and `compute_batch` is given, as those implementations would skip `compute_vector`.

**In the future, we should explore whether computing kernels on "compact" padded inputs and then map the values to an "aligned" batch preserves this speed gain. If so, it might be useful to consider using this "compact" representation throughout the whole algorithm.**

---