In [4]:
from torch import Tensor 


import geometric_kernels.torch
import torch 
from math import comb 
from spherical_harmonics import SphericalHarmonics
from gpytorch.kernels import Kernel, ScaleKernel
from mdgp.variational.spherical_harmonic_features.utils import *

# Test if utils functions can build a Matern kernel correctly

In [5]:
from geometric_kernels.kernels import MaternKarhunenLoeveKernel
from geometric_kernels.spaces import Hypersphere 


def matern_kernel(x: Tensor, y: Tensor, max_ell: int, max_ell_prior: int, d: int, kappa: float, nu: float, sigma: float = 1.0) -> Tensor:
    """
    Returns the kernel matrix for the Matern kernel on S^{d-1}
    x: [..., O, N, D]
    max_ell: []
    d: []
    kappa: [O, 1, 1]
    nu: [O, 1, 1]
    sigma: [O, 1, 1]
    """
    Kux = matern_Kux(x, max_ell=max_ell, d=d) # [..., O, num_harmonics, N]
    Kuy = matern_Kux(y, max_ell=max_ell, d=d) # [..., O, num_harmonics, M]
    ahat = matern_repeated_ahat(max_ell=max_ell, max_ell_prior=max_ell_prior, d=d, kappa=kappa, nu=nu, sigma=sigma) # [O, num_harmonics, 1]
    return (ahat * Kux).mT @ Kuy


def geometric_kernels_matern_kernel(x: Tensor, y: Tensor, max_ell: int, d: int, kappa: float, nu: float, sigma: float = 1.0) -> Tensor:
    """
    Returns the kernel matrix for the Matern kernel on S^{d-1}
    x: [..., O, N, D]
    max_ell: []
    d: []
    kappa: [O, 1, 1]
    nu: [O, 1, 1]
    sigma: [O, 1, 1]
    """
    params = {
        'lengthscale': kappa,
        'nu': nu,
    }
    kernel = MaternKarhunenLoeveKernel(
        space=Hypersphere(d),
        num_levels=max_ell,
        normalize=True,
    )
    return kernel.K(params, x, y) * sigma ** 2


x = torch.tensor([[0.0, 0.0, 1.0], [0.0, 1.0, 0.0]])
y = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
nu = torch.tensor([[2.5]])
kappa = torch.tensor([[1.0]])
max_ell = 5
d = 2
sigma = torch.tensor(1.0)


K_mine = matern_kernel(x, y, max_ell, max_ell, d, kappa, nu, sigma)
K_theirs = geometric_kernels_matern_kernel(x, y, max_ell, d, kappa, nu, sigma)

assert torch.allclose(K_mine, K_theirs)

    https://beartype.readthedocs.io/en/latest/api_roar/#pep-585-deprecations
  warn(


# Test calculation of minimal level containing the desired number of harmonics

In [6]:
num_spherical_harmonics = 100
level, least_upper_bound = num_spherical_harmonics_to_num_levels(num_spherical_harmonics, d)
assert least_upper_bound >= num_spherical_harmonics
assert least_upper_bound == total_num_harmonics(level, d), (
    f"The least upper bound {least_upper_bound} is not equal to the number of spherical harmonics {total_num_harmonics(level, d)}"
)