# Benchmarks - [thing]

**Main considerations when implementing [thing]**


---
## Setup

In [26]:
# Standard library
import os

os.environ['JAX_ENABLE_X64'] = "True"

In [27]:
# Third party
from jax import jit
from jax.tree_util import register_pytree_node_class
from jax import numpy as jnp
from jax.lax import cond

import numpy as np

In [28]:
# Local
from MagmaClustPy.kernels import AbstractKernel, RBFKernel, SEMagmaKernel

In [29]:
# Config

---
## Data

In [30]:
inputs = jnp.array([0.40, 4.45, 7.60, 8.30, 3.50, 5.10, 8.85, 9.35])

---
## Current implementation

In [31]:
@register_pytree_node_class
class NoiseKernel(AbstractKernel):
	def __init__(self, inner_kernel, noise=None, **kwargs):
		if noise is None:
			noise = jnp.array([-1.])
		super().__init__(inner_kernel=inner_kernel, noise=noise, **kwargs)

	@jit
	def compute_scalar(self, x1: jnp.ndarray, x2: jnp.ndarray, inner_kernel=None, noise=None) -> jnp.ndarray:
		return cond(x1 == x2, lambda _: inner_kernel(x1, x2) + jnp.exp(noise), lambda _: inner_kernel(x1, x2), None)

	def __str__(self):
		return f"Noisy{super().__class__.__name__}({', '.join([f'{key}={value}' for key, value in super().__dict__.items()])}, noise={self.noise})"

---
## Custom implementation(s)

In [32]:
jnp.exp(-2.5)

Array(0.082085, dtype=float64, weak_type=True)

---
## Comparison

In [33]:
noise_kern = NoiseKernel(inner_kernel=SEMagmaKernel(length_scale=jnp.array(0.3), variance=jnp.array(1.)), noise=jnp.array(-2.5))

In [34]:
np.asarray(noise_kern(inputs))

array([[2.80036683e+00, 6.24711556e-03, 1.24442362e-08, 2.48089245e-10,
        7.73332286e-02, 7.59889715e-04, 8.87186129e-12, 3.53569699e-13],
       [6.24711556e-03, 2.80036683e+00, 6.88803459e-02, 1.12162794e-02,
        1.94586338e+00, 2.32449347e+00, 2.08889434e-03, 3.73152737e-04],
       [1.24442362e-08, 6.88803459e-02, 2.80036683e+00, 2.26709559e+00,
        5.37188104e-03, 2.68459037e-01, 1.52384351e+00, 8.74259619e-01],
       [2.48089245e-10, 1.12162794e-02, 2.26709559e+00, 2.80036683e+00,
        5.34474839e-04, 6.12378823e-02, 2.43014577e+00, 1.80692039e+00],
       [7.73332286e-02, 1.94586338e+00, 5.37188104e-03, 5.34474839e-04,
        2.80036683e+00, 1.05311525e+00, 6.75910648e-05, 8.49251154e-06],
       [7.59889715e-04, 2.32449347e+00, 2.68459037e-01, 6.12378823e-02,
        1.05311525e+00, 2.80036683e+00, 1.48630336e-02, 3.37785443e-03],
       [8.87186129e-12, 2.08889434e-03, 1.52384351e+00, 2.43014577e+00,
        6.75910648e-05, 1.48630336e-02, 2.80036683e+00, 2.

---
## Conclusion

---