In [1]:
import jax 
import jax.numpy as jnp 
from jax.scipy.special import lpmn_values
MAX_ELL = 35


def array(els):
    return jnp.array(els, dtype='float64')


@jax.jit
def legendre_values(x, y):
    # x, y in cartesian coordinates
    legendre_vals = lpmn_values(MAX_ELL, MAX_ELL, jnp.dot(x, y)[None], False)
    legendre_vals = jnp.squeeze(legendre_vals[0, :, :])
    return legendre_vals


@jax.jit 
def lambd(ell: int) -> int: 
    return ell * (ell + 1)


@jax.jit
def phi(kappa: float, nu: float, lam: float) -> float: 
    return jnp.power(2 * nu / kappa + lam, -nu - 1)


@jax.jit
def legendre_tilde_constant(kappa: float, nu: float, ell: int) -> float:
    lambd_ell = lambd(ell)
    return (2 * ell + 1) / (4 * jnp.pi * lambd_ell) * phi(kappa, nu, lambd_ell)


@jax.jit 
def legendre_tilde_values(x, y, kappa: float = 1.0, nu: float = 2.5):
    legendre_vals = legendre_values(x, y)[1:]
    return jnp.multiply(
        legendre_vals, 
        array([legendre_tilde_constant(kappa, nu, ell) for ell in jnp.arange(1, MAX_ELL + 1)])
    )


@jax.jit
def hodge_matern_k_mine(x, y, kappa: float = 1.0, nu: float = 2.5):
    """
    This to me is the math-to-code translation of the equation for the (unnormalized)
    curl-free kernel. 


    """
    # x, y in cartesian coordinates
    dx = jax.jacfwd(legendre_tilde_values, argnums=0)(x, y, kappa, nu) # grad wrt x P(x cdot y)
    dy = jax.jacfwd(legendre_tilde_values, argnums=1)(x, y, kappa, nu) # grad wrt y P(x cdot y)
    dxody = jnp.einsum('ij, ik -> ijk', dx, dy) # outer product' # dx outer dy 
    return dxody.sum(axis=0) # sum of outer products
    return jnp.einsum('ij, ik -> jk', dx, dy) # sum of outer products 


# The first implementation is taken directly from the code for Intrinsic Gaussian Vector Fields on Manifolds. 
# The second implementation is equivalent to the first.     

@jax.jit
def hodge_matern_k(x, y, kappa: float = 1.0, nu: float = 2.5):
    """
    Unnormalized hodge matern kernel on the sphere. 
    """
    # x, y in cartesian coordinates
    dd_legendre_vals = jax.jacfwd(jax.jacfwd(legendre_values, argnums=0), argnums=1)(x, y)[1:]
    # d term
    dd = jnp.multiply(
        dd_legendre_vals,
        array([
            jnp.power(2 * nu / kappa + ell * (ell + 1), -nu - 1) * (2 * ell + 1) / (4 * jnp.pi * ell * (ell + 1)) # legendre_tilde_constant
            for ell in jnp.arange(1, MAX_ELL + 1)])[:, None, None]
    ).sum(axis=0)
    # vector k
    vk = dd
    return vk


@jax.jit 
def scalar_matern_k(x, y, kappa: float = 1.0, nu: float = 2.5): 
    """
    Unnormalized matern kernel on the sphere. 
    """
    return legendre_tilde_values(x, y, kappa, nu).sum(axis=0)


@jax.jit
def hodge_matern_k_equivalent_curl_free(x, y, kappa: float = 1.0, nu: float = 2.5):
    """
    Unnormalized hodge matern kernel on the sphere. Equivalent to hodge_matern_k. 

    How does a sum of outer products of gradients relate to the second order partial derivatives? 
    """
    return jax.jacfwd(jax.jacfwd(scalar_matern_k, argnums=0), argnums=1)(x, y, kappa, nu)

In [2]:
x = array([0., 0., 1.])
y = array([0., 1., 0.])

print(hodge_matern_k(x, x))
print(hodge_matern_k_equivalent(x, x))
print(hodge_matern_k_mine(x, x))

  return jnp.array(els, dtype='float64')
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


[[0.00019812 0.         0.        ]
 [0.         0.00019812 0.        ]
 [0.         0.         0.0003328 ]]
[[0.00019812 0.         0.        ]
 [0.         0.00019812 0.        ]
 [0.         0.         0.0003328 ]]
[[0.0000000e+00 0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 1.9545936e-08]]
