In [1]:
from copy import deepcopy
from typing import Tuple
import numpy as np
import jax.numpy as jnp
from kernax import AbstractKernel, StaticAbstractKernel, WrapperKernel, FeatureKernel, BlockKernel

In [2]:
k = FeatureKernel(length_scale=1., length_scale_u=1.0, variance=1.0)
bia = deepcopy(k)
bia._unconstrained_length_scale = 0
bia._unconstrained_length_scale_u = None
bia._unconstrained_variance = 0
bk = BlockKernel(inner_kernel=k, nb_blocks=2, block_in_axes=bia, block_over_inputs=False)

In [3]:
bk = deepcopy(bk)
bk_in = deepcopy(bk.inner_kernel)

In [4]:
bk_in._unconstrained_length_scale = jnp.array([0.5, 0.75])
bk_in._unconstrained_variance = jnp.array([1.0, 2.0])
bk.inner_kernel = bk_in

In [5]:
x = jnp.array([[1.,], [2,], [3,]])

In [6]:
np.asarray(bk(x))

array([[0.2820948 , 0.21969566, 0.10377688, 0.531923  , 0.42593062,
        0.21868007],
       [0.21969566, 0.2820948 , 0.21969566, 0.42593062, 0.531923  ,
        0.42593062],
       [0.10377688, 0.21969566, 0.2820948 , 0.21868007, 0.42593062,
        0.531923  ],
       [0.531923  , 0.42593062, 0.21868007, 1.009253  , 0.8263065 ,
        0.45348662],
       [0.42593062, 0.531923  , 0.42593062, 0.8263065 , 1.009253  ,
        0.8263065 ],
       [0.21868007, 0.42593062, 0.531923  , 0.45348662, 0.8263065 ,
        1.009253  ]], dtype=float32)

In [15]:
def mf_func(x1, x2, ls1, ls2, lsu, var1, var2):
	# As the formula only involves diagonal matrices, we can compute directly with vectors
	sigma_diag = (ls1 + ls2 + lsu)  # Σ
	sigma_det = jnp.prod(sigma_diag)  # |Σ|
	diff = x1 - x2  # x - x'

	# Compute the quadratic form: (x - x')^T Sigma^{-1} (x - x')
	# Since Sigma^{-1} is diagonal, this simplifies to sum of (diff_i^2 * sigma_inv_diag_i)
	quadratic_form = jnp.sum(diff**2 / sigma_diag)

	return (
		var1 * var2
		/ (((2 * jnp.pi) ** (1 / 2)) * jnp.sqrt(sigma_det))
		* jnp.exp(-0.5 * quadratic_form)
	)

In [16]:
mo_func(1., 1., 0.5, 0.75, 1., 1., 2.)

Array(3.14682, dtype=float32)