In [1]:
from torch import Tensor 


import geometric_kernels.torch
import torch 
from math import comb 
from spherical_harmonics import SphericalHarmonics
from gpytorch.kernels import Kernel, ScaleKernel
from mdgp.variational.spherical_harmonic_features.utils import *

INFO: Using numpy backend


In [21]:
from torch import Tensor 


import torch 
from math import comb 
from spherical_harmonics import SphericalHarmonics
from gpytorch.kernels import Kernel, ScaleKernel



def num_harmonics_single(ell: int, d: int) -> int:
    r"""
    Number of spherical harmonics of degree ell on S^d.
    """
    if ell == 0:
        return 1
    if d == 2:
        return 2 * ell + 1
    else:
        return (2 * ell + d - 1) * comb(ell + d - 2, ell - 1) // ell


def num_harmonics(ell: Tensor, d: int) -> Tensor:
    """
    Number of spherical harmonics of degree ell on S^d.
    """
    return ell.apply_(lambda e: num_harmonics_single(ell=e, d=d)).int()


def total_num_harmonics(max_ell: int, d: int) -> int:
    """
    Total number of spherical harmonics on S^d with degree < max_ell
    """
    return num_harmonics(ell=torch.arange(max_ell), d=d).sum().item()


def eigenvalue_laplacian(ell: Tensor, d: int) -> Tensor:
    """
    Eigenvalue of the Laplace-Beltrami operator for a spherical harmonic of degree ell on S_{d}
    ell: [...]
    d: []
    return: [...]
    """
    return ell * (ell + d - 1)


def unnormalized_matern_spectral_density(n: Tensor, d: int, kappa: Tensor, nu: Tensor) -> Tensor: 
    """
    compute (unnormalized) spectral density of the matern kernel on S_{d}
    n: [N]
    d: []
    kappa: [O, 1, 1]
    nu: [O, 1, 1]
    return: [O, 1, N]
    """
    # Squared exponential kernel 
    if nu.isinf().all():
        exponent = -kappa ** 2 / 2 * eigenvalue_laplacian(ell=n, d=d) # [O, N, 1]
        return torch.exp(exponent)
    # Matern kernel
    else:
        base = (
            2.0 * nu / kappa**2 + # [O, 1, 1]
            eigenvalue_laplacian(ell=n, d=d).unsqueeze(-1) # [N, 1]
        ) # [O, N, 1]
        exponent = -nu - d / 2.0 # [O, 1, 1]
        return base ** exponent # [O, N, 1]


def matern_spectral_density_normalizer(d: int, max_ell: int, kappa: Tensor, nu: Tensor) -> Tensor:
    """
    Normalizing constant for the spectral density of the Matern kernel on S^d. 
    Depends on kappa and nu. Also depends on max_ell, as truncation of the infinite 
    sum from Karhunen-Loeve decomposition. 
    """
    n = torch.arange(max_ell)
    spectral_values = unnormalized_matern_spectral_density(n=n, d=d, kappa=kappa, nu=nu) # [O, max_ell + 1, 1]
    num_harmonics_per_level = num_harmonics(torch.arange(max_ell), d=d).type(spectral_values.dtype) # [max_ell + 1]
    normalizer = spectral_values.mT @ num_harmonics_per_level # [O, 1, max_ell + 1] @ [max_ell + 1] -> [O, 1]
    return normalizer.unsqueeze(-2) # [O, 1, 1]


def matern_spectral_density(n: Tensor, d: int, kappa: Tensor, nu: Tensor, max_ell: int, sigma: float = 1.0) -> Tensor:
    """
    Spectral density of the Matern kernel on S^{d-1}
    """
    return (
        unnormalized_matern_spectral_density(n=n, d=d, kappa=kappa, nu=nu) / # [O, N, 1]
        matern_spectral_density_normalizer(d=d, max_ell=max_ell, kappa=kappa, nu=nu) * # [O, 1, 1]
        (sigma ** 2)[..., *(None,) * (kappa.ndim - 1)] # [O, 1, 1]
    ) # [O, N, 1] / [O, 1, 1] * [O, 1, 1] -> [O, N, 1]


def matern_ahat(ell: Tensor, d: int, max_ell: int, kappa: Tensor | float, nu: Tensor | float, 
                m: int | None = None, sigma: Tensor | float = 1.0) -> float:
    """
    :math: `\hat{a} = \rho(\ell)` where :math: `\rho` is the spectral density on S^{d-1}
    """
    return matern_spectral_density(n=ell, d=d, kappa=kappa, nu=nu, max_ell=max_ell, sigma=sigma) # [O, N, 1]


def matern_repeated_ahat(max_ell: int, max_ell_prior: int, d: int, kappa: Tensor | float, nu: Tensor | float, sigma: Tensor | float = 1.0) -> Tensor:
    """
    Returns a tensor of repeated ahat values for each ell. 
    """
    ells = torch.arange(max_ell) # [max_ell + 1]
    ahat = matern_ahat(ell=ells, d=d, max_ell=max_ell_prior, kappa=kappa, nu=nu, sigma=sigma) # [O, max_ell + 1, 1]
    repeats = num_harmonics(ell=ells, d=d) # [max_ell + 1]
    return torch.repeat_interleave(ahat, repeats=repeats, dim=-2) # [O, num_harmonics, 1]


def matern_Kuu(max_ell: int, d: int, kappa: float, nu: float, sigma: float = 1.0) -> Tensor: 
    """
    Returns the covariance matrix, which is a diagonal matrix with entries 
    equal to inv_ahat of the corresponding ell. 
    """
    return torch.diag(1 / matern_repeated_ahat(max_ell, d, kappa, nu, sigma=sigma).squeeze(-1)) # [O, num_harmonics, num_harmonics]


def spherical_harmonics(x: Tensor, max_ell: int, d: int) -> Tensor: 
    # Flatten -> evaluate -> unflatten
    x = torch.atleast_2d(x)
    batch_shape, n = x.shape[:-2], x.shape[-2]
    x = x.flatten(0, -2)

    # SphericalHarmonics works with S^{d-1}, while we work with S^d as in GeometricKernels. 
    # Also SphericalHarmonics uses levels up to `degrees` (exclusive); hence, the +1. 
    x = SphericalHarmonics(dimension=d + 1, degrees=max_ell)(x) # [... * O, N, num_harmonics]

    return x.reshape(*batch_shape, n, total_num_harmonics(max_ell, d)) # [..., O, N, num_harmonics]


def matern_Kux(x: Tensor, max_ell: int, d: int) -> Tensor: 
    return spherical_harmonics(x, max_ell=max_ell, d=d).mT # [..., O, num_harmonics, N]


def num_spherical_harmonics_to_num_levels(num_spherical_harmonics: int, dimension: int) -> tuple[int, int]:
    """
    :return: (level, least_upper_bound)
    """
    least_upper_bound, level = 0, 0
    while least_upper_bound < num_spherical_harmonics:
        least_upper_bound += num_harmonics_single(d=dimension, ell=level)
        level += 1

    if least_upper_bound > num_spherical_harmonics:
        print(
            "The number of spherical harmonics requested does not lead to complete "
            "levels of spherical harmonics. We have thus increased the number to "
            f"{least_upper_bound}, which includes all spherical harmonics up to level {level} (exclusive)"
        )
    return level, least_upper_bound


def matern_LT_Phi(x: Tensor, max_ell: int, max_ell_prior: int, d: int, kappa: float, nu: float, sigma: float = 1.0) -> Tensor: 
    Kux = matern_Kux(x, max_ell=max_ell, d=d) # [..., O, num_harmonics, N]
    ahat_sqrt = matern_repeated_ahat(max_ell=max_ell, max_ell_prior=max_ell_prior, d=d, kappa=kappa, nu=nu, sigma=sigma).sqrt() # [O, num_harmonics, 1]
    return Kux * ahat_sqrt # [..., O, num_harmonics, N]


def matern_LT_Phi_from_kernel(x: Tensor, covar_module: Kernel, num_levels: int, num_levels_prior: int) -> Tensor: 
    if isinstance(covar_module, ScaleKernel):
        sigma = covar_module.outputscale.sqrt()
        base_kernel = covar_module.base_kernel
    else:
        base_kernel = covar_module
        sigma = torch.ones(base_kernel.lengthscale.batch_shape, dtype=x.dtype, device=x.device)
    kappa = base_kernel.lengthscale
    nu = base_kernel.nu 

    # TODO Obtaining dimension in this way seems a bit error-prone
    d = base_kernel.space.dimension
    max_ell = num_levels
    max_ell_prior = num_levels_prior
    
    return matern_LT_Phi(x, max_ell=max_ell, max_ell_prior=max_ell_prior, d=d, kappa=kappa, nu=nu, sigma=sigma)


# Test if utils functions can build a Matern kernel correctly

In [22]:
from geometric_kernels.kernels import MaternKarhunenLoeveKernel
from geometric_kernels.spaces import Hypersphere 


def matern_kernel(x: Tensor, y: Tensor, max_ell: int, max_ell_prior: int, d: int, kappa: float, nu: float, sigma: float = 1.0) -> Tensor:
    """
    Returns the kernel matrix for the Matern kernel on S^{d-1}
    x: [..., O, N, D]
    max_ell: []
    d: []
    kappa: [O, 1, 1]
    nu: [O, 1, 1]
    sigma: [O, 1, 1]
    """
    Kux = matern_Kux(x, max_ell=max_ell, d=d) # [..., O, num_harmonics, N]
    Kuy = matern_Kux(y, max_ell=max_ell, d=d) # [..., O, num_harmonics, M]
    ahat = matern_repeated_ahat(max_ell=max_ell, max_ell_prior=max_ell_prior, d=d, kappa=kappa, nu=nu, sigma=sigma) # [O, num_harmonics, 1]
    return (ahat * Kux).mT @ Kuy


def geometric_kernels_matern_kernel(x: Tensor, y: Tensor, max_ell: int, d: int, kappa: float, nu: float, sigma: float = 1.0) -> Tensor:
    """
    Returns the kernel matrix for the Matern kernel on S^{d-1}
    x: [..., O, N, D]
    max_ell: []
    d: []
    kappa: [O, 1, 1]
    nu: [O, 1, 1]
    sigma: [O, 1, 1]
    """
    params = {
        'lengthscale': kappa,
        'nu': nu,
    }
    kernel = MaternKarhunenLoeveKernel(
        space=Hypersphere(d),
        num_levels=max_ell,
        normalize=True,
    )
    return kernel.K(params, x, y) * sigma ** 2


x = torch.tensor([[0.0, 0.0, 1.0], [0.0, 1.0, 0.0]])
y = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
nu = torch.tensor([[2.5]])
kappa = torch.tensor([[1.0]])
sigma = torch.tensor([1.0])
max_ell = 5
d = 2


K_mine = matern_kernel(x, y, max_ell, max_ell, d, kappa, nu, sigma)
K_theirs = geometric_kernels_matern_kernel(x, y, max_ell, d, kappa, nu, sigma)

assert torch.allclose(K_mine, K_theirs)

# Test correct kernel behaviour with batches

In [23]:
from mdgp.utils import sphere_random_uniform

x = sphere_random_uniform(2, 7, 3)

nu = torch.tensor([[[2.5]], [[2.5]]])
kappa = torch.tensor([[[1.0]], [[1.0]]])
sigma = torch.tensor([[1.0], [2.0]])
max_ell = 5
d = 2

K1, K2 = matern_kernel(x, x, max_ell, max_ell, d, kappa, nu, sigma)
assert K1.shape == (2, 7, 7), K2.shape == (2, 7, 7)
assert torch.allclose(K1 * sigma[1] ** 2, K2)

# Test calculation of minimal level containing the desired number of harmonics

In [4]:
num_spherical_harmonics = 100
level, least_upper_bound = num_spherical_harmonics_to_num_levels(num_spherical_harmonics, d)
assert least_upper_bound >= num_spherical_harmonics
assert least_upper_bound == total_num_harmonics(level, d), (
    f"The least upper bound {least_upper_bound} is not equal to the number of spherical harmonics {total_num_harmonics(level, d)}"
)

# Test correct behaviour with batches

In [11]:
from mdgp.utils import sphere_random_uniform

x_batch = sphere_random_uniform(17, 3, 12, 3)
x_nobatch = sphere_random_uniform(12, 3)
assert spherical_harmonics(x_nobatch, max_ell, d).shape == (12, total_num_harmonics(max_ell, d))
assert spherical_harmonics(x_batch, max_ell=max_ell, d=d).shape == (17, 3, 12, total_num_harmonics(max_ell, d))