In [None]:
from jax import config 
config.update("jax_enable_x64", True)


import jax 
import jax.numpy as jnp
import gpjax 
from gpjax.typing import Float, ScalarFloat
from jaxtyping import Num 
from gpjax.base import static_field, param_field, Module
from jax.tree_util import Partial
import tensorflow_probability.substrates.jax.distributions as tfd
from jaxtyping import Key

from matplotlib import pyplot as plt 

from dataclasses import dataclass


def array(x):
    return jnp.array(x, dtype=jnp.float64)


@jax.jit
def sph_to_car(sph):
    """
    From spherical (colat, lon) coordinates to cartesian, single point.
    """
    colat, lon = sph[..., 0], sph[..., 1]
    z = jnp.cos(colat)
    r = jnp.sin(colat)
    x = r * jnp.cos(lon)
    y = r * jnp.sin(lon)
    return jnp.stack([x, y, z], axis=-1)


@jax.jit
def car_to_sph(car):
    x, y, z = car[..., 0], car[..., 1], car[..., 2]
    colat = jnp.arccos(z)
    lon = jnp.arctan2(y, x)
    lon = (lon + 2 * jnp.pi) % (2 * jnp.pi)
    return jnp.stack([colat, lon], axis=-1)



from pathlib import Path
from typing import Callable

import numpy as np
from jax import Array


class FundamentalSystemNotPrecomputedError(ValueError):

    def __init__(self, dimension: int):
        message = f"Fundamental system for dimension {dimension} has not been precomputed."
        super().__init__(message)


def fundamental_set_loader(dimension: int, load_dir="fundamental_system") -> Callable[[int], Array]:
    load_dir = Path("../../") / load_dir
    file_name = load_dir / f"fs_{dimension}D.npz"

    cache = {}
    if file_name.exists():
        with np.load(file_name) as f:
            cache = {k: v for (k, v) in f.items()}
    else:
        raise FundamentalSystemNotPrecomputedError(dimension)

    def load(degree: int) -> Array:
        key = f"degree_{degree}"
        if key not in cache:
            raise ValueError(f"key: {key} not in cache.")
        return cache[key]

    return load


@Partial(jax.jit, static_argnames=('max_ell', 'alpha',))
def gegenbauer(x: Float[Array, "N D"], max_ell: int, alpha: float = 0.5) -> Float[Array, "N L"]:
    """
    Compute the gegenbauer polynomial Cᵅₙ(x) recursively.

    Cᵅ₀(x) = 1
    Cᵅ₁(x) = 2αx
    Cᵅₙ(x) = (2x(n + α - 1) Cᵅₙ₋₁(x) - (n + 2α - 2) Cᵅₙ₋₂(x)) / n

    Args:
        level: The order of the polynomial.
        alpha: The hyper-sphere constant given by (d - 2) / 2 for the Sᵈ⁻¹ sphere.
        x: Input array.

    Returns:
        The Gegenbauer polynomial evaluated at `x`.
    """
    C_0 = jnp.ones_like(x, dtype=x.dtype)
    C_1 = 2 * alpha * x
    
    res = jnp.empty((*x.shape, max_ell + 1), dtype=x.dtype)
    res = res.at[..., 0].set(C_0)

    def step(n: int, res_and_Cs: tuple[Float, Float, Float]) -> tuple[Float, Float, Float]:
        res, C, C_prev = res_and_Cs
        C, C_prev = (2 * x * (n + alpha - 1) * C - (n + 2 * alpha - 2) * C_prev) / n, C
        res = res.at[..., n].set(C)
        return res, C, C_prev
    
    return jax.lax.cond(
        max_ell == 0,
        lambda: res,
        lambda: jax.lax.fori_loop(2, max_ell + 1, step, (res.at[..., 1].set(C_1), C_1, C_0))[0],
    )


@Partial(jax.jit, static_argnames=('alpha',)) # NOTE ell is not static, since it will be most often different with each call 
def gegenbauer_single(x: Float, ell: int, alpha: float) -> Float:
    """
    Compute the gegenbauer polynomial Cᵅₙ(x) recursively.

    Cᵅ₀(x) = 1
    Cᵅ₁(x) = 2αx
    Cᵅₙ(x) = (2x(n + α - 1) Cᵅₙ₋₁(x) - (n + 2α - 2) Cᵅₙ₋₂(x)) / n

    Args:
        level: The order of the polynomial.
        alpha: The hyper-sphere constant given by (d - 2) / 2 for the Sᵈ⁻¹ sphere.
        x: Input array.

    Returns:
        The Gegenbauer polynomial evaluated at `x`.
    """
    C_0 = jnp.ones_like(x, dtype=x.dtype)
    C_1 = 2 * alpha * x

    def step(Cs_and_n):
        C, C_prev, n = Cs_and_n
        C, C_prev = (2 * x * (n + alpha - 1) * C - (n + 2 * alpha - 2) * C_prev) / n, C
        return C, C_prev, n + 1

    def cond(Cs_and_n):
        n = Cs_and_n[2]
        return n <= ell

    return jax.lax.cond(
        ell == 0,
        lambda: C_0,
        lambda: jax.lax.while_loop(cond, step, (C_1, C_0, jnp.array(2, jnp.float64)))[0],
    )


@dataclass
class SphericalHarmonics(gpjax.Module):
    """
    Spherical harmonics inducing features for sparse inference in Gaussian processes.

    The spherical harmonics, Yₙᵐ(·) of frequency n and phase m are eigenfunctions on the sphere and,
    as such, they form an orthogonal basis.

    To construct the harmonics, we use a a fundamental set of points on the sphere {vᵢ}ᵢ and compute
    b = {Cᵅₙ(<vᵢ, x>)}ᵢ. b now forms a complete basis on the sphere and we can orthogoalise it via
    a Cholesky decomposition. However, we only need to run the Cholesky decomposition once during
    initialisation.

    Attributes:
        num_frequencies: The number of frequencies, up to which, we compute the harmonics.

    Returns:
        An instance of the spherical harmonics features.
    """

    max_ell: int = static_field()
    sphere_dim: int = static_field()
    alpha: float = static_field(init=False)
    orth_basis: Array = static_field(init=False)
    Vs: list[Array] = static_field(init=False)
    num_phases_per_frequency: Float[Array, " L"] = static_field(init=False)
    num_phases: int = static_field(init=False)


    @property
    def levels(self):
        return jnp.arange(self.max_ell + 1, dtype=jnp.int32)
    

    def __post_init__(self) -> None:
        """
        Initialise the parameters of the spherical harmonic features and return a `Param` object.

        Returns:
            None
        """
        dim = self.sphere_dim + 1

        # Try loading a pre-computed fundamental set.
        fund_set = fundamental_set_loader(dim)

        # initialise the Gegenbauer lookup table and compute the relevant constants on the sphere.
        self.alpha = (dim - 2.0) / 2.0

        # initialise the parameters Vs. Set them to non-trainable if we do not truncate the phase.
        self.Vs = [fund_set(n) for n in self.levels]

        # pre-compute and save the orthogonal basis 
        self.orth_basis = self._orthogonalise_basis()


        # set these things instead of computing every time 
        self.num_phases_per_frequency = [v.shape[0] for v in self.Vs]
        self.num_phases = sum(self.num_phases_per_frequency)


    @property
    def Ls(self) -> list[Array]:
        """
        Alias for the orthogonal basis at every frequency.
        """
        return self.orth_basis

    def _orthogonalise_basis(self) -> None:
        """
        Compute the basis from the fundamental set and orthogonalise it via Cholesky decomposition.
        """
        alpha = self.alpha
        levels = jnp.split(self.levels, self.max_ell + 1)
        const = alpha / (alpha + self.levels.astype(jnp.float64))
        const = jnp.split(const, self.max_ell + 1)

        def _func(v, n, c):
            x = jnp.matmul(v, v.T)
            B = c * self.custom_gegenbauer_single(x, ell=n[0], alpha=self.alpha)
            L = jnp.linalg.cholesky(B + 1e-16 * jnp.eye(B.shape[0], dtype=B.dtype))
            return L

        return jax.tree.map(_func, self.Vs, levels, const)

    def custom_gegenbauer_single(self, x, ell, alpha):
        return gegenbauer(x, self.max_ell, alpha)[..., ell]

    @jax.jit
    def polynomial_expansion(self, X: Float[Array, "N D"]) -> Float[Array, "M N"]:
        """
        Evaluate the polynomial expansion of an input on the sphere given the harmonic basis.

        Args:
            X: Input Array.

        Returns:
            The harmonics evaluated at the input as a polynomial expansion of the basis.
        """
        levels = jnp.split(self.levels, self.max_ell + 1)

        def _func(v, n, L):
            vxT = jnp.dot(v, X.T)
            zonal = self.custom_gegenbauer_single(vxT, ell=n[0], alpha=self.alpha)
            harmonic = jax.lax.linalg.triangular_solve(L, zonal, left_side=True, lower=True)
            return harmonic

        harmonics = jax.tree.map(_func, self.Vs, levels, self.Ls)
        return jnp.concatenate(harmonics, axis=0)
    
    def __eq__(self, other: "SphericalHarmonics") -> bool:
        """
        Check if two spherical harmonic features are equal.

        Args:
            other: The other spherical harmonic features.

        Returns:
            A boolean indicating if the two features are equal.
        """
        # Given the first two parameters, the rest are deterministic. 
        # The user must not mutate all other fields, but that is not enforced for now.
        return (
            self.max_ell == other.max_ell 
            and self.sphere_dim == other.sphere_dim 
        )    

def angles_to_radians_colat(x: Array) -> Array:
    return jnp.pi * x / 180 + jnp.pi / 2

def angles_to_radians_lon(x: Array) -> Array:
    return jnp.pi * x / 180 


from gpjax.base import static_field, param_field
from gpjax.kernels import AbstractKernel
from gpjax.likelihoods import AbstractLikelihood
from gpjax.gps import AbstractPosterior
import tensorflow_probability.substrates.jax.bijectors as tfb
from jax.scipy.special import gammaln
from jaxtyping import Int


@jax.jit 
def comb(N, k) -> Int:
    return jnp.round(jnp.exp(gammaln(N + 1) - gammaln(k + 1) - gammaln(N - k + 1))).astype(jnp.int64)


@Partial(jax.jit, static_argnames=("sphere_dim"))
def num_phases_in_frequency(sphere_dim: int, frequency: Int) -> Int:
    l, d = frequency, sphere_dim
    return jnp.where(
        l == 0, 
        jnp.ones_like(l, dtype=jnp.int64), 
        comb(l + d - 2, l - 1) + comb(l + d - 1, l),
    )


@Partial(jax.jit, static_argnames=("max_ell", "sphere_dim"))
def sphere_addition_theorem(x: Float[Array, "D"], y: Float[Array, "D"], *, max_ell: int, sphere_dim: int) -> Float:
    alpha = (sphere_dim - 1) / 2.0
    c1 = num_phases_in_frequency(sphere_dim=sphere_dim, frequency=jnp.arange(max_ell + 1))
    c2 = gegenbauer(1.0, max_ell=max_ell, alpha=alpha)
    Pz = gegenbauer(jnp.dot(x, y), max_ell=max_ell, alpha=alpha)
    return c1 / c2 * Pz


def addition_theorem_scalar_kernel(spectrum: Float[Array, "I"], z: Float[Array, "I"]) -> Float[Array, ""]:
    return jnp.dot(spectrum, z)


@Partial(jax.jit, static_argnames=('dim',))
def matern_spectrum(ell: Float, kappa: Float, nu: Float, variance: Float, dim: int) -> Float:
    lambda_ells = ell * (ell + dim - 1)
    log_Phi_nu_ells = -(nu + dim / 2) * jnp.log1p((lambda_ells * kappa**2) / (2 * nu))
    
    # Subtract max value for numerical stability
    max_log_Phi = jnp.max(log_Phi_nu_ells)
    Phi_nu_ells = jnp.exp(log_Phi_nu_ells - max_log_Phi)
    
    # Normalize the density, so that it sums to 1
    num_harmonics_per_ell = num_phases_in_frequency(frequency=ell, sphere_dim=dim)
    normalizer = jnp.dot(num_harmonics_per_ell, Phi_nu_ells)
    return variance * Phi_nu_ells / normalizer


@dataclass
class SphereMaternKernel(Module):
    sphere_dim: int = static_field(2)
    kappa: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus())
    nu: ScalarFloat = param_field(jnp.array(1.5), bijector=tfb.Softplus())
    variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus())
    max_ell: int = static_field(25)

    def __post_init__(self):
        self.kappa = jnp.asarray(self.kappa, dtype=jnp.float64)
        self.nu = jnp.asarray(self.nu, dtype=jnp.float64)
        self.variance = jnp.asarray(self.variance, dtype=jnp.float64)

    @property 
    def ells(self):
        return jnp.arange(self.max_ell + 1, dtype=jnp.float64)
    
    def spectrum(self) -> Num[Array, "I"]:
        return matern_spectrum(self.ells, self.kappa, self.nu, self.variance, dim=self.sphere_dim)

    @jax.jit 
    def from_spectrum(self, spectrum: Float[Array, "M"], x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, ""]:
        return addition_theorem_scalar_kernel(
            spectrum, 
            sphere_addition_theorem(x, y, max_ell=self.max_ell, sphere_dim=self.sphere_dim)
        )
    
    @jax.jit 
    def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, ""]:
        return self.from_spectrum(self.spectrum(), x, y)


@dataclass 
class MultioutputSphereMaternKernel(Module):
    num_outputs: int = static_field()
    sphere_dim: int = static_field(2)
    kappa: ScalarFloat = param_field(jnp.array([1.0]), bijector=tfb.Softplus())
    nu: ScalarFloat = param_field(jnp.array([1.5]), bijector=tfb.Softplus())
    variance: ScalarFloat = param_field(jnp.array([1.0]), bijector=tfb.Softplus())
    max_ell: int = static_field(25)

    def _validate_params(self) -> None:
        # float64 for numerical stability
        self.kappa = jnp.asarray(self.kappa, dtype=jnp.float64)
        self.nu = jnp.asarray(self.nu, dtype=jnp.float64)
        self.variance = jnp.asarray(self.variance, dtype=jnp.float64)

        # shape for multioutput
        self.kappa = jnp.broadcast_to(self.kappa, (self.num_outputs,))
        self.nu = jnp.broadcast_to(self.nu, (self.num_outputs,))
        self.variance = jnp.broadcast_to(self.variance, (self.num_outputs,))

    def __post_init__(self):
        self._validate_params()

    @property 
    def ells(self):
        return jnp.arange(self.max_ell + 1)
    
    @jax.jit 
    def spectrum(self) -> Num[Array, "O L"]:
        return jax.vmap(
            lambda kappa, nu, variance: matern_spectrum(self.ells, kappa, nu, variance, dim=self.sphere_dim)
        )(self.kappa, self.nu, self.variance)
    
    @jax.jit 
    def from_spectrum(self, spectrum: Float[Array, "O L"], x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "O"]:
        return jax.vmap(
            lambda spectrum: addition_theorem_scalar_kernel(
                spectrum, 
                sphere_addition_theorem(x, y, max_ell=self.max_ell, sphere_dim=self.sphere_dim)
            )
        )(spectrum)
    
    @jax.jit 
    def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "O"]:
        return self.from_spectrum(self.spectrum(), x, y)


@dataclass 
class MultioutputPrior(Module):
    kernel: MultioutputSphereMaternKernel = param_field()
    jitter: Float = static_field(1e-12)

    @property 
    def num_outputs(self):
        return self.kernel.num_outputs


@dataclass 
class Prior(Module):
    kernel: SphereMaternKernel = param_field()
    jitter: Float = static_field(1e-12)
    

@dataclass
class Posterior(Module):
    prior: Prior = param_field()
    likelihood: Module = param_field()


@dataclass
class MultioutputPosterior(Module):
    prior: MultioutputPrior = param_field()
    likelihood: Module = param_field()

    @property 
    def num_outputs(self) -> int:
        return self.prior.num_outputs


@Partial(jax.jit, static_argnames=('jitter',))
def spherical_harmonic_features_moments(
    Kxz: Float[Array, "M"], 
    Kzz_inv_diag: Float[Array, "M"], 
    m: Float[Array, "M"], 
    sqrtS: Float[Array, "M M"], 
    jitter: float = 1e-12
) -> tuple[Float[Array, ""], Float[Array, ""]]:
    Lzz_T_inv_diag = jnp.sqrt(Kzz_inv_diag) / jnp.sqrt(1 + jitter * Kzz_inv_diag)
    Kxz_Lzz_T_inv = Kxz * Lzz_T_inv_diag
    Kxz_Lzz_T_inv_sqrtS = Kxz_Lzz_T_inv @ sqrtS

    covariance = (
        jnp.sum(jnp.square(Kxz_Lzz_T_inv_sqrtS))
        # + Kxz_Lzz_T_inv_sqrtS @ Kxz_Lzz_T_inv_sqrtS.T
        # - Kxz_Lzz_T_inv @ Kxz_Lzz_T_inv.T
        # No need for the term above as it is absorbed into Kxx 
    )

    mean = (
        Kxz_Lzz_T_inv @ m
    )

    return mean, covariance


@Partial(jax.jit, static_argnames=('jitter',))
def pathwise_sample_spherical_harmonic_features_posterior(
    Kxz: Float[Array, "M"],
    Kzz_inv_diag: Float[Array, "M"],
    m: Float[Array, "M"],
    sqrtS: Float[Array, "M M"],
    jitter: float = 1e-12,
    *, 
    key: Key
) -> Float[Array, ""]:
    u = jax.random.normal(key=key, shape=m.shape)

    # f(x) + Kxz Kzz^{-1} (u - f(z)) = Kxz Kzz^{-1} u
    Lzz_T_inv_diag = jnp.sqrt(Kzz_inv_diag) / jnp.sqrt(1 + jitter * Kzz_inv_diag)
    Kxz_Lzz_T_inv = Kxz * Lzz_T_inv_diag
    Kxz_Lzz_T_inv_sqrtS = Kxz_Lzz_T_inv @ sqrtS
    return Kxz_Lzz_T_inv_sqrtS @ u + Kxz_Lzz_T_inv @ m


@jax.jit
def whitened_prior_kl(m: Float, sqrtS: Float) -> Float:
    S = sqrtS @ sqrtS.T
    qz = tfd.MultivariateNormalFullCovariance(loc=m, covariance_matrix=S)

    pz = tfd.MultivariateNormalFullCovariance(
        loc=jnp.zeros(m.shape), 
        covariance_matrix=jnp.eye(m.shape[0]),
    )
    return tfd.kl_divergence(qz, pz)


def inducing_points_prior_kl(m: Float, sqrtS: Float) -> Float:
    return whitened_prior_kl(m, sqrtS)


@dataclass 
class DummyPosterior(Module):
    prior: Prior = param_field()


@dataclass 
class MultioutputDummyPosterior(Module):
    prior: MultioutputPrior = param_field()

    @property 
    def num_outputs(self):
        return self.prior.num_outputs
    

@dataclass
class SphericalHarmonicFeaturesPosterior(Module):
    posterior: Posterior = param_field()
    # spherical_harmonics: SphericalHarmonics = static_field()
    spherical_harmonics: SphericalHarmonics = static_field()
    m: Float[Array, "M"] = param_field(init=False)
    sqrtS: Float[Array, "M M"] = param_field(init=False, bijector=tfb.FillTriangular())
    num_inducing: int = static_field(init=False)

    def __post_init__(self):
        kernel = self.posterior.prior.kernel

        self.num_inducing = self.spherical_harmonics.num_phases
        self.m = jnp.zeros(self.num_inducing)
        self.sqrtS = jnp.eye(self.num_inducing)

    @jax.jit 
    def Kzz_diag(self, spectrum: Float[Array, "L"]) -> Float[Array, "M"]:
        shf = self.spherical_harmonics
        repeats = np.array(shf.num_phases_per_frequency)
        total_repeat_length = shf.num_phases
        return jnp.repeat(
            spectrum[:shf.max_ell + 1], 
            repeats=repeats,
            total_repeat_length=total_repeat_length,
        )
    
    def Kxz(self, x: Float[Array, "D"]) -> Float[Array, "M"]:
        return self.spherical_harmonics.polynomial_expansion(x).T
    
    def prior_kl(self) -> Float[Array, ""]:
        return whitened_prior_kl(self.m, self.sqrtS)

    @jax.jit
    def moments(self, x: Float[Array, "N D"]) -> tuple[Float[Array, ""], Float[Array, ""]]:
        kernel = self.posterior.prior.kernel

        spectrum = kernel.spectrum()

        Kzz_diag = self.Kzz_diag(spectrum)
        Kxz = self.Kxz(x)

        return spherical_harmonic_features_moments(Kxz, Kzz_diag, self.m, self.sqrtS)
    
    @jax.jit 
    def diag(self, x: Float[Array, "N D"]) -> tfd.Normal:
        mean, variance = jax.vmap(self.moments)(x)
        return tfd.Normal(loc=mean, scale=jnp.sqrt(variance))
    
    @jax.jit 
    def pathwise_sample_single(self, x: Float[Array, "D"], *, key: Key) -> Float[Array, "N"]:
        kernel = self.posterior.prior.kernel

        Kxz = self.Kxz(x)
        Kzz_diag = self.Kzz_diag(kernel.spectrum())
        return pathwise_sample_spherical_harmonic_features_posterior(
            Kxz, Kzz_diag, self.m, self.sqrtS, key=key
        )
    
    @jax.jit
    def pathwise_sample(self, x: Float[Array, "N D"], *, key: Key) -> Float[Array, "N"]:
        return jax.vmap(lambda x: self.pathwise_sample_single(x, key=key))(x)


@dataclass
class MultioutputSphericalHarmonicFeaturesPosterior(Module):
    num_outputs: int = static_field(init=False)

    posterior: MultioutputPosterior = param_field()
    spherical_harmonics: SphericalHarmonics = static_field()
    m: Float[Array, "M"] = param_field(init=False)
    sqrtS: Float[Array, "M M"] = param_field(init=False, bijector=tfb.FillTriangular())
    sqrtS_augment: Float[Array, "L"] = param_field(init=False)

    def __post_init__(self):
        kernel = self.posterior.prior.kernel

        self.num_outputs = self.posterior.num_outputs
        
        num_inducing = self.spherical_harmonics.num_phases
        self.m = jnp.zeros(num_inducing)
        self.sqrtS = jnp.eye(num_inducing)
        self.sqrtS_augment = jnp.ones(kernel.max_ell + 1).at[:self.spherical_harmonics.max_ell + 1].set(0.0)

        self.m = jnp.broadcast_to(self.m, (self.num_outputs, num_inducing))
        self.sqrtS = jnp.broadcast_to(self.sqrtS, (self.num_outputs, num_inducing, num_inducing))
        self.sqrtS_augment = jnp.broadcast_to(self.sqrtS_augment, (self.num_outputs, kernel.max_ell + 1))

    @jax.jit
    def prior_kl(self) -> Float:
        return jnp.sum(jax.vmap(whitened_prior_kl)(self.m, self.sqrtS), axis=0)

    @jax.jit 
    def Kzz_diag(self, spectrum: Float[Array, "O L"]) -> Float[Array, "O M"]:
        shf = self.spherical_harmonics
        repeats = np.array(shf.num_phases_per_frequency)
        total_repeat_length = shf.num_phases
        return jax.vmap(
            lambda spectrum: jnp.repeat(spectrum, repeats=repeats, total_repeat_length=total_repeat_length)
        )(spectrum[:, :shf.max_ell + 1])
    

    def Kxz(self, x: Float[Array, "D"]) -> Float[Array, "O M"]:
        return self.spherical_harmonics.polynomial_expansion(x).T
    
    
    @jax.jit
    def moments(self, x: Float[Array, "D"]) -> tuple[Float[Array, "O"], Float[Array, "O"]]:
        kernel = self.posterior.prior.kernel

        # prior covariance adjusted by the diagonal variational parameters 
        spectrum = kernel.spectrum() # [O L]
        S_augment = jnp.square(self.sqrtS_augment) # [O L]
        Kxx = kernel.from_spectrum(spectrum * S_augment, x, x) # [O N N]

        # variational covariance 
        Kzz_diag = self.Kzz_diag(spectrum) # [O M]
        Kxz = self.Kxz(x) # [O M]

        m = self.m
        sqrtS = self.sqrtS

        return jax.vmap(
            lambda Kxx, Kzz_diag, m, sqrtS: spherical_harmonic_features_moments(Kxx, Kxz, Kzz_diag, m, sqrtS)
        )(Kxx, Kzz_diag, m, sqrtS)
    
    @jax.jit 
    def diag(self, x: Float[Array, "N D"]) -> tfd.Normal:
        mean, variance = jax.vmap(self.moments)(x)
        return tfd.Normal(loc=mean, scale=jnp.sqrt(variance))
    
    @jax.jit
    def pathwise_sample_single(self, x: Float[Array, "D"], *, key: Key) -> Float[Array, "O"]:
        output_dim_keys = jax.random.split(key, self.num_outputs)

        kernel = self.posterior.prior.kernel

        Kxz = self.Kxz(x)
        Kzz_diag = self.Kzz_diag(kernel.spectrum())

        return jax.vmap(
            lambda Kzz_diag, m, sqrtS, key: pathwise_sample_spherical_harmonic_features_posterior(
                Kxz, Kzz_diag, m, sqrtS, key=key
        ))(Kzz_diag, self.m, self.sqrtS, output_dim_keys)
    
    @jax.jit 
    def pathwise_sample(self, x: Float[Array, "N D"], *, key: Key) -> Float[Array, "N O"]:
        return jax.vmap(lambda x: self.pathwise_sample_single(x, key=key))(x)


# TODO verify that this is correct 
@jax.jit
def sphere_expmap(x: Float[Array, "N D"], v: Float[Array, "N D"]) -> Float[Array, "N D"]:
    theta = jnp.linalg.norm(v, axis=-1, keepdims=True)

    t = x + v
    first_order_approx = t / jnp.linalg.norm(t, axis=-1, keepdims=True)
    true_expmap = jnp.cos(theta) * x + jnp.sin(theta) * v / theta

    return jnp.where(
        theta < 1e-12,
        first_order_approx,
        true_expmap,
    )


@jax.jit 
def sphere_to_tangent(x: Float[Array, "N D"], v: Float[Array, "N D"]) -> Float[Array, "N D"]:
    v_x = jnp.sum(x * v, axis=-1, keepdims=True)
    return v - v_x * x


@dataclass 
class SphereResidualDeepGP(Module):
    hidden_layers: list[MultioutputSphericalHarmonicFeaturesPosterior] = param_field()
    output_layer: SphericalHarmonicFeaturesPosterior = param_field()
    num_samples: int = static_field(1)

    @property 
    def posterior(self) -> Posterior:
        return self.output_layer.posterior      
    
    def prior_kl(self) -> Float:
        return sum(layer.prior_kl() for layer in self.hidden_layers) + self.output_layer.prior_kl()
    
    def sample_moments(self, x: Float[Array, "N D"], *, key: Key) -> tfd.Normal:
        hidden_layer_keys = jax.random.split(key, len(self.hidden_layers))
        for hidden_layer_key, layer in zip(hidden_layer_keys, self.hidden_layers):
            v = layer.diag(x).sample(seed=hidden_layer_key)
            u = sphere_to_tangent(x, v)
            x = sphere_expmap(x, u)
        return jax.vmap(self.output_layer.moments)(x)

    def diag(self, x: Float[Array, "N D"], *, key: Key) -> tfd.MixtureSameFamily:
        sample_keys = jax.random.split(key, self.num_samples)

        # In MixtureSameFamily batch size goes last; hence, out_axes = 1
        mean, variance = jax.vmap(lambda k: self.sample_moments(x, key=k), out_axes=1)(sample_keys) 

        return tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(logits=jnp.zeros(self.num_samples)), 
            components_distribution=tfd.Normal(loc=mean, scale=jnp.sqrt(variance)), 
        )
    
    def pathwise_sample(self, x: Float[Array, "N D"], *, key: Key) -> Float[Array, "N"]:
        hidden_layer_keys = jax.random.split(key, len(self.hidden_layers))
        for hidden_layer_key, layer in zip(hidden_layer_keys, self.hidden_layers):
            v = layer.pathwise_sample(x, key=hidden_layer_key)
            u = sphere_to_tangent(x, v)
            x = sphere_expmap(x, u)
        return self.output_layer.pathwise_sample(x, key=key)


@dataclass
class DeepGaussianLikelihood(Module):
    noise_variance: Float = param_field(jnp.array(1.0), bijector=tfb.Softplus())
    
    @jax.jit 
    def diag(self, pf: tfd.MixtureSameFamily) -> tfd.MixtureSameFamily:
        component_distribution = pf.components_distribution
        mean, variance = component_distribution.mean(), component_distribution.variance()
        variance += self.noise_variance
        return tfd.MixtureSameFamily(
            mixture_distribution=pf.mixture_distribution,
            components_distribution=tfd.Normal(loc=mean, scale=jnp.sqrt(variance)),
        )


def create_residual_deep_gp_with_spherical_harmonic_features(
    num_layers: int, total_hidden_variance: float, max_ell: int, x: Float[Array, "N D"], num_samples: int = 3, *, 
    nu: float = 2.5
) -> SphereResidualDeepGP:
    sphere_dim = x.shape[1] - 1

    hidden_nu = jnp.array(nu)
    output_nu = hidden_nu

    hidden_variance = jnp.array(total_hidden_variance / max(num_layers - 1, 1))
    output_variance = jnp.array(1.0)

    hidden_kappa = jnp.array(1.0)
    output_kappa = hidden_kappa

    shf_max_ell = kernel_max_ell = max_ell
    hidden_spherical_harmonics = SphericalHarmonics(max_ell=shf_max_ell, sphere_dim=sphere_dim)
    output_spherical_harmonics = hidden_spherical_harmonics

    hidden_layers = []
    for _ in range(num_layers - 1):
        kernel = MultioutputSphereMaternKernel(
            num_outputs=sphere_dim + 1, 
            sphere_dim=sphere_dim, 
            nu=hidden_nu,
            kappa=hidden_kappa,
            variance=hidden_variance,
            max_ell=kernel_max_ell,
        )
        prior = MultioutputPrior(kernel=kernel)
        posterior = MultioutputDummyPosterior(prior=prior)
        layer = MultioutputSphericalHarmonicFeaturesPosterior(posterior=posterior, spherical_harmonics=hidden_spherical_harmonics)
        hidden_layers.append(layer)

    kernel = SphereMaternKernel(
        sphere_dim=sphere_dim,
        nu=output_nu,
        kappa=output_kappa,
        variance=output_variance,
        max_ell=kernel_max_ell,
    )
    prior = Prior(kernel=kernel)
    likelihood = DeepGaussianLikelihood()
    posterior = Posterior(prior=prior, likelihood=likelihood)
    output_layer = SphericalHarmonicFeaturesPosterior(posterior=posterior, spherical_harmonics=output_spherical_harmonics)

    return SphereResidualDeepGP(hidden_layers=hidden_layers, output_layer=output_layer, num_samples=num_samples)

# Notes on data
- tangent vectors need not be unit norm 

In [None]:
import numpy as np 
import pandas as pd 
import plotly.express as px 
from plotly import graph_objects as go
from plotly.subplots import make_subplots


mean_inputs = pd.read_csv("../mean_inputs.csv", header=None, names=['x', 'y', 'z'])
mean_outputs = pd.read_csv("../mean_outputs.csv", header=None, names=['x', 'y', 'z', 'u', 'v', 'w'])
std_inputs = pd.read_csv("../std_inputs.csv", header=None, names=['x', 'y', 'z'])
std_outputs = pd.read_csv("../std_outputs.csv", header=None, names=['y'])

In [None]:
x = jnp.asarray(mean_inputs.values)
x_output = jnp.asarray(std_inputs.values)
model = create_residual_deep_gp_with_spherical_harmonic_features(
    num_layers=5, total_hidden_variance=0.5, max_ell=10, x=mean_inputs.values, num_samples=1, nu=1.5,
)

In [None]:
key = jax.random.key(3)

In [None]:
# Sphere (background)
theta = jnp.linspace(0, 2 * jnp.pi, 100)
phi = jnp.linspace(0, jnp.pi, 100)
theta, phi = jnp.meshgrid(theta, phi)
sphere_inputs = jnp.stack([jnp.sin(phi) * jnp.cos(theta), jnp.sin(phi) * jnp.sin(theta), jnp.cos(phi)], axis=-1)
sphere_outputs = jnp.zeros((100, 100))

# Inputs (Left-most image)
input_inputs = x
input_outputs = jnp.zeros((x.shape[0],))

# GVF (Middle-left image)
gvf_inputs = x
v = model.hidden_layers[0].pathwise_sample(x, key=key)
u = sphere_to_tangent(x, v)
gvf_outputs = jnp.concat([x, u], axis=-1)

# exp GVF (Middle-right image)
exp_gvf_inputs = x
y = sphere_expmap(x, u)
exp_gvf_outputs = y

# Outputs (Right-most image)
output_inputs = x_output
output_outputs = model.pathwise_sample(x_output, key=key)

In [None]:
# plot all images 
r_scatter = 1.01
marker_size = 3

fig = make_subplots(
    rows=1, 
    cols=4, 
    subplot_titles=("Inputs", "GVF", "Exp+GVF", "Outputs"), 
    specs=[[{'type': 'surface'}, {'type': 'surface'}, {'type': 'scatter3d'}, {'type': 'scatter3d'}]]
)

# Inputs (Scatter3d with black marker color on a grey sphere)

fig.add_trace(
    go.Surface(
        x=sphere_inputs[:, :, 0],
        y=sphere_inputs[:, :, 1],
        z=sphere_inputs[:, :, 2],
        surfacecolor=sphere_outputs,
        colorscale=['lightgrey', 'lightgrey'],
        showscale=False,
    ),
    row=1, col=1
)

fig.add_trace(
    go.Scatter3d(
        x=input_inputs[:, 0] * r_scatter, 
        y=input_inputs[:, 1] * r_scatter, 
        z=input_inputs[:, 2] * r_scatter, 
        mode='markers', 
        marker=dict(
            color='black',
            size=marker_size,
        ),
    ), 
    row=1, col=1
)

# GVF (Cone plot with cone position set to gvf_inputs, cone direction set to gvf_outputs[:, 3:], and black cone color)
fig.add_trace(
    go.Surface(
        x=sphere_inputs[:, :, 0],
        y=sphere_inputs[:, :, 1],
        z=sphere_inputs[:, :, 2],
        surfacecolor=sphere_outputs,
        colorscale=['lightgrey', 'lightgrey'],
        showscale=False,
    ),
    row=1, col=2
)

fig.add_trace(
    go.Cone(
        x=gvf_inputs[:, 0], 
        y=gvf_inputs[:, 1], 
        z=gvf_inputs[:, 2], 
        u=gvf_outputs[:, 3], 
        v=gvf_outputs[:, 4], 
        w=gvf_outputs[:, 5], 
        colorscale=['black', 'black'],
        sizemode='scaled',
        sizeref=1.2,
        showscale=False,
    ), 
    row=1, col=2
)

# Exp+GVF (Scatter3d with black marker color)
fig.add_trace(
    go.Surface(
        x=sphere_inputs[:, :, 0],
        y=sphere_inputs[:, :, 1],
        z=sphere_inputs[:, :, 2],
        surfacecolor=sphere_outputs,
        colorscale=['lightgrey', 'lightgrey'],
        showscale=False,
    ),
    row=1, col=3
)

fig.add_trace(
    go.Scatter3d(
        x=exp_gvf_outputs[:, 0] * r_scatter, 
        y=exp_gvf_outputs[:, 1] * r_scatter, 
        z=exp_gvf_outputs[:, 2] * r_scatter, 
        mode='markers', 
        marker=dict(
            color='black',
            size=marker_size,
        ),
    ), 
    row=1, col=3
)

# Outputs (Scatter3d with color set to output_outputs)
fig.add_trace(
    go.Surface(
        x=sphere_inputs[:, :, 0],
        y=sphere_inputs[:, :, 1],
        z=sphere_inputs[:, :, 2],
        surfacecolor=sphere_outputs,
        colorscale=['lightgrey', 'lightgrey'],
        showscale=False,
    ),
    row=1, col=4
)

fig.add_trace(
    go.Scatter3d(
        x=output_inputs[:, 0] * r_scatter, 
        y=output_inputs[:, 1] * r_scatter, 
        z=output_inputs[:, 2] * r_scatter, 
        mode='markers', 
        marker=dict(
            color=output_outputs, 
            size=marker_size,
            colorscale='Plasma',
        ),
    ), 
    row=1, col=4
)

fig.update_layout(
    scene=dict(
        xaxis=dict(visible=False),
        yaxis=dict(visible=False),
        zaxis=dict(visible=False),
    ),
    scene2=dict(
        xaxis=dict(visible=False),
        yaxis=dict(visible=False),
        zaxis=dict(visible=False),
        camera=dict(
            eye=dict(x=1.25, y=-1.25, z=1.25),  # Camera position
            up=dict(x=0, y=0, z=1)  # Up direction
        ),
    ),
    scene3=dict(
        xaxis=dict(visible=False),
        yaxis=dict(visible=False),
        zaxis=dict(visible=False),
    ),
    scene4=dict(
        xaxis=dict(visible=False),
        yaxis=dict(visible=False),
        zaxis=dict(visible=False),
    ),
    width=1300,
    height=600,
    showlegend=False,
)
fig.write_image("residual_deep_gp-schematic.pdf")
fig.show()

In [None]:
# save the data as csv using the names of the variables
data = [
    input_inputs, 
    input_outputs, 
    gvf_inputs, 
    exp_gvf_inputs, 
    exp_gvf_outputs, 
    output_inputs, 
    output_outputs
]

names = [
    "a-inputs", 
    "a-outputs", 
    "b", 
    "c-inputs", 
    "c-outputs",  
    "d-inputs", 
    "d-outputs"
]



for datum, name in zip(data, names):
    pd.DataFrame(datum).to_csv(f"{name}.csv", header=False, index=False)

In [None]:
pd.read_csv("./gvf_outputs.csv")