# Benchmarks - Kernels

In MagmaClust and GPs in general, kernels are what makes it possible to compute a covariance matrix for any set of points, given a set of hyperparameters. They are used in many innermost loops, and their implementation is critical for performance.

**Main considerations when implementing kernels**

A good kernel implementation must be:
* fast, as it is used in many innermost loops
* usable at many dimensions, including a batch dimension with distinct hyperparameters for each element
* work on padded inputs (aka inputs with NaNs), maybe using a mask
* jittable, as it is used in many jit-compiled functions
* modular, as kernels can be combined in many ways
* both static and instance-based, to either carry around hyperparameters or be called with them
* easy to override, as users may want to define their own kernels

These goals are conflicting in some cases (e.g: a jittable version of a kernel is not trivial to write for most people), and the best implementation will depend on the specific use case.

---
## Setup

In [1]:
# Standard library
import os
from functools import partial

os.environ['JAX_ENABLE_X64'] = "True"

In [2]:
# Third party
import jax
from jax import numpy as jnp
from jax import vmap

import numpy as np
import pandas as pd

In [3]:
# Local
from MagmaClustPy.utils import generate_dummy_db, preprocess_db

In [4]:
# Config
key = jax.random.PRNGKey(0)

---
## Data

---
## Current implementation

In [5]:
import Kernax

---
## Custom implementation(s)

*Start by copy-pasting the original function from the MagmaClustPy module, then bring modifications*

### Defaults of the previous implementation that we wish to correct/improve:

**We want to be able to compose kernels**

As kernels are pytrees, they need to support tree_flatten and tree_unflatten methods.
This is not trivial to implement, as e.g. a sum kernel will have to dynamically know how many hyperparameters belong to each sub-kernel.

---
## Comparison

### On padded datasets

In [6]:
db = pd.read_csv("../datasets/K=1/small_distinct_input_shared_hp.csv")
padded_inputs, padded_outputs, mappings, all_inputs = preprocess_db(db)
padded_inputs.shape, padded_outputs.shape, mappings.shape, all_inputs.shape

((20, 15, 1), (20, 15, 1), (20, 15), (41, 1))

In [7]:
noises = jnp.arange(1., 2., .05)
noises.shape

(20,)

In [70]:
old_kernel = Kernax.SEMagmaKernel(length_scale=jnp.array(.3), variance=jnp.array(1.))
new_kernel = Kernax.SEMagmaKernel(length_scale=.3, variance=1.) + Kernax.DiagKernel(Kernax.ExpKernel(noises))
diag_kernel = Kernax.DiagKernel(Kernax.ExpKernel(0.1))
linear_kernel = Kernax.LinearKernel(variance_b=0.5, variance_v=1.5, offset_c=1.0)
periodic_kernel = Kernax.PeriodicKernel(length_scale=0.5, variance=1.0, period=2.0)
rational_quadratic_kernel = Kernax.RationalQuadraticKernel(length_scale=0.5, variance=1.0, alpha=1.0)
matern12_kernel = Kernax.Matern12Kernel(length_scale=0.5)
matern32_kernel = Kernax.Matern32Kernel(length_scale=0.5)
matern52_kernel = Kernax.Matern52Kernel(length_scale=0.5)

In [9]:
new_kernel.right_kernel.inner_kernel.inner_kernel

ConstantKernel(value=[1.   1.05 1.1  1.15 1.2  1.25 1.3  1.35 1.4  1.45 1.5  1.55 1.6  1.65
 1.7  1.75 1.8  1.85 1.9  1.95])

In [63]:
# Covariance on padded matrix
res1 = old_kernel(padded_inputs)
res1.shape
res1lin = linear_kernel(padded_inputs)
res1lin.shape
res1perio = periodic_kernel(padded_inputs)
res1perio.shape
res1RQ = rational_quadratic_kernel(padded_inputs)
res1RQ.shape
res1matern12 = matern12_kernel(padded_inputs)
res1matern12.shape
res1matern32 = matern32_kernel(padded_inputs)
res1matern32.shape
res1matern52 = matern52_kernel(padded_inputs)
res1matern52.shape

(20, 15, 15)

In [64]:
# Covariance on un-padded matrix
# res2 = new_kernel(padded_inputs)
res2 = new_kernel(padded_inputs)
res2.shape

(20, 15, 15)

In [65]:
# Check that values in un-padded matrix correspond to the values in the padded matrix
jnp.allclose(jnp.nan_to_num(res1), jnp.nan_to_num(res2))

Array(False, dtype=bool)

In [None]:
%timeit old_kernel(padded_inputs.squeeze()).block_until_ready()
%timeit linear_kernel(padded_inputs.squeeze()).block_until_ready()
%timeit periodic_kernel(padded_inputs.squeeze()).block_until_ready()
%timeit rational_quadratic_kernel(padded_inputs.squeeze()).block_until_ready()
%timeit matern12_kernel(padded_inputs.squeeze()).block_until_ready()
%timeit matern32_kernel(padded_inputs.squeeze()).block_until_ready()
%timeit matern52_kernel(padded_inputs.squeeze()).block_until_ready()

115 μs ± 7.21 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
103 μs ± 7.72 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
117 μs ± 9.71 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
%timeit new_kernel(padded_inputs).block_until_ready()
%timeit linear_kernel(padded_inputs).block_until_ready()
%timeit periodic_kernel(padded_inputs).block_until_ready()
%timeit rational_quadratic_kernel(padded_inputs).block_until_ready()
%timeit matern12_kernel(padded_inputs).block_until_ready()
%timeit matern32_kernel(padded_inputs).block_until_ready()
%timeit matern52_kernel(padded_inputs).block_until_ready()

99.3 μs ± 5.72 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
%timeit old_kernel(padded_inputs[0].squeeze()).block_until_ready()
%timeit linear_kernel(padded_inputs[0].squeeze()).block_until_ready()
%timeit periodic_kernel(padded_inputs[0].squeeze()).block_until_ready()
%timeit rational_quadratic_kernel(padded_inputs[0].squeeze()).block_until_ready()
%timeit matern12_kernel(padded_inputs[0].squeeze()).block_until_ready()
%timeit matern32_kernel(padded_inputs[0].squeeze()).block_until_ready()
%timeit matern52_kernel(padded_inputs[0].squeeze()).block_until_ready()

187 μs ± 10.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
noise = 0.1
new_kernel = Kernax.SEMagmaKernel(length_scale=.3, variance=1.) + Kernax.DiagKernel(Kernax.ExpKernel(noise))
# README : si vous faîtes tourner ce code en ayant initialisé le noyau DiagKernel avec un vecteur
# en entrée (i.e. "noises" défini en cellule 7), vous tomberez sur l'erreur suivante : 
# "The output of true_fun has type float64[] but the corresponding output of false_fun has type float64[20], so the shapes do not match."
# Pour que le calcul de la covariance customizée array x scalaire fonctionne, il faut redéfinir
# le noyau DiagKernel avec un scalaire en entrée.
%timeit new_kernel(padded_inputs[0]).block_until_ready()
%timeit linear_kernel(padded_inputs[0]).block_until_ready()
%timeit periodic_kernel(padded_inputs[0]).block_until_ready()
%timeit rational_quadratic_kernel(padded_inputs[0]).block_until_ready()
%timeit matern12_kernel(padded_inputs[0]).block_until_ready()
%timeit matern32_kernel(padded_inputs[0]).block_until_ready()
%timeit matern52_kernel(padded_inputs[0]).block_until_ready()

166 μs ± 20 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### On scalars

In [69]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, ())
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, ())
a, b

(Array(0.04738122, dtype=float64), Array(0.74779426, dtype=float64))

In [71]:
res1 = old_kernel(a, b)
np.asarray(res1)
res1lin = linear_kernel(a, b)
np.asarray(res1lin)
res1perio = periodic_kernel(a, b)
np.asarray(res1perio)
res1RQ = rational_quadratic_kernel(a, b)
np.asarray(res1RQ)
res1matern12 = matern12_kernel(a, b)
np.asarray(res1matern12)
res1matern32 = matern32_kernel(a, b)
np.asarray(res1matern32)
res1matern52 = matern52_kernel(a, b)
np.asarray(res1matern52)

array(0.32287978)

In [72]:
noise = 0.1
new_kernel = Kernax.SEMagmaKernel(length_scale=.3, variance=1.) + Kernax.DiagKernel(Kernax.ExpKernel(noise))
# README : si vous faîtes tourner ce code en ayant initialisé le noyau DiagKernel avec un vecteur
# en entrée (i.e. "noises" défini en cellule 7), vous tomberez sur l'erreur suivante : 
# "The output of true_fun has type float64[] but the corresponding output of false_fun has type float64[20], so the shapes do not match."
# Pour que le calcul de la covariance customizée array x scalaire fonctionne, il faut redéfinir
# le noyau DiagKernel avec un scalaire en entrée.
res2 = new_kernel(a.reshape(-1, 1), b.reshape(-1, 1))
np.asarray(res2)
res2lin = linear_kernel(a.reshape(-1, 1), b.reshape(-1, 1))
np.asarray(res2lin)
res2perio = periodic_kernel(a.reshape(-1, 1), b.reshape(-1, 1))
np.asarray(res2perio)
res2RQ = rational_quadratic_kernel(a.reshape(-1, 1), b.reshape(-1, 1))
np.asarray(res2RQ)
res1matern12 = matern12_kernel(a.reshape(-1, 1), b.reshape(-1, 1))
np.asarray(res1matern12)
res1matern32 = matern32_kernel(a.reshape(-1, 1), b.reshape(-1, 1))
np.asarray(res1matern32)
res1matern52 = matern52_kernel(a.reshape(-1, 1), b.reshape(-1, 1))
np.asarray(res1matern52)

array([[0.32287978]])

In [73]:
jnp.allclose(res1, res2)
jnp.allclose(res1lin, res2lin)
jnp.allclose(res1perio, res2perio)
jnp.allclose(res1RQ, res2RQ)
jnp.allclose(res1matern12, res1matern12)
jnp.allclose(res1matern32, res1matern32)
jnp.allclose(res1matern52, res1matern52)

Array(True, dtype=bool)

In [None]:
%timeit old_kernel(a, b).block_until_ready()
%timeit linear_kernel(a, b).block_until_ready()
%timeit periodic_kernel(a, b).block_until_ready()
%timeit rational_quadratic_kernel(a, b).block_until_ready()
%timeit matern12_kernel(a, b).block_until_ready()
%timeit matern32_kernel(a, b).block_until_ready()
%timeit matern52_kernel(a, b).block_until_ready()

13.2 μs ± 1.39 μs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
11.5 μs ± 1.31 μs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
15.5 μs ± 3.49 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
%timeit new_kernel(a.reshape(-1, 1), b.reshape(-1, 1)).block_until_ready() 
%timeit linear_kernel(a.reshape(-1, 1), b.reshape(-1, 1)).block_until_ready()
%timeit periodic_kernel(a.reshape(-1, 1), b.reshape(-1, 1)).block_until_ready()
%timeit rational_quadratic_kernel(a.reshape(-1, 1), b.reshape(-1, 1)).block_until_ready()
%timeit matern12_kernel(a.reshape(-1, 1), b.reshape(-1, 1)).block_until_ready()
%timeit matern32_kernel(a.reshape(-1, 1), b.reshape(-1, 1)).block_until_ready()
%timeit matern52_kernel(a.reshape(-1, 1), b.reshape(-1, 1)).block_until_ready()

110 μs ± 5.86 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
97.3 μs ± 8.51 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
103 μs ± 10 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
120 μs ± 14.6 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### On an array and a scalar

In [74]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (10000,1))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, ())
print(a.shape)
print(b.shape)

(10000, 1)
()


In [37]:
res1 = old_kernel(a, b)
np.asarray(res1)
res1lin = linear_kernel(a, b)
np.asarray(res1lin)
res1perio = periodic_kernel(a, b)
np.asarray(res1perio)
res1RQ = rational_quadratic_kernel(a, b)
np.asarray(res1RQ)
res1matern12 = matern12_kernel(a, b)
np.asarray(res1matern12)
res1matern32 = matern32_kernel(a, b)
np.asarray(res1matern32)
res1matern52 = matern52_kernel(a, b)
np.asarray(res1matern52)

array([0.99969112, 0.99773397, 0.24271707, ..., 0.99969418, 0.80140374,
       0.20201437], shape=(10000,))

In [35]:
noise = 0.1
new_kernel = Kernax.SEMagmaKernel(length_scale=.3, variance=1.) + Kernax.DiagKernel(Kernax.ExpKernel(noise))
res2 = new_kernel(a, b) 
# README : si vous faîtes tourner ce code en ayant initialisé le noyau DiagKernel avec un vecteur
# en entrée (i.e. "noises" défini en cellule 7), vous tomberez sur l'erreur suivante : 
# "The output of true_fun has type float64[] but the corresponding output of false_fun has type float64[20], so the shapes do not match."
# Pour que le calcul de la covariance customizée array x scalaire fonctionne, il faut redéfinir
# le noyau DiagKernel avec un scalaire en entrée.
np.asarray(res2)

array([2.71818849, 2.71759523, 2.13630804, ..., 2.71818941, 2.6441578 ,
       2.05141667], shape=(10000,))

In [38]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [None]:
%timeit old_kernel(a, b).block_until_ready()
%timeit linear_kernel(a, b).block_until_ready()
%timeit periodic_kernel(a, b).block_until_ready()
%timeit rational_quadratic_kernel(a, b).block_until_ready()
%timeit matern12_kernel(a, b).block_until_ready()
%timeit matern32_kernel(a, b).block_until_ready()
%timeit matern52_kernel(a, b).block_until_ready()

271 μs ± 43.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
98.9 μs ± 7.17 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
104 μs ± 4 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
107 μs ± 7.43 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [40]:
%timeit new_kernel(a, b).block_until_ready()

190 μs ± 12.5 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### On two arrays

In [43]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (10000,1))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (15000,1))
a.shape, b.shape

((10000, 1), (15000, 1))

In [None]:
res1 = old_kernel(a,b)
np.asarray(res1).shape
res2 = old_kernel(a,b)
res2_temp = diag_kernel(a, b)
res2 = res2 + res2_temp
np.asarray(res2).shape
# README : Au vu de la taille des vecteurs, l'exécution directe de new_kernel(a,b) est très longue.
#          C'est dû au fait que JAX ne voit plus plus un "simple" noyau SE, mais un noyau SE + DiagKernel + ExpKernel.
#          Le DiagKernel a une structure mathématique totalement différente (basée sur des comparaisons d'égalité, pas des multiplications de matrices)
#          JIT ne reconnaît plus le motif simple et optimisable du noyau SE seul. Il ne sait plus comment le transformer en opérations matricielles.
#          JAX se rabat sur la méthode générique définie dans AbstractKernel.py : la double boucle vmap imbriquée. 
#          Il exécute un par un les 150 millions de calculs, ce qui est extrêmement lent.
#          Pour éviter ce problème, on évalue le SE kernel et le DiagKernel séparément, puis on les additionne.


np.asarray(res2).shape
res1lin = linear_kernel(a,b)
np.asarray(res1lin).shape
res1perio = periodic_kernel(a,b)
np.asarray(res1perio).shape
res1RQ = rational_quadratic_kernel(a,b)
np.asarray(res1RQ).shape
res1matern12 = matern12_kernel(a,b)
np.asarray(res1matern12).shape
res1matern32 = matern32_kernel(a,b)
np.asarray(res1matern32).shape
res1matern52 = matern52_kernel(a,b)
np.asarray(res1matern52).shape


(10000, 15000)

In [29]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [None]:
%timeit old_kernel(a, b).block_until_ready()
%timeit linear_kernel(a, b).block_until_ready()
%timeit periodic_kernel(a, b).block_until_ready()
%timeit rational_quadratic_kernel(a, b).block_until_ready()
%timeit matern12_kernel(a, b).block_until_ready()
%timeit matern32_kernel(a, b).block_until_ready()
%timeit matern52_kernel(a, b).block_until_ready()

674 ms ± 69 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
744 ms ± 142 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
855 ms ± 78.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%timeit (old_kernel(a,b) + diag_kernel(a,b)).block_until_ready()
# WARNING : le temps d'exécution peut être assez long, ça fait un peu ramer mon pauvre ordi, mais qui n'est pas très puissant (Alexia)

3.68 s ± 1.16 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


### On two batches of arrays with shared HP

In [48]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (50, 100,1))
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (50, 150,1))
a.shape, b.shape

((50, 100, 1), (50, 150, 1))

In [None]:
# Using shared hyperparameters for all batches
res1 = old_kernel(a, b)
np.asarray(res1).shape
res1lin = linear_kernel(a, b)
np.asarray(res1lin).shape
res1perio = periodic_kernel(a, b)
np.asarray(res1perio).shape
res1RQ = rational_quadratic_kernel(a, b)
np.asarray(res1RQ).shape
res1matern12 = matern12_kernel(a, b)
np.asarray(res1matern12).shape
res1matern32 = matern32_kernel(a, b)
np.asarray(res1matern32).shape
res1matern52 = matern52_kernel(a, b)
np.asarray(res1matern52).shape

(50, 100, 150)

In [34]:
# Also using shared hyperparameters for all batches
res2 = new_kernel(a, b)
np.asarray(res2).shape

(50, 100, 150)

In [35]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

In [52]:
%timeit old_kernel(a, b).block_until_ready()
%timeit linear_kernel(a, b).block_until_ready()
%timeit periodic_kernel(a, b).block_until_ready()
%timeit rational_quadratic_kernel(a, b).block_until_ready()
%timeit matern12_kernel(a, b).block_until_ready()
%timeit matern32_kernel(a, b).block_until_ready()
%timeit matern52_kernel(a, b).block_until_ready()

6.37 ms ± 1.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.2 ms ± 418 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.39 ms ± 1.39 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
5 ms ± 1.1 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [36]:
%timeit new_kernel(a, b).block_until_ready()

2.35 ms ± 141 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### On two batches of arrays with distinct HP

In [53]:
key, subkey = jax.random.split(key)
a = jax.random.uniform(subkey, (50, 100, 1))  # 5 batches of 10-dimensional data
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey, (50, 150, 1))  # 5 batches of 15-dimensional data

# Create distinct hyperparameters for each batch
key, subkey = jax.random.split(key)
distinct_length_scales = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values
key, subkey = jax.random.split(key)
distinct_variances = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values
key, subkey = jax.random.split(key)
distinct_offsets = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values
key, subkey = jax.random.split(key)
distinct_periods = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values
key, subkey = jax.random.split(key)
distinct_noises = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values
key, subkey = jax.random.split(key)
distinct_alpha = jax.random.uniform(subkey, (50,)) + 0.5  # Ensuring positive values

distinct_length_scales, distinct_variances, distinct_offsets, distinct_periods, distinct_noises, distinct_alpha

(Array([1.13058712, 1.39257398, 1.43719247, 1.4921365 , 1.07411877,
        0.66303972, 0.74972994, 1.26233933, 1.25212426, 0.77106796,
        1.05009263, 1.41423839, 0.81068287, 0.97035455, 0.86315106,
        0.80851423, 0.98519326, 1.29311533, 1.04478941, 0.92615081,
        0.65470157, 0.59665795, 1.36819238, 0.93291217, 0.94619525,
        1.48503262, 1.46194018, 0.62133315, 0.51839394, 0.56565761,
        0.98794474, 0.62636225, 0.86938382, 1.33794021, 0.71495018,
        0.99728644, 0.79130799, 1.27566005, 1.01309037, 1.12478322,
        0.91268421, 1.15867137, 1.0636699 , 0.67294768, 0.99465571,
        1.15434298, 1.36256904, 1.34733359, 0.78277785, 1.0667693 ],      dtype=float64),
 Array([0.54541315, 1.37805604, 0.96281375, 0.84741937, 0.61986569,
        0.65105202, 0.67778848, 1.0729185 , 1.49000765, 1.38680738,
        0.90393972, 0.6773235 , 0.7754256 , 0.71523645, 1.29020206,
        0.88986081, 0.65889881, 0.89519714, 1.30095015, 1.38484195,
        1.25639527, 1.0107

In [None]:
batched_kernel = Kernax.SEMagmaKernel(
    length_scale=distinct_length_scales,
    variance=distinct_variances
)
res1 = batched_kernel(a, b)
print(batched_kernel)

batched_kernel = Kernax.LinearKernel(
    variance_b=distinct_variances,
    variance_v=distinct_variances,
    offset_c=distinct_offsets
)
res1lin = batched_kernel(a, b)
print(batched_kernel)

batched_kernel = Kernax.PeriodicKernel(
    length_scale=distinct_length_scales,
    variance=distinct_variances,
    period=distinct_periods
)
res1perio = batched_kernel(a, b)
print(batched_kernel)

batched_kernel = Kernax.RationalQuadraticKernel(
    length_scale=distinct_length_scales,
    variance=distinct_variances,
    alpha = distinct_alpha
)
res1 = batched_kernel(a, b)
print(batched_kernel)

batched_kernel = Kernax.Matern12Kernel(
    length_scale=distinct_length_scales)
res1 = batched_kernel(a, b)
print(batched_kernel)

batched_kernel = Kernax.Matern32Kernel(
    length_scale=distinct_length_scales)
res1 = batched_kernel(a, b)
print(batched_kernel)

batched_kernel = Kernax.Matern52Kernel(
    length_scale=distinct_length_scales)
res1 = batched_kernel(a, b)
print(batched_kernel)

Matern32Kernel(length_scale=[1.13058712 1.39257398 1.43719247 1.4921365  1.07411877 0.66303972
 0.74972994 1.26233933 1.25212426 0.77106796 1.05009263 1.41423839
 0.81068287 0.97035455 0.86315106 0.80851423 0.98519326 1.29311533
 1.04478941 0.92615081 0.65470157 0.59665795 1.36819238 0.93291217
 0.94619525 1.48503262 1.46194018 0.62133315 0.51839394 0.56565761
 0.98794474 0.62636225 0.86938382 1.33794021 0.71495018 0.99728644
 0.79130799 1.27566005 1.01309037 1.12478322 0.91268421 1.15867137
 1.0636699  0.67294768 0.99465571 1.15434298 1.36256904 1.34733359
 0.78277785 1.0667693 ])
Matern52Kernel(length_scale=[1.13058712 1.39257398 1.43719247 1.4921365  1.07411877 0.66303972
 0.74972994 1.26233933 1.25212426 0.77106796 1.05009263 1.41423839
 0.81068287 0.97035455 0.86315106 0.80851423 0.98519326 1.29311533
 1.04478941 0.92615081 0.65470157 0.59665795 1.36819238 0.93291217
 0.94619525 1.48503262 1.46194018 0.62133315 0.51839394 0.56565761
 0.98794474 0.62636225 0.86938382 1.33794021 0.7

In [38]:
# res2 = new_kernel(a, b, length_scale=distinct_length_scales, variance=distinct_variances)
batched_kernel = Kernax.SEMagmaKernel(length_scale=distinct_length_scales, 
                                      variance=distinct_variances) + Kernax.DiagKernel(Kernax.ExpKernel(distinct_noises))
res2 = batched_kernel(a, b)
print(batched_kernel)

SumKernel(left_kernel=SEMagmaKernel(length_scale=[0.73314298 1.16657917 0.85419874 0.51307126 1.18606573 0.55039322
 0.95887649 1.22698946 0.96474768 0.80743648 1.09766226 0.80717582
 0.89580246 1.3499695  0.77893643 1.40237506 1.46014276 0.78476368
 0.60720141 0.83234395 0.80273261 0.89522365 0.55596665 1.07829468
 1.12246261 1.31228677 1.27370113 1.28929839 1.177215   0.6121177
 0.85959204 0.97278255 0.93860787 0.95404747 1.35596002 1.10412686
 1.49799604 1.3636239  0.9819335  1.49066854 0.57652775 1.42748273
 0.96863449 0.54673638 0.60064792 1.21251168 0.54389353 1.28000685
 1.4706205  1.37892888], variance=[1.40087618 0.50347402 0.81557193 1.34923058 1.00457197 0.65788761
 0.9219105  0.72545942 1.23310028 0.91450771 0.52153169 1.07353706
 0.84001156 0.63915561 1.28793987 0.960267   0.5185922  1.46299715
 1.2229984  0.73096194 1.47447511 1.41269148 0.62063165 1.39239143
 0.56357474 0.95879739 1.23934137 0.66682831 1.27697895 1.41619929
 1.1099631  1.14399616 1.30613247 0.78241536 1.

In [68]:
jnp.allclose(res1, res2)

Array(True, dtype=bool)

### On two batches of arrays with both shared and distinct HP

Tests unitaires pour noyaux Alexia

In [None]:
print("### Test du noyau Matérn (Version Corrigée) ###\n")
print("WARNING : ne fonctionne pas encore, car le noyau Matérn n'est pas encore implémenté dans Kernax.")
# ===================================================================
# 1. Test d'instanciation et d'appel simple
# ===================================================================
print("--- 1. Test d'instanciation ---")
try:
    # On utilise des hyperparamètres scalaires simples pour commencer
    matern_32 = Kernax.MaternKernel(length_scale=1.0, variance=1.0, nu=1.5)
    x1 = jnp.array([[1.0]]) # Entrée 2D -> (1 point, 1 dimension)
    x2 = jnp.array([[2.5]]) # Entrée 2D
    
    # Appel simple
    res = matern_32(x1, x2)
    
    # Un appel entre deux points doit retourner un scalaire
    assert res.shape == (1, 1), f"Shape attendue (1, 1), obtenue {res.shape}"
    print("✅ Instanciation et appel simple réussis.")
except Exception as e:
    print(f"❌ L'instanciation ou l'appel simple a échoué: {e}")

# # ===================================================================
# # 2. Test de correction mathématique (nu=0.5 vs. Noyau Exponentiel)
# # ===================================================================
# print("\n--- 2. Test de correction mathématique ---")
# try:
#     # Noyau Matérn avec nu=0.5
#     matern_exp = Kernax.MaternKernel(length_scale=1.2, variance=0.8, nu=0.5)
    
#     # Implémentation manuelle du noyau exponentiel pour comparaison
#     def manual_exp_kernel(x1, x2, variance, length_scale):
#         d = jnp.sqrt(jnp.sum((x1 - x2) ** 2))
#         return variance * jnp.exp(-d / length_scale)

#     x1_test = jnp.array([[1.0]])
#     x2_test = jnp.array([[2.5]])
    
#     res_matern_exp = matern_exp(x1_test, x2_test).squeeze()
#     res_manual_exp = manual_exp_kernel(x1_test, x2_test, variance=0.8, length_scale=1.2)
    
#     assert jnp.allclose(res_matern_exp, res_manual_exp)
#     print("✅ Correction mathématique (nu=0.5 vs Exponentiel) vérifiée.")
# except Exception as e:
#     print(f"❌ La vérification de la correction a échoué: {e}")

# # ===================================================================
# # 3. Test de covariance croisée (matrice)
# # ===================================================================
# print("\n--- 3. Test de dimensions (matrice de covariance) ---")
# try:
#     key, subkey1, subkey2 = jax.random.split(key, 3)
#     vec1 = jax.random.uniform(subkey1, (10, 1)) # 10 points
#     vec2 = jax.random.uniform(subkey2, (15, 1)) # 15 points
    
#     # Matrice de covariance (10x10)
#     res_matrix = matern_32(vec1, vec1)
#     assert res_matrix.shape == (10, 10), f"Shape attendue (10, 10), obtenue {res_matrix.shape}"

#     # Matrice de covariance croisée (10x15)
#     res_cross = matern_32(vec1, vec2)
#     assert res_cross.shape == (10, 15), f"Shape attendue (10, 15), obtenue {res_cross.shape}"

#     print("✅ Tests sur les matrices de covariance et de covariance croisée réussis.")
# except Exception as e:
#     print(f"❌ Les tests sur les dimensions ont échoué: {e}")

# # ===================================================================
# # 4. Test sur des batches avec hyperparamètres distincts
# # ===================================================================
# print("\n--- 4. Test sur des batches avec HPs distincts ---")
# try:
#     # 10 batches, 5 points par batch, 1 dimension
#     a_batch = jax.random.uniform(key, (10, 5, 1))
    
#     # HPs distincts pour chaque batch (un vecteur de 10 valeurs)
#     distinct_ls = jax.random.uniform(key, (10,)) + 0.5
#     distinct_var = jax.random.uniform(key, (10,)) + 0.5
    
#     # Le noyau est initialisé sans hyperparamètres spécifiques
#     matern_for_batch = Kernax.MaternKernel(nu=1.5)

#     # On utilise vmap pour mapper le calcul sur l'axe du batch (axe 0)
#     # On spécifie que les entrées (a_batch) et les HPs (ls, var) sont mappés
#     batched_call = vmap(
#         matern_for_batch, 
#         in_axes=(0, 0, None, None, None), # Mappe a1, a2, mais pas les HPs
#         out_axes=0
#     )
    
#     # On doit maintenant passer les HPs comme des arguments
#     res_batch = batched_call(a_batch, a_batch, distinct_ls, distinct_var, 1.5)
    
#     assert res_batch.shape == (10, 5, 5)
#     print("✅ Test sur les batches avec HPs distincts réussi.")

# except Exception as e:
#     print(f"❌ Le test sur les batches a échoué: {e}")

### Test du noyau Matérn (Version Corrigée) ###

--- 1. Test d'instanciation ---
❌ L'instanciation ou l'appel simple a échoué: module 'Kernax' has no attribute 'MaternKernel'


In [8]:
print("### Test du Noyau Périodique ###\n")

# ===================================================================
# 1. Test d'instanciation
# ===================================================================
print("--- 1. Test d'instanciation ---")
try:
    # On vérifie que le noyau peut être créé avec ses hyperparamètres
    periodic_kern = Kernax.PeriodicKernel(variance=1.5, length_scale=2.0, period=3.0)
    
    # On vérifie que les attributs sont correctement assignés
    assert periodic_kern.variance == 1.5
    assert periodic_kern.length_scale == 2.0
    assert periodic_kern.period == 3.0
    print("✅ Instanciation réussie avec les bons hyperparamètres.")
    
except Exception as e:
    print(f"❌ L'instanciation a échoué: {e}")
    
# ===================================================================
# 2. Test de correction mathématique
# ===================================================================
print("\n--- 2. Test de correction mathématique ---")

def manual_periodic(x1, x2, variance, length_scale, period):
    """Implémentation manuelle de la formule pour vérification."""
    dist = jnp.linalg.norm(x1 - x2)
    return variance * jnp.exp(-2 * jnp.sin(jnp.pi * dist / period)**2 / length_scale**2)

try:
    x1_test = jnp.array([1.0])
    x2_test = jnp.array([2.5])
    
    # Utilisation du noyau Kernax
    res_kernax = periodic_kern(x1_test, x2_test)
    
    # Calcul manuel pour comparaison
    res_manual = manual_periodic(x1_test, x2_test, variance=1.5, length_scale=2.0, period=3.0)
    
    assert jnp.allclose(res_kernax, res_manual)
    print(f"✅ Correction mathématique vérifiée (Kernax: {res_kernax:.4f}, Manuel: {res_manual:.4f}).")

except Exception as e:
    print(f"❌ La vérification de la correction a échoué: {e}")
    
# ===================================================================
# 3. Test sur des matrices (Covariance Croisée)
# ===================================================================
print("\n--- 3. Test de dimensions (matrice de covariance) ---")
try:
    key, subkey1, subkey2 = jax.random.split(key, 3)
    vec1 = jax.random.uniform(subkey1, (10, 2)) # 10 points en 2D
    vec2 = jax.random.uniform(subkey2, (15, 2)) # 15 points en 2D
    
    # Matrice de covariance (10x10)
    res_matrix = periodic_kern(vec1, vec1)
    assert res_matrix.shape == (10, 10), f"Shape attendue (10, 10), obtenue {res_matrix.shape}"

    # Matrice de covariance croisée (10x15)
    res_cross = periodic_kern(vec1, vec2)
    assert res_cross.shape == (10, 15), f"Shape attendue (10, 15), obtenue {res_cross.shape}"

    print("✅ Tests sur les matrices de covariance réussis.")
except Exception as e:
    print(f"❌ Les tests sur les dimensions ont échoué: {e}")
    
# ===================================================================
# 4. Test sur des batches avec hyperparamètres distincts
# ===================================================================
print("\n--- 4. Test sur des batches avec HPs distincts ---")
try:
    key, subkey = jax.random.split(key)
    a_batch = jax.random.uniform(subkey, (5, 8, 2)) # 5 batches de 8 points 2D
    
    # Génération de HPs distincts pour chaque batch
    distinct_variances = jax.random.uniform(key, (5,)) + 0.5
    distinct_length_scales = jax.random.uniform(key, (5,)) + 1.0
    distinct_periods = jax.random.uniform(key, (5,)) * 2 + 1.0
    
    # On définit une fonction qui applique le noyau pour un seul élément du batch
    def compute_single_batch(x, var, ls, p):
        # 1. Crée le noyau avec les HPs de ce batch
        kernel = Kernax.PeriodicKernel(variance=var, length_scale=ls, period=p)
        # 2. Calcule la matrice de covariance pour ce batch
        return kernel(x, x)

    # On mappe cette fonction sur les axes correspondants des données et des HPs
    res_batch = vmap(compute_single_batch, in_axes=(0, 0, 0, 0))(
        a_batch, distinct_variances, distinct_length_scales, distinct_periods
    )
    
    assert res_batch.shape == (5, 8, 8), f"Shape attendue (5, 8, 8), obtenue {res_batch.shape}"
    print("✅ Test sur les batches avec HPs distincts réussi.")

except Exception as e:
    print(f"❌ Le test sur les batches a échoué: {e}")

### Test du Noyau Périodique ###

--- 1. Test d'instanciation ---
✅ Instanciation réussie avec les bons hyperparamètres.

--- 2. Test de correction mathématique ---
✅ Correction mathématique vérifiée (Kernax: 0.9098, Manuel: 0.9098).

--- 3. Test de dimensions (matrice de covariance) ---
✅ Tests sur les matrices de covariance réussis.

--- 4. Test sur des batches avec HPs distincts ---
✅ Test sur les batches avec HPs distincts réussi.


---
## Conclusion

---