In [93]:
import geometric_kernels.torch 
import torch 
import gpytorch 
from mdgp.kernels import GeometricMaternKernel
from gpytorch.variational import CholeskyVariationalDistribution, _VariationalStrategy, VariationalStrategy
from gpytorch.kernels import ScaleKernel
from geometric_kernels.spaces import Hypersphere
from torch import Tensor
from spherical_harmonics import SphericalHarmonics # TODO Use this to implement Phi


torch.set_default_dtype(torch.float64)

Idea: We could reuse the spherical harmonics computed in the variational strategy for the kernel itself.

There actually seems to be a simplified form for the update when using the same number of spherical harmonics for the kernel approximation and for the inducing variables. But that might be too many inducing variables. 

In [94]:
from torch import Tensor
from scipy.special import comb
from linear_operator.operators import DiagLinearOperator
from mdgp.utils import rsh 
    

def num_harmonics_single(ell: int, d: int) -> int:
    r"""
    Number of spherical harmonics of degree ell on S^{d - 1}.
    """
    if ell == 0:
        return 1
    if d == 3:
        return 2 * ell + 1
    else:
        return (2 * ell + d - 2) * comb(ell + d - 3, ell - 1) / ell


def num_harmonics(ell: Tensor | int, d: int) -> Tensor:
    """
    Vectorized version of num_harmonics_single
    """
    if isinstance(ell, int):
        return num_harmonics_single(ell, d)
    return ell.apply_(lambda e: num_harmonics_single(ell=e, d=d))


def total_num_harmonics(max_ell: int, d: int) -> int:
    """
    Total number of spherical harmonics on S^{d-1} with degree <= max_ell
    """
    return int(sum(num_harmonics(ell=torch.arange(max_ell + 1), d=d)))


def eigenvalue_laplacian(ell: int, d: int) -> float:
    """
    Eigenvalue of the Laplace-Beltrami operator for a spherical harmonic of degree ell on S_{d-1}
    """
    return ell * (ell + d - 2)


def unnormalized_matern_spectral_density(n: Tensor | float, d: int, kappa: float, nu: float) -> Tensor | float: 
    """
    compute (unnormalized) spectral density of the matern kernel on S_{d-1}
    """
    return (
        (2.0 * nu / kappa**2 + eigenvalue_laplacian(ell=n, d=d)) **
        (-nu - (d - 1) / 2.0)
    )


def matern_ahat_normalizer(d: int, max_ell: int, kappa: float, nu: float) -> Tensor:
    """
    Normalizing constant for the spectral density of the Matern kernel on S^{d-1}. 
    Depends on kappa and nu. Also depends on max_ell, as truncation of the infinite 
    sum from Karhunen-Loeve decomposition. 
    """
    n = torch.arange(max_ell + 1)
    spectral_values = unnormalized_matern_spectral_density(
        n=n,
        d=d,
        kappa=kappa,
        nu=nu,
    )
    num_harmonics_per_level = num_harmonics(torch.arange(max_ell + 1), d=d)
    normalizer = torch.sum(spectral_values * num_harmonics_per_level)

    return normalizer 


def matern_spectral_density(n: Tensor, d: int, kappa: float, nu: float, max_ell: int, sigma: float = 1.0) -> Tensor:
    """
    Spectral density of the Matern kernel on S^{d-1}
    """
    return (
        unnormalized_matern_spectral_density(n=n, d=d, kappa=kappa, nu=nu) / 
        matern_ahat_normalizer(d=d, max_ell=max_ell, kappa=kappa, nu=nu) * 
        (sigma ** 2)
    )


def matern_ahat(ell: int, d: int, max_ell: int, kappa: float, nu: float, m: int | None = None, sigma: float = 1.0) -> float:
    """
    ahat = rho(\sqrt{\ell(\ell + d - 2)}) where rho is the spectral density on S^{d-1}
    """
    return matern_spectral_density(n=ell, d=d, kappa=kappa, nu=nu, max_ell=max_ell, sigma=sigma)


def matern_repeated_ahat(max_ell: int, d: int, kappa: float, nu: float, sigma: float = 1.0) -> Tensor:
    """
    Returns a tensor of repeated ahat values for each ell. 
    """
    ells = torch.arange(max_ell + 1)
    ahat = matern_ahat(
        ell=ells, 
        d=d, 
        max_ell=max_ell,
        kappa=kappa,
        nu=nu,
        sigma=sigma,
    )
    repeats = num_harmonics(ell=ells, d=d)
    return torch.repeat_interleave(ahat, repeats=repeats)


def matern_Kuu(max_ell: int, d: int, kappa: float, nu: float, sigma: float = 1.0) -> Tensor | DiagLinearOperator: 
    """
    Returns the covariance matrix, which is a diagonal matrix with entries 
    equal to inv_ahat of the corresponding ell. 
    """
    return DiagLinearOperator(1 / matern_repeated_ahat(max_ell, d, kappa, nu, sigma=sigma))


def matern_inv_Kuu(max_ell: int, d: int, kappa: float, nu: float, sigma: float = 1.0) -> Tensor | DiagLinearOperator:
    """
    Returns the inverse of the covariance matrix, which is a diagonal matrix with entries 
    equal to ahat of the corresponding ell. 
    """
    return DiagLinearOperator(matern_repeated_ahat(max_ell, d, kappa, nu, sigma=sigma))


def spherical_harmonics_rsh(max_ell: int, d: int):
    assert d == 3, "spherical_harmonics_rsh is only implemented for d=3."
    assert max_ell <= 8, "spherical_harmonics_rsh is only implemented for max_ell <= 8."
    from mdgp.utils import rsh 
    return getattr(rsh, f"rsh_cart_{max_ell}")


def spherical_harmonics(x: Tensor, max_ell: int, d: int, rsh: bool = True) -> Tensor: 
    # Make sur ethat x is at least 2d and flatten it
    x = torch.atleast_2d(x)
    batch_shape, n = x.shape[:-2], x.shape[-2]
    x = x.flatten(0, -2)

    # Choose method of evaluating spherical harmonics at x
    if max_ell <= 8 and d == 3 and rsh is True:
        f = spherical_harmonics_rsh(max_ell, d)
    else: 
        f = SphericalHarmonics(dimension=d, degrees=max_ell + 1)

    # Evaluate x and reintroduce batch dimensions
    return f(x).reshape(*batch_shape, n, total_num_harmonics(max_ell, d))


def matern_kernel(x1: Tensor, x2: Tensor, max_ell: int, d: int, kappa: float, nu: float, sigma: float = 1.0, rsh: bool = True) -> Tensor:
    """
    Returns the Matern kernel evaluated at (x1, x2). 
    """
    f_x1 = spherical_harmonics(x1, max_ell=max_ell, d=d, rsh=rsh) # [*B, n, num_harmonics]
    f_x2 = spherical_harmonics(x2, max_ell=max_ell, d=d, rsh=rsh) # [*B, n, num_harmonics]
    ahat = matern_repeated_ahat(max_ell=max_ell, d=d, kappa=kappa, nu=nu, sigma=sigma) # [num_harmonics]
    return torch.sum(ahat * f_x1 * f_x2, dim=-1) # [*B, n, num_harmonics] @ [num_harmonics] -> [*B, n]


def matern_Kux(x: Tensor, max_ell: int, d: int, kappa: float | None = None, nu: float | None = None, sigma: float = 1.0, rsh: bool = True) -> Tensor: 
    return spherical_harmonics(x, max_ell=max_ell, d=d, rsh=rsh).mT


def matern_Phi(x: Tensor, max_ell: int, d: int, kappa: float, nu: float, rsh: bool = True, sigma: float = 1.0) -> Tensor: 
    """
    Returns the feature vector of spherical harmonics evaluated at x. 
    """
    Kux = matern_Kux(x, max_ell=max_ell, d=d, kappa=kappa, nu=nu, sigma=sigma, rsh=rsh) # [*B, num_harmonics, n]
    ahat = matern_repeated_ahat(max_ell=max_ell, d=d, kappa=kappa, nu=nu, sigma=sigma).unsqueeze(-1) # [num_harmonics, 1]
    return Kux * ahat # [*B, num_harmonics, n]


def covar_update(x: Tensor, S: Tensor, max_ell: int, d: int = 3, kappa: float = 1.0, nu: float = 2.5) -> Tensor: 
    """
    Returns update term for the covariance matrix.
    """
    Phi = matern_Phi(x, max_ell=max_ell, d=d, kappa=kappa, nu=nu)
    Kuu = matern_Kuu(max_ell=max_ell, d=d, kappa=kappa, nu=nu)
    return Phi.mT @ (S - Kuu) @ Phi 


def num_spherical_harmonics_to_degree(num_spherical_harmonics: int, dimension: int) -> int:
    """
    Returns the minimum degree for which there are at least
    `num_eigenfunctions` in the collection.
    """
    n, degree = 0, 0  # n: number of harmonics, d: degree (or level)
    while n < num_spherical_harmonics:
        n += num_harmonics(d=dimension, ell=degree)
        degree += 1

    if n > num_spherical_harmonics:
        print(
            "The number of spherical harmonics requested does not lead to complete "
            "levels of spherical harmonics. We have thus increased the number to "
            f"{n}, which includes all spherical harmonics up to degree {degree} (incl.)"
        )
    return degree - 1

In [95]:
d = 3
nu = 2.5 
kappa = 1.0
max_ell = 8
sigma = 2.0
space = Hypersphere(dim=d - 1)
rsh = True

In [96]:
base_kernel = GeometricMaternKernel(space=space, lengthscale=kappa, nu=nu, num_eigenfunctions=max_ell + 1)
geometric_matern_kernel = ScaleKernel(base_kernel)
geometric_matern_kernel.outputscale = sigma ** 2

x1 = torch.tensor([[0.0, 0.0, 1.0]])
x2 = torch.tensor([[0.0, 1.0, 0.0]])
K1_x1x2 = geometric_matern_kernel(x1, x2).evaluate()
K1_x1x1 = geometric_matern_kernel(x1, x1).evaluate()
K2_x1x2 = matern_kernel(x1, x2, max_ell=max_ell, d=d, kappa=kappa, nu=nu, sigma=sigma, rsh=rsh)
K2_x1x1 = matern_kernel(x1, x1, max_ell=max_ell, d=d, kappa=kappa, nu=nu, sigma=sigma, rsh=rsh)

In [97]:
print(K1_x1x2, K2_x1x2)
print(K1_x1x1, K2_x1x1)

tensor([[1.4269]], grad_fn=<MulBackward0>) tensor([0.1135])
tensor([[4.0000]], grad_fn=<MulBackward0>) tensor([0.3183])


In [98]:
d, kappa, nu, max_ell = 3, 1.0, 2.5, 8
kernel = GeometricMaternKernel(
    space=Hypersphere(dim=d - 1), 
    num_eigenfunctions=max_ell + 1, 
    lengthscale=kappa, 
    nu=nu, 
    trainable_nu=False
) # NOTE trainable_nu=False for simplicity in debugging
x = torch.tensor([[0., 0., 1.], [0., 0., -1.]])
S = torch.eye(total_num_harmonics(max_ell, d)) # TODO Replace with ChoelskyVariationalDistribution

Kxx = kernel(x, x)

# TODO Package the code below and functions above into a VariationalStrategy subclass 
Kuu = matern_Kuu(max_ell, d, kappa, nu)
Phi = matern_Phi(x, max_ell, d, kappa, nu)
Kxx + covar_update(x, S, max_ell)

<linear_operator.operators.sum_linear_operator.SumLinearOperator at 0x7f1345a03c50>

In [99]:
from mdgp.utils import sphere_uniform_grid, sphere_meshgrid


spherical_harmonic = spherical_harmonics_rsh(8, 3) 
def target_fnc(x):
    # return torch.ones_like(x[..., 0]) + 0.01 * torch.randn_like(x[..., 0])
    return spherical_harmonic(x)[..., 14] + 0.01 * torch.randn_like(x[..., 0])


test_inputs = sphere_meshgrid(100, 100)
test_targets = target_fnc(test_inputs)

train_inputs = sphere_uniform_grid(200)
train_targets = target_fnc(train_inputs)

# TODO 
[X] make the functions work with batch dimensions

In [100]:
from gpytorch import Module
from gpytorch.models import ApproximateGP
from gpytorch.variational import _VariationalDistribution
from gpytorch.distributions import MultivariateNormal, Distribution
from linear_operator.operators import LinearOperator
from gpytorch.utils.memoize import cached, clear_cache_hook
from gpytorch import settings 
from gpytorch.kernels import ScaleKernel
from gpytorch.means import ZeroMean 


class SphericalHarmonicFeaturesVariationalStrategy(Module):
    """
    Abstract base class for all Variational Strategies.
    """

    has_fantasy_strategy = False

    def __init__(
        self,
        covar_module, # TODO currently we pass modules for simplicity; however, they could be obtained from model 
        mean_module,
        variational_distribution: _VariationalDistribution,
        dimension: int, # TODO Given either the dimension or the number of spherical harmonics we can compute the other based on the variational distribution
        num_spherical_harmonics: int, 
    ):
        super().__init__()
        # modules
        object.__setattr__(self, "covar_module", covar_module) # FIXME Somehow passing ScaleKernel 
        base_kernel = covar_module 
        if isinstance(base_kernel, ScaleKernel):
            base_kernel = base_kernel.base_kernel
        object.__setattr__(self, "base_kernel", base_kernel)
        object.__setattr__(self, "mean_module", mean_module)

        # Variational distribution
        self._variational_distribution = variational_distribution
        self.register_buffer("variational_params_initialized", torch.tensor(0))

        # spherical harmonics 
        self.dimension = dimension 
        self.degree = num_spherical_harmonics_to_degree(num_spherical_harmonics, dimension)
        self.num_spherical_harmonics = total_num_harmonics(self.degree, dimension)

    @property
    def kappa(self) -> float:
        return self.base_kernel.lengthscale
    
    @property
    def nu(self) -> float:
        return self.base_kernel.nu
    
    @property 
    def outputscale(self) -> float:
        return self.covar_module.outputscale if hasattr(self.covar_module, "outputscale") else torch.tensor(1.0)

    @property 
    def sigma(self) -> float: 
        return self.outputscale.sqrt()

    def _clear_cache(self) -> None:
        clear_cache_hook(self)

    def _expand_inputs(self, x: Tensor, inducing_points: Tensor) -> tuple[Tensor, Tensor]:
        """
        Pre-processing step in __call__ to make x the same batch_shape as the inducing points
        """
        batch_shape = torch.broadcast_shapes(inducing_points.shape[:-2], x.shape[:-2])
        inducing_points = inducing_points.expand(*batch_shape, *inducing_points.shape[-2:])
        x = x.expand(*batch_shape, *x.shape[-2:])
        return x, inducing_points

    @property
    def jitter_val(self) -> float:
        if self._jitter_val is None:
            return settings.variational_cholesky_jitter.value(dtype=self.inducing_points.dtype)
        return self._jitter_val

    @jitter_val.setter
    def jitter_val(self, jitter_val: float):
        self._jitter_val = jitter_val

    # @property
    # @cached(name="prior_distribution_memo")
    @property
    def prior_distribution(self) -> MultivariateNormal:
        r"""
        The :func:`~gpytorch.variational.VariationalStrategy.prior_distribution` method determines how to compute the
        GP prior distribution of the inducing points, e.g. :math:`p(u) \sim N(\mu(X_u), K(X_u, X_u))`. Most commonly,
        this is done simply by calling the user defined GP prior on the inducing point data directly.

        :rtype: :obj:`~gpytorch.distributions.MultivariateNormal`
        :return: The distribution :math:`p( \mathbf u)`
        """
        covariance_matrix = matern_Kuu(
            max_ell=self.degree, 
            d=self.dimension,
            kappa=self.kappa,
            nu=self.nu,
            sigma=self.sigma,
        )
        mean = torch.zeros(self.num_spherical_harmonics)
        return MultivariateNormal(mean=mean, covariance_matrix=covariance_matrix)

    @property
    # @cached(name="variational_distribution_memo")
    def variational_distribution(self) -> Distribution:
        return self._variational_distribution()

    def forward(
        self,
        x: Tensor,
        inducing_mean: Tensor,
        variational_inducing_covar: LinearOperator | None = None,
        **kwargs,
    ) -> MultivariateNormal:
        r"""
        The :func:`~gpytorch.variational.VariationalStrategy.forward` method determines how to marginalize out the
        inducing point function values. Specifically, forward defines how to transform a variational distribution
        over the inducing point values, :math:`q(u)`, in to a variational distribution over the function values at
        specified locations x, :math:`q(f|x)`, by integrating :math:`\int p(f|x, u)q(u)du`

        :param x: Locations :math:`\mathbf X` to get the
            variational posterior of the function values at.
        :param inducing_points: Locations :math:`\mathbf Z` of the inducing points
        :param inducing_values: Mean of the distribution :math:`q(\mathbf u)` if q is a Gaussian.
        :param variational_inducing_covar: If
            the distribuiton :math:`q(\mathbf u)` is
            Gaussian, then this variable is the covariance matrix of that Gaussian.
            Otherwise, it will be None.

        :rtype: :obj:`~gpytorch.distributions.MultivariateNormal`
        :return: The distribution :math:`q( \mathbf f(\mathbf X))`
        """
        # inducing variables prior
        fu_mvn = self.prior_distribution
        muu, Kuu = fu_mvn.mean, fu_mvn.lazy_covariance_matrix

        # input points prior 
        mux, Kxx = self.mean_module(x), self.covar_module(x)

        # compute Phi, consdering that matern_Phi does not accept batch dimensions 
        Phi = matern_Phi(x, self.degree, d=self.dimension, kappa=self.kappa, 
                         nu=self.nu, sigma=self.sigma, rsh=False) 
        updated_covariance_matrix = Kxx + Phi.mT @ (variational_inducing_covar - Kuu) @ Phi # If possible convert this to linear operators 
        # updated_covariance_matrix = Kxx - torch.einsum('...nm, mk, ...kn -> ...', Phi.mT, (variational_inducing_covar - Kuu).to_dense(), Phi)

        # assert isinstance(self.mean_module, ZeroMean), "Currently, only zero mean priors are supported"
        # updated_mean = torch.einsum('...mn, m -> ...n', Phi, inducing_mean)
        updated_mean = mux + Phi.mT @ (inducing_mean - muu)

        return MultivariateNormal(mean=updated_mean, covariance_matrix=updated_covariance_matrix)


    def kl_divergence(self) -> Tensor:
        r"""
        Compute the KL divergence between the variational inducing distribution :math:`q(\mathbf u)`
        and the prior inducing distribution :math:`p(\mathbf u)`.
        """
        with settings.max_preconditioner_size(0):
            prior_distribution = self.prior_distribution
            variational_distribution = self.variational_distribution
            kl_divergence = torch.distributions.kl.kl_divergence(variational_distribution, prior_distribution)
        return kl_divergence

    def __call__(self, x: Tensor, prior: bool = False, **kwargs) -> MultivariateNormal:
        # If we're in prior mode, then we're done!
        if prior:
            return self.model.forward(x, **kwargs)

        # Delete previously cached items from the training distribution
        if self.training:
            self._clear_cache()

        # (Maybe) initialize variational distribution
        if not self.variational_params_initialized.item():
            prior_dist = self.prior_distribution
            self._variational_distribution.initialize_variational_distribution(prior_dist)
            self.variational_params_initialized.fill_(1)

        # Get p(u)/q(u)
        variational_dist_u = self.variational_distribution

        # Get q(f)
        assert isinstance(variational_dist_u, MultivariateNormal)

        return super().__call__(
            x,
            inducing_mean=variational_dist_u.mean,
            variational_inducing_covar=variational_dist_u.lazy_covariance_matrix,
            **kwargs,
        )

In [106]:
from gpytorch import Module
from gpytorch.models import ApproximateGP
from gpytorch.variational import _VariationalDistribution
from gpytorch.distributions import MultivariateNormal, Distribution
from linear_operator.operators import LinearOperator
from gpytorch.utils.memoize import cached, clear_cache_hook
from gpytorch import settings 
from gpytorch.kernels import ScaleKernel
from gpytorch.means import ZeroMean 


class WhitenedSphericalHarmonicFeaturesVariationalStrategy(Module):
    def __init__(
        self,
        covar_module, 
        mean_module,
        variational_distribution: CholeskyVariationalDistribution,
        dimension: int, 
        num_spherical_harmonics: int, 
    ):
        super().__init__()

        # modules
        object.__setattr__(self, "covar_module", covar_module) # FIXME Somehow passing ScaleKernel 
        base_kernel = covar_module 
        if isinstance(base_kernel, ScaleKernel):
            base_kernel = base_kernel.base_kernel
        object.__setattr__(self, "base_kernel", base_kernel)
        object.__setattr__(self, "mean_module", mean_module)

        # Variational distribution
        self._variational_distribution = variational_distribution
        self.register_buffer("variational_params_initialized", torch.tensor(0))

        # spherical harmonics 
        self.dimension = dimension 
        self.degree = num_spherical_harmonics_to_degree(num_spherical_harmonics, dimension)
        self.num_spherical_harmonics = total_num_harmonics(self.degree, dimension)

    @property
    def kappa(self) -> float:
        return self.base_kernel.lengthscale
    
    @property
    def nu(self) -> float:
        return self.base_kernel.nu
    
    @property 
    def outputscale(self) -> float:
        return self.covar_module.outputscale if hasattr(self.covar_module, "outputscale") else torch.tensor(1.0)

    @property 
    def sigma(self) -> float: 
        return self.outputscale.sqrt()

    def _clear_cache(self) -> None:
        clear_cache_hook(self)

    @property
    @cached(name="prior_distribution_memo")
    def prior_distribution(self) -> MultivariateNormal:
        covariance_matrix = DiagLinearOperator(torch.ones(self.num_spherical_harmonics))
        mean = torch.zeros(self.num_spherical_harmonics)
        return MultivariateNormal(mean=mean, covariance_matrix=covariance_matrix)
    
    @property 
    @cached(name="cholesky_factor_prior_memo")
    def cholesky_factor_prior(self) -> DiagLinearOperator:
        Kuu = matern_Kuu(
            max_ell=self.degree, 
            d=self.dimension,
            kappa=self.kappa,
            nu=self.nu,
            sigma=self.sigma,
        )
        return Kuu.cholesky()
        
    @property
    @cached(name="variational_distribution_memo")
    def variational_distribution(self) -> MultivariateNormal:
        return self._variational_distribution()

    def forward(self, x: Tensor, **kwargs) -> MultivariateNormal:
        # inducing-inducing prior
        pu = self.prior_distribution
        invL_muu, invL_Kuu_invLt = pu.mean, pu.lazy_covariance_matrix

        # input-input prior
        mux, Kxx = self.mean_module(x), self.covar_module(x)

        # inducing-inducing variational
        qu = self.variational_distribution
        invL_m = qu.mean
        invL_S_invLt = qu.lazy_covariance_matrix

        # inducing-input prior  
        Phi = matern_Phi(x, self.degree, d=self.dimension, kappa=self.kappa, nu=self.nu, sigma=self.sigma, rsh=False)
        L = self.cholesky_factor_prior
        Delta = L.mT @ Phi
        
        # updated_covariance_matrix = Kxx + Phi.mT @ (L @ invL_S_invLt @ L.mT - L @ invL_Kuu_invLt @ L.mT) @ Phi 
        # updated_covariance_matrix = Kxx + (L.mT @ Phi).mT @ (invL_S_invLt - invL_Kuu_invLt) @ (L.mT @ Phi) 
        updated_covariance_matrix = Kxx + Delta.mT @ (invL_S_invLt - invL_Kuu_invLt) @ Delta

        # updated_mean = mux + Phi.mT @ (L @ invL_m - L @ invL_muu)
        # updated_mean = mux + (L.mT @ Phi).mT @ (invL_m - invL_muu)
        updated_mean = mux + Delta.mT @ (invL_m - invL_muu)

        return MultivariateNormal(mean=updated_mean, covariance_matrix=updated_covariance_matrix)

    def kl_divergence(self) -> Tensor:
        with settings.max_preconditioner_size(0):
            kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
        return kl_divergence

    def __call__(self, x: Tensor, prior: bool = False, **kwargs) -> MultivariateNormal:
        # If we're in prior mode, then we're done!
        if prior:
            return self.model.forward(x, **kwargs)

        # Delete previously cached items from the training distribution
        if self.training:
            self._clear_cache()

        # (Maybe) initialize variational distribution
        if not self.variational_params_initialized.item():
            prior_dist = self.prior_distribution
            self._variational_distribution.initialize_variational_distribution(prior_dist)
            self.variational_params_initialized.fill_(1)

        return super().__call__(x, **kwargs)

In [107]:
from gpytorch import Module
from gpytorch.variational import _VariationalDistribution
from gpytorch.distributions import MultivariateNormal
from gpytorch.utils.memoize import clear_cache_hook
from gpytorch import settings 
from gpytorch.kernels import ScaleKernel


class FullSphericalHarmonicFeaturesVariationalStrategy(Module):
    def __init__(
        self,
        covar_module, 
        mean_module,
        variational_distribution: _VariationalDistribution,
        dimension: int, 
        num_spherical_harmonics: int, 
    ):
        super().__init__()

        # modules
        object.__setattr__(self, "covar_module", covar_module) # FIXME Somehow passing ScaleKernel 
        base_kernel = covar_module 
        if isinstance(base_kernel, ScaleKernel):
            base_kernel = base_kernel.base_kernel
        object.__setattr__(self, "base_kernel", base_kernel)
        object.__setattr__(self, "mean_module", mean_module)

        # Variational distribution
        self._variational_distribution = variational_distribution
        self.register_buffer("variational_params_initialized", torch.tensor(0))

        # spherical harmonics 
        self.dimension = dimension 
        self.degree = num_spherical_harmonics_to_degree(num_spherical_harmonics, dimension)
        self.num_spherical_harmonics = total_num_harmonics(self.degree, dimension)

    @property
    def kappa(self) -> float:
        return self.base_kernel.lengthscale
    
    @property
    def nu(self) -> float:
        return self.base_kernel.nu
    
    @property 
    def outputscale(self) -> float:
        return self.covar_module.outputscale if hasattr(self.covar_module, "outputscale") else torch.tensor(1.0)

    @property 
    def sigma(self) -> float: 
        return self.outputscale.sqrt()

    def _clear_cache(self) -> None:
        clear_cache_hook(self)

    def inputs_inputs_prior(self, x) -> tuple[Tensor, Tensor]:
        return self.mean_module(x), self.covar_module(x)
    
    def inducing_inputs_covariance(self, x) -> Tensor:
        return matern_Kux(
            x=x, 
            max_ell=self.degree, 
            d=self.dimension,
            kappa=self.kappa,
            nu=self.nu,
            sigma=self.sigma,
            rsh=False, 
        )
    
    @property 
    def inducing_inducing_prior(self) -> tuple[Tensor, Tensor | DiagLinearOperator]: 
        covariance_matrix = matern_Kuu(
            max_ell=self.degree, 
            d=self.dimension,
            kappa=self.kappa,
            nu=self.nu,
            sigma=self.sigma,
        )
        mean = torch.zeros(self.num_spherical_harmonics)
        return mean, covariance_matrix

    @property 
    def variational_distribution(self) -> MultivariateNormal:
        return self._variational_distribution()

    @property
    def inducing_inducing_posterior(self) -> tuple[Tensor, Tensor]:
        qu = self.variational_distribution
        return qu.mean, qu.lazy_covariance_matrix
    
    @property 
    def prior_distribution(self) -> MultivariateNormal:
        muu, Kuu = self.inducing_inducing_prior
        return MultivariateNormal(mean=muu, covariance_matrix=Kuu)

    def forward(self, x: Tensor) -> MultivariateNormal:
        # inducing-inducing prior
        muu, Kuu = self.inducing_inducing_prior 
        inv_Kuu = Kuu.inverse() #/ self.sigma ** 2

        # inducing-inputs prior 
        Kux = self.inducing_inputs_covariance(x)

        # inputs-inputs prior 
        mux, Kxx = self.inputs_inputs_prior(x)

        # variational distribution 
        m, S = self.inducing_inducing_posterior

        # update 
        updated_covariance_matrix = Kxx + Kux.mT @ inv_Kuu @ (S - Kuu) @ inv_Kuu @ Kux
        updated_mean = mux + Kux.mT @ inv_Kuu @ (m - muu)

        return MultivariateNormal(mean=updated_mean, covariance_matrix=updated_covariance_matrix)

    def kl_divergence(self) -> Tensor:
        with settings.max_preconditioner_size(0):
            return torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)

    def __call__(self, x: Tensor, prior: bool = True) -> MultivariateNormal:
        if prior is True: 
            return self.inputs_inputs_prior(x)

        # Delete previously cached items from the training distribution
        if self.training:
            self._clear_cache()

        # (Maybe) initialize variational distribution
        if not self.variational_params_initialized.item():
            prior_dist = self.prior_distribution
            self._variational_distribution.initialize_variational_distribution(prior_dist)
            self.variational_params_initialized.fill_(1)

        return super().__call__(x)

In [108]:
from gpytorch.models import ApproximateGP

class SimpleApproximateGP(ApproximateGP):
    def __init__(self, mean_module, covar_module, variational_strategy):
        super().__init__(variational_strategy=variational_strategy)
        self.covar_module = covar_module
        self.mean_module = mean_module

    def forward(self, x):
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        return MultivariateNormal(mean, covar)


In [109]:
num_spherical_harmonics = 80
dimension = 3
degree = num_spherical_harmonics_to_degree(num_spherical_harmonics=num_spherical_harmonics, dimension=dimension)
num_spherical_harmonics = total_num_harmonics(max_ell=degree, d=dimension)

# Variational distribution 
num_inducing_points = num_spherical_harmonics
batch_shape = torch.Size([])

# covar
space = Hypersphere(dim=dimension - 1)
base_kernel = GeometricMaternKernel(space=space, trainable_nu=False, num_eigenfunctions=degree)
covar_module = ScaleKernel(base_kernel)
covar_module = base_kernel

# mean
mean_module = gpytorch.means.ZeroMean()

# variational_strategies 
shf = SphericalHarmonicFeaturesVariationalStrategy(
    covar_module=covar_module,
    mean_module=mean_module,
    variational_distribution=CholeskyVariationalDistribution(num_inducing_points=num_inducing_points, batch_shape=batch_shape),
    dimension=dimension,
    num_spherical_harmonics=num_spherical_harmonics,
)

fshf = FullSphericalHarmonicFeaturesVariationalStrategy(
    covar_module=covar_module,
    mean_module=mean_module,
    variational_distribution=CholeskyVariationalDistribution(num_inducing_points=num_inducing_points, batch_shape=batch_shape),
    dimension=dimension,
    num_spherical_harmonics=num_spherical_harmonics,
)

wshf = WhitenedSphericalHarmonicFeaturesVariationalStrategy(
    covar_module=covar_module,
    mean_module=mean_module,
    variational_distribution=CholeskyVariationalDistribution(num_inducing_points=num_inducing_points, batch_shape=batch_shape),
    dimension=dimension,
    num_spherical_harmonics=num_spherical_harmonics,
)

# model
model = SimpleApproximateGP(
    covar_module=covar_module,
    mean_module=mean_module,
    variational_strategy=wshf, 
)
likelihood = gpytorch.likelihoods.GaussianLikelihood()

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 81, which includes all spherical harmonics up to degree 9 (incl.)


In [110]:
num_epochs = 6000
model.train()
likelihood.train()


optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    {'params': likelihood.parameters()},
], lr=0.01)

mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_inputs.size(0), beta=1.)

for i in range(num_epochs):
    optimizer.zero_grad()
    output = model(train_inputs)
    loss = -mll(output, (train_targets - train_targets.mean()) / train_targets.std())
    # loss = -likelihood.expected_log_prob(train_targets, output).sum(-1)
    print(f"Epoch {i+1}/{num_epochs}, Loss: {loss}", end="\r")
    # print(f"{model.variational_strategy._variational_distribution.variational_mean[14]}", end='\r')
    # print(f"{model.variational_strategy._variational_distribution.variational_mean[14]}")
    # print(f"{model.covar_module.outputscale}", end='\r')
    loss.backward()
    optimizer.step()

Epoch 106/6000, Loss: 1.3513591555418975

KeyboardInterrupt: 

In [92]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go


with torch.no_grad():
    model.eval()
    preds = model(test_inputs.view(-1, 3)).mean

fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'surface'}, {'type': 'scatter3d'},]], 
                    subplot_titles=["Posterior Mean", "Target"])

x, y, z = test_inputs.unbind(-1)
fig.add_trace(
    go.Surface(
        x=x, 
        y=y,
        z=z,
        surfacecolor=preds.view_as(x),
        coloraxis="coloraxis",
    ),
    row=1, col=1
)


fig.add_trace(
    go.Surface(
        x=x, 
        y=y,
        z=z,
        surfacecolor=((test_targets - train_targets.mean()) / train_targets.std()).view_as(x),
        coloraxis="coloraxis",
    ),
    row=1, col=2
)


fig.update_layout(
    coloraxis=dict(
        colorscale="plasma", 
    )
)

fig.show()