In [13]:
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import vmap, jit, Array
import equinox as eqx
from equinox import filter_jit

In [14]:
from kernax import AbstractKernel, StaticAbstractKernel, WrapperKernel

In [15]:
class BlockKernel(WrapperKernel):
	"""
	Wrapper kernel to build block covariance matrices using any kernel.

	A basic kernel usually works on inputs of shape (N, I), and produces covariance matrices of shape (N, N).

	Wrapped inside a block kernel, they can either:
	- still work on inputs of shape (N, I), but produce covariance matrices of shape (B*N, B*N), where B is the number of blocks. This is useful when the hyperparameters are distinct to blocks, i.e. each sub-matrix has its own set of hyperparameters.
	- or work on inputs of shape (B, N, I), producing covariance matrices of shape (B*N, B*N). This is useful when inputs are different for each block, regardless of whether the hyperparameters are shared between blocks or not.

	This class uses vmap to vectorize the kernel computation of each block, then resize the result into a block matrix.
	"""

	inner_kernel: AbstractKernel = eqx.field()
	nb_blocks: int = eqx.field(static=True)
	block_in_axes: bool = eqx.field(static=True)
	block_over_inputs: int | None = eqx.field(static=True)

	def __init__(self, inner_kernel, nb_blocks, block_in_axes=None, block_over_inputs=True):
		"""
		:param inner_kernel: the kernel to wrap, must be an instance of AbstractKernel
		:param nb_blocks: the number of blocks
		:param block_in_axes: a pytree indicating which hyperparameters change across blocks.
								If 0, the hyperparameter changes across the columns of the block matrix.
								If 1, the hyperparameter changes across the rows of the block matrix.
								If None, the hyperparameter is shared across all blocks.
								To compute the block matrix, the kernel needs to have at least one of its hyperparameters changing across rows and one across columns.
		:param block_over_inputs: whether to expect inputs of shape (B, N, I) (True) or (N, I) (False)

		N.b: the result of this kernel is not always a valid covariance matrix! For example, an RBF kernel with a varying lengthscale across rows and a varying amplitude across column will not produce a symmetric matrix, hence giving an invalid covariance matrix.
		Usually, you want to use this kernel with an appropriate inner_kernel, calculating a function where two hyper-parameters have symmetric roles.
		A good example is a multi-output (convolutional) kernel in GPs, which usually have two distinct lengthscales (and variances) depending on which output dimension is considered.
		"""
		# Initialize the WrapperKernel
		super().__init__(inner_kernel=inner_kernel)

		# TODO: explicit error message when nb_blocks is 1, as vmap is not needed then
		# TODO: check that at least one hyperparameter varies across rows and one across columns

		self.nb_blocks = nb_blocks

		# Default: all array hyperparameters are shared (None for all array leaves)
		if block_in_axes is None:
			# Extract only array leaves and map them to None
			self.block_in_axes = jtu.tree_map(lambda _: None, inner_kernel)
		else:
			self.block_in_axes = block_in_axes

		self.block_over_inputs = 0 if block_over_inputs else None

		# Add batch dimension to parameters where batch_in_axes is 0
		self.inner_kernel = jtu.tree_map(
			lambda param, block_in_ax: (
				param if block_in_ax is None else jnp.repeat(param[None, ...], nb_blocks, axis=0)
			),
			self.inner_kernel,
			self.block_in_axes,
		)

	@filter_jit
	def __call__(self, x1: jnp.ndarray, x2: None | jnp.ndarray = None) -> jnp.ndarray:
		"""
		Compute the kernel over batched inputs using vmap.

		Args:
				x1: Input of shape (B, ..., N, I)
				x2: Optional second input of shape (B, ..., M, I)

		Returns:
				Kernel block-matrix of appropriate shape
		"""
		x2 = x1 if x2 is None else x2

		rows, cols = jnp.triu_indices(self.nb_blocks)

		full_kernel = jtu.tree_map(
			lambda param, block_in_ax:
				param[rows] if block_in_ax == 0 else param[cols] if block_in_ax == 1 else param,
			self.inner_kernel,
			self.block_in_axes,
		)

		x1 = x1[rows] if self.block_over_inputs == 0 else x1
		x2 = x2[cols] if self.block_over_inputs == 0 else x2

		# vmap over the batch dimension of inner_kernel and inputs
		# Each batch element gets its own version of inner_kernel with corresponding hyperparameters
		return vmap(
			lambda kernel, x1, x2: kernel(x1, x2),
			in_axes=(
				jtu.tree_map(lambda x: None if x is None else 0, self.block_in_axes),
				self.block_over_inputs,
				self.block_over_inputs,
			),
		)(full_kernel, x1, x2)

	def __str__(self):
		return f"Block{self.inner_kernel}"

In [17]:
class StaticMOKernel(StaticAbstractKernel):
	@classmethod
	@filter_jit
	def pairwise_cov(cls, kern: AbstractKernel, x1: jnp.ndarray, x2: jnp.ndarray) -> jnp.ndarray:
		"""
		Compute the kernel covariance value between two vectors.

		:param kern: kernel instance containing the hyperparameters
		:param x1: scalar array
		:param x2: scalar array
		:return: scalar array
		"""
		kern = eqx.combine(kern)

		# As the formula only involves diagonal matrices, we can compute directly with vectors
		sigma_diag = jnp.exp(kern.length_scale_1) + jnp.exp(kern.length_scale_2) + jnp.exp(kern.length_scale_u)  # Σ
		sigma_det = jnp.prod(sigma_diag)  # |Σ|
		diff = x1 - x2  # x - x'

		# Compute the quadratic form: (x - x')^T Sigma^{-1} (x - x')
		# Since Sigma^{-1} is diagonal, this simplifies to sum of (diff_i^2 * sigma_inv_diag_i)
		quadratic_form = jnp.sum(diff**2 / sigma_diag)

		return jnp.exp(kern.variance_1) * jnp.exp(kern.variance_2) /(((2 * jnp.pi)**(len(x1)/2)) * jnp.sqrt(sigma_det)) * jnp.exp(-0.5 * quadratic_form)


class MOKernel(AbstractKernel):
	"""
	Squared Exponential (aka "RBF" or "Gaussian") Kernel
	"""

	length_scale_1: Array = eqx.field(converter=jnp.asarray)
	length_scale_2: Array = eqx.field(converter=jnp.asarray)
	length_scale_u: Array = eqx.field(converter=jnp.asarray)
	variance_1: Array = eqx.field(converter=jnp.asarray)
	variance_2: Array = eqx.field(converter=jnp.asarray)

	static_class = StaticMOKernel

	def __init__(self, length_scale_1, length_scale_2, length_scale_u, variance_1, variance_2):
		super().__init__()
		self.length_scale_1 = length_scale_1
		self.length_scale_2 = length_scale_2
		self.length_scale_u = length_scale_u
		self.variance_1 = variance_1
		self.variance_2 = variance_2

In [3]:
jnp.array([1, 2, 3])[jnp.array([0, 0, 1, 1, 0, 2])]

Array([1, 1, 2, 2, 1, 3], dtype=int32)