# Test of function implemented in delight.utils_jx

In [None]:
from jax import jit
import jax.numpy as jnp
from delight.utils_jx import find_positions,bilininterp_precomputedbins,kernel_parts_interp_jax
from delight.utils_jx import approx_flux_likelihood_jax, gauss_prob,gauss_lnprob,logsumexp

In [None]:
from jax import vmap

## find_positions

In [None]:
# Exemple d'utilisation
fz1 = jnp.array([0.5, 1.5, 2.5])
fzGrid = jnp.array([0.0, 1.0, 2.0, 3.0])
positions = find_positions(fz1, fzGrid)
print(positions)

## bilininterp_precomputedbins

In [None]:
# Example usage
numBands = 3
nobj = 2
#numBands = jnp.array(3)  # Dimension statique
#nobj = jnp.array(2)      # Dimension statique


Kinterp = jnp.zeros((numBands, nobj))
v1s = jnp.array([0.5, 1.5])
v2s = jnp.array([0.5, 1.5])
p1s = jnp.array([0, 1])
p2s = jnp.array([0, 1])
grid1 = jnp.array([0.0, 1.0, 2.0])
grid2 = jnp.array([0.0, 1.0, 2.0])

Kgrid = jnp.array([
    [[1.0, 2.0], [3.0, 4.0]],
    [[1.5, 2.5], [3.5, 4.5]],
    [[2.0, 3.0], [4.0, 5.0]]
])

Kinterp = bilininterp_precomputedbins(numBands, nobj, Kinterp, v1s, v2s, p1s, p2s, grid1, grid2, Kgrid)
print(Kinterp)

## kernel_parts_interp

In [None]:
from jax import random
key = random.PRNGKey(758493)  # Random seed is explicit in JAX

In [None]:
# Exemple d'input pour tester la fonction
NO1 = 2
NO2 = 2
Kinterp = jnp.zeros((NO1, NO2))  # matrice de résultats d'interpolation
b1 = jnp.array([0, 1])  # indices pour le premier indice de bande
b2 = jnp.array([0, 1])  # indices pour le second indice de bande
fz1 = jnp.array([0.5, 1.5])  # positions fz1
fz2 = jnp.array([0.5, 1.5])  # positions fz2
p1s = jnp.array([0, 1])  # indices p1 pour la grille fz1
p2s = jnp.array([0, 1])  # indices p2 pour la grille fz2
fzGrid = jnp.array([0.0, 1.0, 2.0])  # la grille fz

# Kgrid est bien de 4 D : (numBands1, self.numBands2, nz1, nz2)


#Kgrid = jnp.array([
#    [[1.0, 2.0], [3.0, 4.0]],
#    [[1.5, 2.5], [3.5, 4.5]],
#    [[2.0, 3.0], [4.0, 5.0]]
#])  # Kgrid avec dimensions nbands x ngridz1 x ngridz2

Kgrid_flat  =  random.uniform(key, shape=(len(b1)*len(b2)*len(fz1)*len(fz2),))
Kgrid = Kgrid_flat.reshape(len(b1),len(b2),len(fz1),len(fz2))

Kgrid = jnp.array([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],
                       [[[9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0]]],
                       [[[17.0, 18.0], [19.0, 20.0]], [[21.0, 22.0], [23.0, 24.0]]]])

# Appel de la fonction
Kinterp_result = kernel_parts_interp_jax(NO1, NO2, Kinterp, b1, b2, fz1, fz2, p1s, p2s, fzGrid, Kgrid)

print(Kinterp_result)


In [None]:
Kgrid.shape 

In [None]:
Kgrid

## test_approx_flux_likelihood_jax

In [None]:
# Test de la fonction JAX
def test_approx_flux_likelihood_jax():
    nz = 2
    nt = 2
    nf = 3

    # Données fictives pour les tests
    f_obs = jnp.array([1.0, 2.0, 3.0])
    f_obs_var = jnp.array([0.1, 0.2, 0.3])
    f_mod = jnp.array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], 
                       [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]])
    f_mod_covar = jnp.array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], 
                             [[0.7, 0.8, 0.9], [1.0, 1.1, 1.2]]])
    ell_hat = jnp.array([0.5, 1.0])
    ell_var = jnp.array([0.1, 0.2])

    result = approx_flux_likelihood_jax(f_obs, f_obs_var, f_mod, f_mod_covar, ell_hat, ell_var)

    # Afficher le résultat
    print(result)

# Appliquer jit à la fonction
approx_flux_likelihood_jax_jit = jit(approx_flux_likelihood_jax)


    

In [None]:
test_approx_flux_likelihood_jax()

## test gaussian

In [None]:
# Exemple d'utilisation des fonctions
def test_gaussian():
    x = 1.0
    mu = 0.0
    var = 1.0

    # Probabilités gaussiennes
    prob = gauss_prob(x, mu, var)
    lnprob = gauss_lnprob(x, mu, var)

    print("Gaussian probability:", prob)
    print("Log Gaussian probability:", lnprob)

def test_logsumexp():
    arr = jnp.array([1.0, 2.0, 3.0])
    lse = logsumexp(arr)
    print("LogSumExp:", lse)

# Tester les fonctions
test_gaussian()
test_logsumexp()
