# Test of function implemented in delight.photoz_kernels_jx

In [None]:
from jax import jit
import jax.numpy as jnp
from jax import vmap
from delight import photoz_kernels_jx
from delight.photoz_kernels_jx import kernel_parts_interp_jx,kernelparts_diag_jx, kernelparts_jax

In [None]:
dir(photoz_kernels_jx) 

## kernel_parts_interp

In [None]:
# Exemple d'appel de la fonction
NO1, NO2 = 3, 4
Kgrid = jnp.ones((3, 4, 5, 5))  # Exemple de tableau Kgrid 4D
b1 = jnp.array([0, 1, 2])
fz1 = jnp.array([0.1, 0.2, 0.3])
p1s = jnp.array([0, 1, 2])
b2 = jnp.array([0, 1, 2, 3])
fz2 = jnp.array([0.1, 0.2, 0.3, 0.4])
p2s = jnp.array([0, 1, 2, 3])
fzGrid = jnp.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5])

# Appel de la fonction jittée
result = kernel_parts_interp_jx(NO1, NO2, Kgrid, b1, fz1, p1s, b2, fz2, p2s, fzGrid)
print(result)

In [None]:
from jax import random
key = random.PRNGKey(758493)  # Random seed is explicit in JAX

In [None]:
# Exemple d'input pour tester la fonction
NO1 = 2
NO2 = 2
Kinterp = jnp.zeros((NO1, NO2))  # matrice de résultats d'interpolation
b1 = jnp.array([0, 1])  # indices pour le premier indice de bande
b2 = jnp.array([0, 1])  # indices pour le second indice de bande
fz1 = jnp.array([0.5, 1.5])  # positions fz1
fz2 = jnp.array([0.5, 1.5])  # positions fz2
p1s = jnp.array([0, 1])  # indices p1 pour la grille fz1
p2s = jnp.array([0, 1])  # indices p2 pour la grille fz2
fzGrid = jnp.array([0.0, 1.0, 2.0])  # la grille fz

# Kgrid est bien de 4 D : (numBands1, self.numBands2, nz1, nz2)


#Kgrid = jnp.array([
#    [[1.0, 2.0], [3.0, 4.0]],
#    [[1.5, 2.5], [3.5, 4.5]],
#    [[2.0, 3.0], [4.0, 5.0]]
#])  # Kgrid avec dimensions nbands x ngridz1 x ngridz2

Kgrid_flat  =  random.uniform(key, shape=(len(b1)*len(b2)*len(fz1)*len(fz2),))
Kgrid = Kgrid_flat.reshape(len(b1),len(b2),len(fz1),len(fz2))

Kgrid = jnp.array([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],
                       [[[9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0]]],
                       [[[17.0, 18.0], [19.0, 20.0]], [[21.0, 22.0], [23.0, 24.0]]]])

# Appel de la fonction
Kinterp_result = kernel_parts_interp_jx(NO1, NO2, Kgrid, b1,fz1,p1s, b2, fz2, p2s, fzGrid)

print(Kinterp_result)


In [None]:
Kgrid.shape 

In [None]:
Kgrid

In [None]:
# Exemple d'utilisation
NO1 = 10
NC = 5
NL = 3
alpha_C = 0.1
alpha_L = 0.2
fcoefs_amp = jnp.ones((NO1, NC))
fcoefs_mu = jnp.ones((NO1, NC))
fcoefs_sig = jnp.ones((NO1, NC))
lines_mu = jnp.ones(NL)
lines_sig = jnp.ones(NL)
norms = jnp.ones(NO1)
b1 = jnp.arange(NO1)
fz1 = jnp.linspace(0.1, 1.0, NO1)
grad_needed = True

KC, KL, D_alpha_C, D_alpha_L = kernelparts_diag_jx(NO1, NC, NL, alpha_C, alpha_L, fcoefs_amp, fcoefs_mu, fcoefs_sig, 
                                                   lines_mu, lines_sig, norms, b1, fz1, grad_needed)

print("KC:", KC)
print("KL:", KL)
print("D_alpha_C:", D_alpha_C)
print("D_alpha_L:", D_alpha_L)


In [None]:

# Example usage
NO1 = 10
NO2 = 10
NC = 5
NL = 3
alpha_C = 0.1
alpha_L = 0.2
fcoefs_amp = jnp.ones((NO1, NC))
fcoefs_mu = jnp.ones((NO1, NC))
fcoefs_sig = jnp.ones((NO1, NC))
lines_mu = jnp.ones(NL)
lines_sig = jnp.ones(NL)
norms = jnp.ones(NO1)
b1 = jnp.arange(NO1)
fz1 = jnp.linspace(0.1, 1.0, NO1)
b2 = jnp.arange(NO2)
fz2 = jnp.linspace(0.1, 1.0, NO2)
grad_needed = True

KC, KL, D_alpha_C, D_alpha_L, D_alpha_z = kernelparts_jax(NO1, NO2, NC, NL, alpha_C, alpha_L, fcoefs_amp, fcoefs_mu, fcoefs_sig, 
                                                           lines_mu, lines_sig, norms, b1, fz1, b2, fz2, grad_needed)

