In [1]:
import geometric_kernels.torch 
import torch 
import gpytorch 
from mdgp.kernels import GeometricMaternKernel
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.kernels import ScaleKernel
from geometric_kernels.spaces import Hypersphere
from spherical_harmonics import SphericalHarmonics
from gpytorch import Module
from gpytorch.distributions import MultivariateNormal
from gpytorch.utils.memoize import cached, clear_cache_hook
from gpytorch import settings 
from gpytorch.kernels import ScaleKernel
from torch import Tensor
from scipy.special import comb
from linear_operator.operators import DiagLinearOperator

torch.set_default_dtype(torch.float64)

INFO: Using numpy backend


Idea: We could reuse the spherical harmonics computed in the variational strategy for the kernel itself.

There actually seems to be a simplified form for the update when using the same number of spherical harmonics for the kernel approximation and for the inducing variables. But that might be too many inducing variables. 

In [2]:
def num_harmonics_single(ell: int, d: int) -> int:
    r"""
    Number of spherical harmonics of degree ell on S^{d - 1}.
    """
    if ell == 0:
        return 1
    if d == 3:
        return 2 * ell + 1
    else:
        return (2 * ell + d - 2) * comb(ell + d - 3, ell - 1) / ell


def num_harmonics(ell: Tensor | int, d: int) -> Tensor:
    """
    Vectorized version of num_harmonics_single
    """
    if isinstance(ell, int):
        return num_harmonics_single(ell, d)
    return ell.apply_(lambda e: num_harmonics_single(ell=e, d=d))


def total_num_harmonics(max_ell: int, d: int) -> int:
    """
    Total number of spherical harmonics on S^{d-1} with degree <= max_ell
    """
    return int(sum(num_harmonics(ell=torch.arange(max_ell + 1), d=d)))


def eigenvalue_laplacian(ell: int, d: int) -> float:
    """
    Eigenvalue of the Laplace-Beltrami operator for a spherical harmonic of degree ell on S_{d-1}
    """
    return ell * (ell + d - 2)


def unnormalized_matern_spectral_density(n: Tensor | float, d: int, kappa: float, nu: float) -> Tensor | float: 
    """
    compute (unnormalized) spectral density of the matern kernel on S_{d-1}
    """
    return (
        (2.0 * nu / kappa**2 + eigenvalue_laplacian(ell=n, d=d)) **
        (-nu - (d - 1) / 2.0)
    )


def matern_spectral_density_normalizer(d: int, max_ell: int, kappa: float, nu: float) -> Tensor:
    """
    Normalizing constant for the spectral density of the Matern kernel on S^{d-1}. 
    Depends on kappa and nu. Also depends on max_ell, as truncation of the infinite 
    sum from Karhunen-Loeve decomposition. 
    """
    n = torch.arange(max_ell + 1)
    spectral_values = unnormalized_matern_spectral_density(
        n=n,
        d=d,
        kappa=kappa,
        nu=nu,
    )
    num_harmonics_per_level = num_harmonics(torch.arange(max_ell + 1), d=d)
    normalizer = torch.sum(spectral_values * num_harmonics_per_level)

    return normalizer 


def matern_spectral_density(n: Tensor, d: int, kappa: float, nu: float, max_ell: int, sigma: float = 1.0) -> Tensor:
    """
    Spectral density of the Matern kernel on S^{d-1}
    """
    return (
        unnormalized_matern_spectral_density(n=n, d=d, kappa=kappa, nu=nu) / 
        matern_spectral_density_normalizer(d=d, max_ell=max_ell, kappa=kappa, nu=nu) * 
        (sigma ** 2)
    )


def matern_ahat(ell: int, d: int, max_ell: int, kappa: float, nu: float, m: int | None = None, sigma: float = 1.0) -> float:
    """
    ahat = rho(\sqrt{\ell(\ell + d - 2)}) where rho is the spectral density on S^{d-1}
    """
    return matern_spectral_density(n=ell, d=d, kappa=kappa, nu=nu, max_ell=max_ell, sigma=sigma)


def matern_repeated_ahat(max_ell: int, d: int, kappa: float, nu: float, sigma: float = 1.0) -> Tensor:
    """
    Returns a tensor of repeated ahat values for each ell. 
    """
    ells = torch.arange(max_ell + 1)
    ahat = matern_ahat(ell=ells, d=d, max_ell=max_ell, kappa=kappa, nu=nu, sigma=sigma)
    repeats = num_harmonics(ell=ells, d=d)
    return torch.repeat_interleave(ahat, repeats=repeats)


def matern_Kuu(max_ell: int, d: int, kappa: float, nu: float, sigma: float = 1.0) -> Tensor | DiagLinearOperator: 
    """
    Returns the covariance matrix, which is a diagonal matrix with entries 
    equal to inv_ahat of the corresponding ell. 
    """
    return DiagLinearOperator(1 / matern_repeated_ahat(max_ell, d, kappa, nu, sigma=sigma))


def spherical_harmonics(x: Tensor, max_ell: int, d: int) -> Tensor: 
    # Make sure that x is at least 2d and flatten it
    x = torch.atleast_2d(x)
    batch_shape, n = x.shape[:-2], x.shape[-2]
    x = x.flatten(0, -2)

    # Get spherical harmonics callable
    f = SphericalHarmonics(dimension=d, degrees=max_ell + 1)

    # Evaluate x and reintroduce batch dimensions
    return f(x).reshape(*batch_shape, n, total_num_harmonics(max_ell, d))


def matern_Kux(x: Tensor, max_ell: int, d: int) -> Tensor: 
    return spherical_harmonics(x, max_ell=max_ell, d=d).mT


def matern_LT_Phi(x: Tensor, max_ell: int, d: int, kappa: float, nu: float, sigma: float = 1.0) -> Tensor: 
    """
    Returns the feature vector of spherical harmonics evaluated at x. 
    """
    Kux = matern_Kux(x, max_ell=max_ell, d=d) # [*B, num_harmonics, n]
    ahat_sqrt = matern_repeated_ahat(max_ell=max_ell, d=d, kappa=kappa, nu=nu, sigma=sigma).sqrt() # [num_harmonics]
    return Kux * ahat_sqrt.unsqueeze(-1) # [*B, num_harmonics, n]


def num_spherical_harmonics_to_degree(num_spherical_harmonics: int, dimension: int) -> int:
    """
    Returns the minimum degree for which there are at least
    `num_eigenfunctions` in the collection.
    """
    n, degree = 0, 0  # n: number of harmonics, d: degree (or level)
    while n < num_spherical_harmonics:
        n += num_harmonics(d=dimension, ell=degree)
        degree += 1

    if n > num_spherical_harmonics:
        print(
            "The number of spherical harmonics requested does not lead to complete "
            "levels of spherical harmonics. We have thus increased the number to "
            f"{n}, which includes all spherical harmonics up to degree {degree} (incl.)"
        )
    return degree - 1

In [3]:
from mdgp.utils import sphere_uniform_grid, sphere_meshgrid, spherical_harmonic


def target_fnc(x):
    return spherical_harmonic(x, 2, 3) + 0.01 * torch.randn_like(x[..., 0])


test_inputs = sphere_meshgrid(100, 100)
test_targets = target_fnc(test_inputs)

train_inputs = sphere_uniform_grid(200)
train_targets = target_fnc(train_inputs)

# TODO 
[X] make the functions work with batch dimensions

In [4]:
class WhitenedSphericalHarmonicFeaturesVariationalStrategy(Module):
    def __init__(
        self,
        covar_module, 
        mean_module,
        variational_distribution: CholeskyVariationalDistribution,
        dimension: int, 
        num_spherical_harmonics: int, 
    ):
        super().__init__()

        # modules
        object.__setattr__(self, "covar_module", covar_module) # FIXME Somehow passing ScaleKernel 
        base_kernel = covar_module 
        if isinstance(base_kernel, ScaleKernel):
            base_kernel = base_kernel.base_kernel
        object.__setattr__(self, "base_kernel", base_kernel)
        object.__setattr__(self, "mean_module", mean_module)

        # Variational distribution
        self._variational_distribution = variational_distribution
        self.register_buffer("variational_params_initialized", torch.tensor(0))

        # spherical harmonics 
        self.dimension = dimension 
        self.degree = num_spherical_harmonics_to_degree(num_spherical_harmonics, dimension)
        self.num_spherical_harmonics = total_num_harmonics(self.degree, dimension)

    @property
    def kappa(self) -> float:
        return self.base_kernel.lengthscale
    
    @property
    def nu(self) -> float:
        return self.base_kernel.nu
    
    @property 
    def outputscale(self) -> float:
        return self.covar_module.outputscale if hasattr(self.covar_module, "outputscale") else torch.tensor(1.0)

    @property 
    def sigma(self) -> float: 
        return self.outputscale.sqrt()

    def _clear_cache(self) -> None:
        clear_cache_hook(self)

    @property
    @cached(name="prior_distribution_memo")
    def prior_distribution(self) -> MultivariateNormal:
        covariance_matrix = DiagLinearOperator(torch.ones(self.num_spherical_harmonics))
        mean = torch.zeros(self.num_spherical_harmonics)
        return MultivariateNormal(mean=mean, covariance_matrix=covariance_matrix)
    
    @property 
    @cached(name="cholesky_factor_prior_memo")
    def cholesky_factor_prior(self) -> DiagLinearOperator:
        Kuu = matern_Kuu(max_ell=self.degree, d=self.dimension, kappa=self.kappa, nu=self.nu, sigma=self.sigma)
        return Kuu.cholesky() # Kuu is a DiagLinearOperator, so .cholesky() is equivalent to .sqrt() 
        
    @property
    @cached(name="variational_distribution_memo")
    def variational_distribution(self) -> MultivariateNormal:
        return self._variational_distribution()

    def forward(self, x: Tensor, **kwargs) -> MultivariateNormal:
        # inducing-inducing prior
        pu = self.prior_distribution
        invL_muu, invL_Kuu_invLt = pu.mean, pu.lazy_covariance_matrix

        # input-input prior
        mux, Kxx = self.mean_module(x), self.covar_module(x)

        # inducing-inducing variational
        qu = self.variational_distribution
        invL_m = qu.mean
        invL_S_invLt = qu.lazy_covariance_matrix

        # inducing-input prior  
        LT_Phi = matern_LT_Phi(x, max_ell=self.degree, d=self.dimension, kappa=self.kappa, nu=self.nu, sigma=self.sigma)
        
        # Update the mean and covariance matrix
        updated_mean = mux + LT_Phi.mT @ (invL_m - invL_muu)
        updated_covariance_matrix = Kxx + LT_Phi.mT @ (invL_S_invLt - invL_Kuu_invLt) @ LT_Phi

        return MultivariateNormal(mean=updated_mean, covariance_matrix=updated_covariance_matrix)

    def kl_divergence(self) -> Tensor:
        with settings.max_preconditioner_size(0):
            kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
        return kl_divergence

    def __call__(self, x: Tensor, prior: bool = False, **kwargs) -> MultivariateNormal:
        # If we're in prior mode, then we're done!
        if prior:
            return self.model.forward(x, **kwargs)

        # Delete previously cached items from the training distribution
        if self.training:
            self._clear_cache()

        # (Maybe) initialize variational distribution
        if not self.variational_params_initialized.item():
            prior_dist = self.prior_distribution
            self._variational_distribution.initialize_variational_distribution(prior_dist)
            self.variational_params_initialized.fill_(1)

        return super().__call__(x, **kwargs)

In [5]:
from gpytorch.models import ApproximateGP


class SimpleApproximateGP(ApproximateGP):
    def __init__(self, mean_module, covar_module, variational_strategy):
        super().__init__(variational_strategy=variational_strategy)
        self.covar_module = covar_module
        self.mean_module = mean_module

    def forward(self, x):
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        return MultivariateNormal(mean, covar)

In [11]:
num_spherical_harmonics = 80
dimension = 3
degree = num_spherical_harmonics_to_degree(num_spherical_harmonics=num_spherical_harmonics, dimension=dimension)
num_spherical_harmonics = total_num_harmonics(max_ell=degree, d=dimension)

# Variational distribution 
num_inducing_points = num_spherical_harmonics
batch_shape = torch.Size([2])

# covar
space = Hypersphere(dim=dimension - 1)
base_kernel = GeometricMaternKernel(space=space, trainable_nu=False, num_eigenfunctions=degree, batch_shape=batch_shape)
covar_module = ScaleKernel(base_kernel, batch_shape=batch_shape)

# mean
mean_module = gpytorch.means.ZeroMean(batch_shape=batch_shape)

wshf = WhitenedSphericalHarmonicFeaturesVariationalStrategy(
    covar_module=covar_module,
    mean_module=mean_module,
    variational_distribution=CholeskyVariationalDistribution(num_inducing_points=num_inducing_points, batch_shape=batch_shape),
    dimension=dimension,
    num_spherical_harmonics=num_spherical_harmonics,
)

# model
model = SimpleApproximateGP(
    covar_module=covar_module,
    mean_module=mean_module,
    variational_strategy=wshf, 
)
likelihood = gpytorch.likelihoods.GaussianLikelihood(batch_shape=batch_shape)

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 81, which includes all spherical harmonics up to degree 9 (incl.)


In [98]:
num_epochs = 1000
model.train()
likelihood.train()


optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    {'params': likelihood.parameters()},
], lr=0.01)

mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_inputs.size(0), beta=1.)

for i in range(num_epochs):
    optimizer.zero_grad()
    output = model(train_inputs)
    loss = -mll(output, train_targets)
    print(f"Epoch {i+1}/{num_epochs}, Loss: {loss}", end="\r")
    loss.backward()
    optimizer.step()

Epoch 403/1000, Loss: -1.145989912659798355

KeyboardInterrupt: 

In [102]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go


with torch.no_grad():
    model.eval()
    preds = model(test_inputs.view(-1, 3)).mean
    x, y, z = test_inputs.unbind(-1)    

fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'surface'}, {'type': 'scatter3d'},]], 
                    subplot_titles=["Posterior Mean", "Target"])

fig.add_trace(
    go.Surface(
        x=x, 
        y=y,
        z=z,
        surfacecolor=preds.view_as(x),
        coloraxis="coloraxis",
    ),
    row=1, col=1
)


fig.add_trace(
    go.Surface(
        x=x, 
        y=y,
        z=z,
        surfacecolor=test_targets.view_as(x),
        coloraxis="coloraxis",
    ),
    row=1, col=2
)


fig.update_layout(
    coloraxis=dict(
        colorscale="plasma", 
    )
)

fig.show()