# Benchmarks - Kernels

In MagmaClust and GPs in general, kernels are what makes it possible to compute a covariance matrix for any set of points, given a set of hyperparameters. They are used in many innermost loops, and their implementation is critical for performance.

**Main considerations when implementing kernels**

A good kernel implementation must be:
* fast, as it is used in many innermost loops
* usable at many dimensions, including a batch dimension with distinct hyperparameters for each element
* work on padded inputs (aka inputs with NaNs), maybe using a mask
* jittable, as it is used in many jit-compiled functions
* modular, as kernels can be combined in many ways
* both static and instance-based, to either carry around hyperparameters or be called with them
* easy to override, as users may want to define their own kernels

These goals are conflicting in some cases (e.g: a jittable version of a kernel is not trivial to write for most people), and the best implementation will depend on the specific use case.

---
## Setup

In [1]:
# Standard library
import os

os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
from jax import numpy as jnp

import numpy as np
import pandas as pd

In [3]:
# Local
from MagmaClustPy.utils import generate_dummy_db, preprocess_db

In [4]:
# Config
key = jax.random.PRNGKey(0)

---
## Data

---
## Current implementation

In [5]:
import Kernax

---
## Custom implementation(s)

### Defaults of the previous implementation that we wish to correct/improve:

**Previous kernels were functions of R -> R, but we want them to be function of R^I -> R**

It's more common to see kernels used on vectors. We have to add one dimension to the inputs

In [6]:
from jax import jit, vmap
from jax.tree_util import register_pytree_node_class
from jax.lax import cond


@register_pytree_node_class
class AbstractKernel:
	def __init__(self, skip_check=False, **kwargs):
		if not skip_check:
			# Check that hyperparameters are all jnp arrays/scalars or kernels
			for key, value in kwargs.items():
				if not isinstance(value, jnp.ndarray):  # Check type
					kwargs[key] = jnp.array(value)
				if len(kwargs[key].shape) > 1:  # Check dimensionality
					raise ValueError(f"Parameter {key} must be a scalar or a 1D array, got shape {value.shape}.")

		# Register hyperparameters in *kwargs* as instance attributes
		self.__dict__.update(kwargs)

	def __str__(self):
		return f"{self.__class__.__name__}({', '.join([f'{key}={value}' for key, value in self.__dict__.items()])})"

	def __repr__(self):
		return str(self)

	@jit
	def check_kwargs(self, **kwargs):
		for key in self.__dict__:
			if key not in kwargs:
				kwargs[key] = self.__dict__[key]
		return kwargs

	@jit
	def __call__(self, x1, x2=None, **kwargs):
		# If no x2 is provided, we compute the covariance between x1 and itself
		if x2 is None:
			x2 = x1

		# Turn scalar inputs into vectors
		x1, x2 = jnp.atleast_2d(x1), jnp.atleast_2d(x2)

		# Check kwargs
		kwargs = self.check_kwargs(**kwargs)

		# Call the appropriate method
		if jnp.ndim(x1) == 1 and jnp.ndim(x2) == 1:
			return self.pairwise_cov_if_not_nan(x1, x2, **kwargs)
		elif jnp.ndim(x1) == 2 and jnp.ndim(x2) == 1:
			return self.cross_cov_vector_if_not_nan(x1, x2, **kwargs)
		elif jnp.ndim(x1) == 1 and jnp.ndim(x2) == 2:
			return self.cross_cov_vector_if_not_nan(x2, x1, **kwargs)
		elif jnp.ndim(x1) == 2 and jnp.ndim(x2) == 2:
			return self.cross_cov_matrix(x1, x2, **kwargs)
		elif jnp.ndim(x1) == 3 and jnp.ndim(x2) == 3:
			return self.cross_cov_batch(x1, x2, **kwargs)
		else:
			return jnp.nan

	# Methods to use Kernel as a PyTree
	def tree_flatten(self):
		return tuple(self.__dict__.values()), None  # No static values

	@classmethod
	def tree_unflatten(cls, _, children):
		# This class being abstract, this function fails when called on an "abstract instance",
		# as we don't know the number of parameters the constructor expects, yet we send it children.
		# On a subclass, this will work as expected as long as the constructor has a clear number of
		# kwargs as parameters.
		return cls(*children, skip_check=True)

	@jit
	def pairwise_cov_if_not_nan(self, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Returns NaN if either x1 or x2 is NaN, otherwise calls the compute_scalar method.

		:param x1: scalar array
		:param x2: scalar array
		:param kwargs: hyperparameters of the kernel
		:return: scalar array
		"""
		return cond(jnp.any(jnp.isnan(x1) | jnp.isnan(x2)), lambda _: jnp.nan,
		            lambda _: self.pairwise_cov(x1, x2, **kwargs), None)

	@jit
	def pairwise_cov(self, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Compute the kernel covariance value between two vectors.

		:param x1: scalar array
		:param x2: scalar array
		:param kwargs: hyperparameters of the kernel
		:return: scalar array
		"""
		return jnp.array(jnp.nan)  # To be overwritten in subclasses

	@jit
	def cross_cov_vector(self, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Compute the kernel cross covariance values between an array of vectors (matrix) and a vector.

		:param x1: vector array (N, )
		:param x2: scalar array
		:param kwargs: hyperparameters of the kernel
		:return: vector array (N, )
		"""
		return vmap(lambda x: self.pairwise_cov_if_not_nan(x, x2, **kwargs), in_axes=0)(x1)

	@jit
	def cross_cov_vector_if_not_nan(self, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Returns an array of NaN if scalar is NaN, otherwise calls the compute_vector method.

		:param x1: vector array (N, )
		:param x2: scalar array
		:param kwargs: hyperparameters of the kernel
		:return: vector array (N, )
		"""
		return cond(jnp.any(jnp.isnan(x2)), lambda _: jnp.full(len(x1), jnp.nan), lambda _: self.cross_cov_vector(x1, x2, **kwargs),
		            None)

	@jit
	def cross_cov_matrix(self, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Compute the kernel covariance matrix between two vector arrays.

		:param x1: vector array (N, )
		:param x2: vector array (M, )
		:param kwargs: hyperparameters of the kernel
		:return: matrix array (N, M)
		"""
		return vmap(lambda x: self.cross_cov_vector_if_not_nan(x2, x, **kwargs), in_axes=0)(x1)

	@jit
	def cross_cov_batch(self, x1: jnp.ndarray, x2: jnp.ndarray, **kwargs) -> jnp.ndarray:
		"""
		Compute the kernel covariance matrix between two batched vector arrays.

		:param x1: vector array (B, N)
		:param x2: vector array (B, M)
		:param kwargs: hyperparameters of the kernel. Each HP that is a scalar will be shared to the whole batch, and
		each HP that is a vector will be distinct and thus must have shape (B, )
		:return: tensor array (B, N, M)
		"""
		# vmap(self.compute_matrix)(x1, x2, **kwargs)
		shared_hps = {key: value for key, value in kwargs.items() if jnp.isscalar(value)}
		distinct_hps = {key: value for key, value in kwargs.items() if not jnp.isscalar(value)}

		return vmap(lambda x, y, hps: self.cross_cov_matrix(x, y, **hps, **shared_hps), in_axes=(0, 0, 0))(x1, x2, distinct_hps)


In [7]:
from jax import jit
from jax.tree_util import register_pytree_node_class
from jax import numpy as jnp


# from Kernax import AbstractKernel


@register_pytree_node_class
class RBFKernel(AbstractKernel):
	def __init__(self, length_scale=None, variance=None, **kwargs):
		if length_scale is None:
			length_scale = jnp.array([1.])
		if variance is None:
			variance = jnp.array([1.])
		super().__init__(length_scale=length_scale, variance=variance, **kwargs)

	@jit
	def pairwise_cov(self, x1: jnp.ndarray, x2: jnp.ndarray, length_scale=None, variance=None) -> jnp.ndarray:
		return variance * jnp.exp(-0.5 * ((x1 - x2) @ (x1 - x2)) / length_scale ** 2)

In [17]:
@register_pytree_node_class
class SEMagmaKernel(AbstractKernel):
	def __init__(self, length_scale=None, variance=None, **kwargs):
		if length_scale is None:
			length_scale = jnp.array([1.])
		if variance is None:
			variance = jnp.array([1.])
		super().__init__(length_scale=length_scale, variance=variance, **kwargs)

	@jit
	def pairwise_cov(self, x1: jnp.ndarray, x2: jnp.ndarray, length_scale=None, variance=None) -> jnp.ndarray:
		return jnp.exp(variance - jnp.exp(-length_scale) * jnp.sum((x1 - x2) ** 2) * 0.5)

In [34]:
@register_pytree_node_class
class NoisySEMagmaKernel(AbstractKernel):
	def __init__(self, length_scale=None, variance=None, noise=None, **kwargs):
		if noise is None:
			noise = jnp.array([-1.])
		super().__init__(length_scale=length_scale, variance=variance, noise=noise, **kwargs)

	@jit
	def pairwise_cov(self, x1: jnp.ndarray, x2: jnp.ndarray, length_scale=None, variance=None, noise=None) -> jnp.ndarray:
		return cond(jnp.all(x1 == x2),
		            lambda _: jnp.exp(variance - jnp.exp(-length_scale) * jnp.sum((x1 - x2) ** 2) * 0.5) + jnp.exp(noise),
		            lambda _: jnp.exp(variance - jnp.exp(-length_scale) * jnp.sum((x1 - x2) ** 2) * 0.5)
		            , None)

---
## Comparison

### On padded datasets

In [10]:
db = pd.read_csv("../datasets/small_shared_input_shared_hp.csv")
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)
padded_inputs.shape, padded_outputs.shape, mappings.shape, all_inputs.shape,

((20, 15, 1), (20, 15, 1), (20, 15), (15, 1))

In [35]:
old_kernel = Kernax.NoisySEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.), noise=jnp.array(2.5))
new_kernel = NoisySEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.), noise=jnp.array(2.5))

In [36]:
# Covariance on padded matrix
res1 = old_kernel(padded_inputs.squeeze())
res1.shape

(20, 15, 15)

In [37]:
# Covariance on un-padded matrix
# res2 = new_kernel(padded_inputs)
res2 = new_kernel(padded_inputs)
res2.shape

(20, 15, 15)

In [38]:
# Check that values in un-padded matrix correspond to the values in the padded matrix
jnp.allclose(jnp.nan_to_num(res1), jnp.nan_to_num(res2))

Array(True, dtype=bool)

In [39]:
%timeit old_kernel(padded_inputs.squeeze()).block_until_ready()

60.1 μs ± 1.07 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [40]:
%timeit new_kernel(padded_inputs).block_until_ready()

46.4 μs ± 1.03 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [41]:
%timeit old_kernel(padded_inputs[0].squeeze()).block_until_ready()

96.3 μs ± 1.4 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [42]:
%timeit new_kernel(padded_inputs[0]).block_until_ready()

86.9 μs ± 1.82 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [20]:
%timeit new_kernel.cross_covariance(padded_inputs[0].reshape(-1, 1), padded_inputs[0].reshape(-1, 1)).block_until_ready()

2.41 ms ± 49.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [21]:
jnp.nansum(masks, axis=1).max()

Array(90, dtype=int64)

In [22]:
# Equivalent dense inputs
key, subkey = jax.random.split(key)
dense_inputs = jax.random.uniform(subkey, (50, int(jnp.nansum(masks, axis=1).max().item())))

In [23]:
%timeit old_kernel(dense_inputs).block_until_ready()

482 μs ± 11.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [24]:
%timeit new_kernel.compute_tensor(dense_inputs).block_until_ready()

3.99 ms ± 202 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [25]:
padded_inputs.shape, dense_inputs.shape

((50, 398), (50, 90))

### On scalars

In [26]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, ())
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, ())
a, b

(Array(0.67119299, dtype=float64), Array(0.35166634, dtype=float64))

In [27]:
res1 = old_kernel(a, b)
np.asarray(res1)

array(0.56710711)

In [28]:
res2 = new_kernel(a.reshape(-1, 1), b.reshape(-1, 1))
np.asarray(res2)

array(0.56710711)

In [29]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [30]:
%timeit old_kernel(a, b).block_until_ready()

5.78 μs ± 29.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [31]:
%timeit new_kernel(a.reshape(-1, 1), b.reshape(-1, 1)).block_until_ready()

147 μs ± 1.39 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### On an array and a scalar

In [32]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (10000,))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, ())
a, b

(Array([0.15974131, 0.14370377, 0.62062701, ..., 0.7647444 , 0.49158483,
        0.34969892], dtype=float64),
 Array(0.96106866, dtype=float64))

In [33]:
res1 = old_kernel(a, b)
np.asarray(res1)

array([0.02823017, 0.02443875, 0.52524432, ..., 0.80724426, 0.29389634,
       0.12536618])

In [34]:
res2 = new_kernel(a, b)
np.asarray(res2)

IndexError: Too many indices: 0-dimensional array indexed with 1 regular index.

In [42]:
jnp.allclose(res1, res2)

Array(False, dtype=bool)

In [43]:
%timeit old_kernel(a, b).block_until_ready()

47.1 ms ± 3.44 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [44]:
%timeit new_kernel(a, b).block_until_ready()

33.9 ms ± 9.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### On two arrays

In [45]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (10000,))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (15000,))
a.shape, b.shape

((10000,), (15000,))

In [46]:
res1 = old_kernel(a, b)
np.asarray(res1).shape

(10000, 15000)

In [47]:
res2 = new_kernel(a, b)
np.asarray(res2).shape

(10000, 15000)

In [48]:
jnp.allclose(res1, res2)

Array(False, dtype=bool)

In [49]:
%timeit old_kernel(a, b).block_until_ready()

1min 10s ± 5.42 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [50]:
%timeit new_kernel(a, b).block_until_ready()

14.5 s ± 4.42 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


### On two batches of arrays with shared HP

In [51]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (50, 100))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (50, 150))
a.shape, b.shape

((50, 100), (50, 150))

In [52]:
# Using shared hyperparameters for all batches
res1 = old_kernel(a, b)
np.asarray(res1).shape

{'inner_kernel': NewSEMagmaKernel(length_scale=0.3, variance=1.0), 'noise': Array(-2.5, dtype=float64, weak_type=True)} {'length_scale': Array(0.3, dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)} {'inner_kernel': NewSEMagmaKernel(length_scale=0.3, variance=1.0), 'noise': Array(-2.5, dtype=float64, weak_type=True), 'length_scale': Array(0.3, dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)}


(50, 100, 150)

In [53]:
# Also using shared hyperparameters for all batches
res2 = new_kernel(a, b)
np.asarray(res2).shape

{'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)} {} {'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)}


(50, 100, 150)

In [54]:
jnp.allclose(res1, res2)

Array(False, dtype=bool)

In [55]:
%timeit old_kernel(a, b).block_until_ready()

{'inner_kernel': NewSEMagmaKernel(length_scale=0.3, variance=1.0), 'noise': Array(-2.5, dtype=float64, weak_type=True)} {'length_scale': Array(0.3, dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)} {'inner_kernel': NewSEMagmaKernel(length_scale=0.3, variance=1.0), 'noise': Array(-2.5, dtype=float64, weak_type=True), 'length_scale': Array(0.3, dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)}
{'inner_kernel': NewSEMagmaKernel(length_scale=0.3, variance=1.0), 'noise': Array(-2.5, dtype=float64, weak_type=True)} {'length_scale': Array(0.3, dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)} {'inner_kernel': NewSEMagmaKernel(length_scale=0.3, variance=1.0), 'noise': Array(-2.5, dtype=float64, weak_type=True), 'length_scale': Array(0.3, dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)}
{'inner_kernel': NewSEMagmaKernel(length_scale=0.3, variance=1.0),

In [56]:
%timeit new_kernel(a, b).block_until_ready()

{'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)} {} {'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)}
{'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)} {} {'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)}
{'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)} {} {'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)}
{'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)} {} {'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dtype=float64, weak_type=True)}
{'length_scale': Array(1., dtype=float64, weak_type=True), 'variance': Array(1., dty

### On two batches of arrays with distinct HP

In [57]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (50, 100))  # 5 batches of 10-dimensional data
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (50, 150))  # 5 batches of 15-dimensional data

# Create distinct hyperparameters for each batch
key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values
key, subkey = jax.random.split(key)
distinct_variances = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values

distinct_length_scales, distinct_variances

(Array([1.40087618, 0.50347402, 0.81557193, 1.34923058, 1.00457197,
        0.65788761, 0.9219105 , 0.72545942, 1.23310028, 0.91450771,
        0.52153169, 1.07353706, 0.84001156, 0.63915561, 1.28793987,
        0.960267  , 0.5185922 , 1.46299715, 1.2229984 , 0.73096194,
        1.47447511, 1.41269148, 0.62063165, 1.39239143, 0.56357474,
        0.95879739, 1.23934137, 0.66682831, 1.27697895, 1.41619929,
        1.1099631 , 1.14399616, 1.30613247, 0.78241536, 1.13752401,
        0.7997413 , 0.86489189, 1.1392135 , 1.05768611, 0.64485643,
        0.78303156, 0.61316633, 0.53466457, 0.93155447, 0.81706077,
        0.89970457, 0.6447488 , 1.16024019, 1.30514671, 0.9256145 ],      dtype=float64),
 Array([1.34874564, 1.01689555, 1.11981318, 0.96395389, 0.73946461,
        0.57796114, 0.86174911, 0.87934257, 1.08225335, 1.42074674,
        1.38278328, 0.69312996, 0.62528442, 0.52303255, 1.35397314,
        0.62574799, 1.25923923, 0.82557538, 1.32550512, 1.13693679,
        0.61045989, 0.9583

In [58]:
res1 = old_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances)

{'length_scale': Array([1.40087618, 0.50347402, 0.81557193, 1.34923058, 1.00457197,
       0.65788761, 0.9219105 , 0.72545942, 1.23310028, 0.91450771,
       0.52153169, 1.07353706, 0.84001156, 0.63915561, 1.28793987,
       0.960267  , 0.5185922 , 1.46299715, 1.2229984 , 0.73096194,
       1.47447511, 1.41269148, 0.62063165, 1.39239143, 0.56357474,
       0.95879739, 1.23934137, 0.66682831, 1.27697895, 1.41619929,
       1.1099631 , 1.14399616, 1.30613247, 0.78241536, 1.13752401,
       0.7997413 , 0.86489189, 1.1392135 , 1.05768611, 0.64485643,
       0.78303156, 0.61316633, 0.53466457, 0.93155447, 0.81706077,
       0.89970457, 0.6447488 , 1.16024019, 1.30514671, 0.9256145 ],      dtype=float64), 'variance': Array([1.34874564, 1.01689555, 1.11981318, 0.96395389, 0.73946461,
       0.57796114, 0.86174911, 0.87934257, 1.08225335, 1.42074674,
       1.38278328, 0.69312996, 0.62528442, 0.52303255, 1.35397314,
       0.62574799, 1.25923923, 0.82557538, 1.32550512, 1.13693679,
       0.61

In [59]:
res2 = new_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances)

{'length_scale': Array([1.40087618, 0.50347402, 0.81557193, 1.34923058, 1.00457197,
       0.65788761, 0.9219105 , 0.72545942, 1.23310028, 0.91450771,
       0.52153169, 1.07353706, 0.84001156, 0.63915561, 1.28793987,
       0.960267  , 0.5185922 , 1.46299715, 1.2229984 , 0.73096194,
       1.47447511, 1.41269148, 0.62063165, 1.39239143, 0.56357474,
       0.95879739, 1.23934137, 0.66682831, 1.27697895, 1.41619929,
       1.1099631 , 1.14399616, 1.30613247, 0.78241536, 1.13752401,
       0.7997413 , 0.86489189, 1.1392135 , 1.05768611, 0.64485643,
       0.78303156, 0.61316633, 0.53466457, 0.93155447, 0.81706077,
       0.89970457, 0.6447488 , 1.16024019, 1.30514671, 0.9256145 ],      dtype=float64), 'variance': Array([1.34874564, 1.01689555, 1.11981318, 0.96395389, 0.73946461,
       0.57796114, 0.86174911, 0.87934257, 1.08225335, 1.42074674,
       1.38278328, 0.69312996, 0.62528442, 0.52303255, 1.35397314,
       0.62574799, 1.25923923, 0.82557538, 1.32550512, 1.13693679,
       0.61

In [60]:
jnp.allclose(res1, res2)

Array(False, dtype=bool)

In [61]:
%timeit old_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances).block_until_ready()

{'length_scale': Array([1.40087618, 0.50347402, 0.81557193, 1.34923058, 1.00457197,
       0.65788761, 0.9219105 , 0.72545942, 1.23310028, 0.91450771,
       0.52153169, 1.07353706, 0.84001156, 0.63915561, 1.28793987,
       0.960267  , 0.5185922 , 1.46299715, 1.2229984 , 0.73096194,
       1.47447511, 1.41269148, 0.62063165, 1.39239143, 0.56357474,
       0.95879739, 1.23934137, 0.66682831, 1.27697895, 1.41619929,
       1.1099631 , 1.14399616, 1.30613247, 0.78241536, 1.13752401,
       0.7997413 , 0.86489189, 1.1392135 , 1.05768611, 0.64485643,
       0.78303156, 0.61316633, 0.53466457, 0.93155447, 0.81706077,
       0.89970457, 0.6447488 , 1.16024019, 1.30514671, 0.9256145 ],      dtype=float64), 'variance': Array([1.34874564, 1.01689555, 1.11981318, 0.96395389, 0.73946461,
       0.57796114, 0.86174911, 0.87934257, 1.08225335, 1.42074674,
       1.38278328, 0.69312996, 0.62528442, 0.52303255, 1.35397314,
       0.62574799, 1.25923923, 0.82557538, 1.32550512, 1.13693679,
       0.61

In [62]:
%timeit new_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances).block_until_ready()

{'length_scale': Array([1.40087618, 0.50347402, 0.81557193, 1.34923058, 1.00457197,
       0.65788761, 0.9219105 , 0.72545942, 1.23310028, 0.91450771,
       0.52153169, 1.07353706, 0.84001156, 0.63915561, 1.28793987,
       0.960267  , 0.5185922 , 1.46299715, 1.2229984 , 0.73096194,
       1.47447511, 1.41269148, 0.62063165, 1.39239143, 0.56357474,
       0.95879739, 1.23934137, 0.66682831, 1.27697895, 1.41619929,
       1.1099631 , 1.14399616, 1.30613247, 0.78241536, 1.13752401,
       0.7997413 , 0.86489189, 1.1392135 , 1.05768611, 0.64485643,
       0.78303156, 0.61316633, 0.53466457, 0.93155447, 0.81706077,
       0.89970457, 0.6447488 , 1.16024019, 1.30514671, 0.9256145 ],      dtype=float64), 'variance': Array([1.34874564, 1.01689555, 1.11981318, 0.96395389, 0.73946461,
       0.57796114, 0.86174911, 0.87934257, 1.08225335, 1.42074674,
       1.38278328, 0.69312996, 0.62528442, 0.52303255, 1.35397314,
       0.62574799, 1.25923923, 0.82557538, 1.32550512, 1.13693679,
       0.61

### On two batches of arrays with both shared and distinct HP

---
## Conclusion

The conditional logic in `compute_vector` that simplifies computations on NaNs gives us roughly a 2x improvement in speed on padded matrices. However, this solution is still 6x slower than computing the covariance batch on a "denser" input. *In fine*, computing a batch kernel is a $O(M*N^2)$ operation. Computing covariance on smallest possible (meaning "padded but not aligned") vectors (N = 90) is obviously faster than on fully padded and aligned ones (N = 400).

It is interesting to note that this speed gain is only obtained if no implementation of `compute_matrix` and `compute_batch` is given, as those implementations would skip `compute_vector`.

**In the future, we should explore whether computing kernels on "compact" padded inputs and then map the values to an "aligned" batch preserves this speed gain. If so, it might be useful to consider using this "compact" representation throughout the whole algorithm.**

---