# Load in data
Currently supported datasets: power, protein, kin8nm

In [81]:
import torch 
import gpytorch 
import geometric_kernels.torch 

torch.set_default_dtype(torch.float64)

### Create data on S^D

In [5]:
from geometric_kernels.spaces import Hypersphere


def get_space_and_data(dim):
    space = Hypersphere(dim)
    x = torch.tensor(space.random_uniform(100))
    y = torch.sin(x[:, 0])
    return space, x, y

### Make sure that these god-forsaken functions are correctly implemented by reproducing the geometric_kernels kernel

In [6]:
from torch import Tensor 

import torch 
from math import comb 
from spherical_harmonics import SphericalHarmonics


def num_harmonics_single(ell: int, d: int) -> int:
    r"""
    Number of spherical harmonics of degree ell on S^{d - 1}.
    """
    if ell == 0:
        return 1
    if d == 3:
        return 2 * ell + 1
    else:
        return (2 * ell + d - 2) * comb(ell + d - 3, ell - 1) // ell


def num_harmonics(ell: Tensor | int, d: int) -> Tensor:
    """
    Vectorized version of num_harmonics_single
    """
    if isinstance(ell, int):
        return num_harmonics_single(ell, d)
    return ell.apply_(lambda e: num_harmonics_single(ell=e, d=d))


def total_num_harmonics(max_ell: int, d: int) -> int:
    """
    Total number of spherical harmonics on S^{d-1} with degree <= max_ell
    """
    return int(sum(num_harmonics(ell=torch.arange(max_ell + 1), d=d)))


def eigenvalue_laplacian(ell: Tensor, d: int) -> Tensor:
    """
    Eigenvalue of the Laplace-Beltrami operator for a spherical harmonic of degree ell on S_{d-1}
    ell: [...]
    d: []
    return: [...]
    """
    return ell * (ell + d - 2)


def unnormalized_matern_spectral_density(n: Tensor, d: int, kappa: Tensor | float, nu: Tensor | float) -> Tensor | float: 
    """
    compute (unnormalized) spectral density of the matern kernel on S_{d-1}
    n: [N]
    d: []
    kappa: [O, 1, 1]
    nu: [O, 1, 1]
    return: [O, 1, N]
    """
    # Squared exponential kernel 
    if torch.all(nu.isinf()):
        exponent = -kappa ** 2 / 2 * eigenvalue_laplacian(ell=n, d=d) # [O, N, 1]
        return torch.exp(exponent)
    # Matern kernel
    else:
        base = (
            2.0 * nu / kappa**2 + # [O, 1, 1]
            eigenvalue_laplacian(ell=n, d=d).unsqueeze(-1) # [N, 1]
        ) # [O, N, 1]
        exponent = -nu - (d - 1) / 2.0 # [O, 1, 1]
        return base ** exponent # [O, N, 1]


def matern_spectral_density_normalizer(d: int, max_ell: int, kappa: Tensor | float, nu: Tensor | float) -> Tensor:
    """
    Normalizing constant for the spectral density of the Matern kernel on S^{d-1}. 
    Depends on kappa and nu. Also depends on max_ell, as truncation of the infinite 
    sum from Karhunen-Loeve decomposition. 
    """
    n = torch.arange(max_ell + 1)
    spectral_values = unnormalized_matern_spectral_density(n=n, d=d, kappa=kappa, nu=nu) # [O, max_ell + 1, 1]
    num_harmonics_per_level = num_harmonics(torch.arange(max_ell + 1), d=d).type(spectral_values.dtype) # [max_ell + 1]
    normalizer = spectral_values.mT @ num_harmonics_per_level # [O, 1, max_ell + 1] @ [max_ell + 1] -> [O, 1]
    return normalizer.unsqueeze(-2) # [O, 1, 1]


def matern_spectral_density(n: Tensor, d: int, kappa: Tensor, nu: Tensor, max_ell: int, sigma: float = 1.0) -> Tensor:
    """
    Spectral density of the Matern kernel on S^{d-1}
    """
    return (
        unnormalized_matern_spectral_density(n=n, d=d, kappa=kappa, nu=nu) / # [O, N, 1]
        matern_spectral_density_normalizer(d=d, max_ell=max_ell, kappa=kappa, nu=nu) * # [O, 1, 1]
        (sigma ** 2)[..., *(None,) * (kappa.ndim - 1)] # [O, 1, 1] NOTE the reason for this seemingly overcomplicated broadcasting is that sigma can be a scalar if O is empty
    ) # [O, N, 1] / [O, 1, 1] * [O, 1, 1] -> [O, N, 1]


def matern_ahat(ell: Tensor, d: int, max_ell: int, kappa: Tensor | float, nu: Tensor | float, 
                m: int | None = None, sigma: Tensor | float = 1.0) -> float:
    """
    :math: `\hat{a} = \rho(\ell)` where :math: `\rho` is the spectral density on S^{d-1}
    """
    return matern_spectral_density(n=ell, d=d, kappa=kappa, nu=nu, max_ell=max_ell, sigma=sigma) # [O, N, 1]


def matern_repeated_ahat(max_ell: int, d: int, kappa: Tensor | float, nu: Tensor | float, sigma: Tensor | float = 1.0) -> Tensor:
    """
    Returns a tensor of repeated ahat values for each ell. 
    """
    ells = torch.arange(max_ell + 1) # [max_ell + 1]
    ahat = matern_ahat(ell=ells, d=d, max_ell=max_ell, kappa=kappa, nu=nu, sigma=sigma) # [O, max_ell + 1, 1]
    repeats = num_harmonics(ell=ells, d=d) # [max_ell + 1]
    return torch.repeat_interleave(ahat, repeats=repeats, dim=-2) # [O, num_harmonics, 1]


def matern_Kuu(max_ell: int, d: int, kappa: float, nu: float, sigma: float = 1.0) -> Tensor : 
    """
    Returns the covariance matrix, which is a diagonal matrix with entries 
    equal to inv_ahat of the corresponding ell. 
    """
    return torch.diag(1 / matern_repeated_ahat(max_ell, d, kappa, nu, sigma=sigma).squeeze(-1)) # [O, num_harmonics, num_harmonics]


def spherical_harmonics(x: Tensor, max_ell: int, d: int) -> Tensor: 
    # Make sure that x is at least 2d and flatten it
    x = torch.atleast_2d(x)
    batch_shape, n = x.shape[:-2], x.shape[-2]
    x = x.flatten(0, -2)

    # Get spherical harmonics callable
    f = SphericalHarmonics(dimension=d, degrees=max_ell + 1) # [... * O, N, num_harmonics]

    # Evaluate x and reintroduce batch dimensions
    return f(x).reshape(*batch_shape, n, total_num_harmonics(max_ell, d)) # [..., O, N, num_harmonics]


def matern_Kux(x: Tensor, max_ell: int, d: int) -> Tensor: 
    return spherical_harmonics(x, max_ell=max_ell, d=d).mT # [..., O, num_harmonics, N]


def num_spherical_harmonics_to_degree(num_spherical_harmonics: int, dimension: int) -> tuple[int, int]:
    """
    Returns the minimum degree for which there are at least
    `num_eigenfunctions` in the collection.
    """
    n, degree = 0, 0  # n: number of harmonics, d: degree (or level)
    while n < num_spherical_harmonics:
        n += num_harmonics(d=dimension, ell=degree)
        degree += 1

    if n > num_spherical_harmonics:
        print(
            "The number of spherical harmonics requested does not lead to complete "
            "levels of spherical harmonics. We have thus increased the number to "
            f"{n}, which includes all spherical harmonics up to degree {degree} (incl.)"
        )
    return degree - 1, n


In [7]:
from geometric_kernels.kernels import MaternKarhunenLoeveKernel

In [30]:
def matern_kernel(x: Tensor, y:Tensor, max_ell: int, d: int, kappa: float, nu: float, sigma: float = 1.0) -> Tensor:
    """
    Returns the kernel matrix for the Matern kernel on S^{d-1}
    x: [..., O, N, D]
    max_ell: []
    d: []
    kappa: [O, 1, 1]
    nu: [O, 1, 1]
    sigma: [O, 1, 1]
    """
    Kux = matern_Kux(x, max_ell=max_ell, d=d) # [..., O, num_harmonics, N]
    Kuy = matern_Kux(y, max_ell=max_ell, d=d) # [..., O, num_harmonics, M]
    ahat = matern_repeated_ahat(max_ell=max_ell, d=d, kappa=kappa, nu=nu, sigma=sigma) # [O, num_harmonics, 1]
    return (ahat * Kux).mT @ Kuy

In [31]:
x = torch.tensor([[0.0, 0.0, 1.0], [0.0, 1.0, 0.0]])
y = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
nu = torch.tensor([[2.5]])
kappa = torch.tensor([[1.0]])
max_ell = 2
d = 3
sigma = torch.tensor(1.0)
K_mine = matern_kernel(x, y, max_ell, d, kappa, nu, sigma)

params = {
    'lengthscale': kappa, 
    'nu': nu,
}
K_theirs = MaternKarhunenLoeveKernel(Hypersphere(d - 1), num_levels=max_ell + 1, normalize=True).K(
    params=params, X=x, X2=y
)

print(K_mine, K_theirs)

tensor([[0.3757, 0.3757, 1.0000],
        [0.3757, 1.0000, 0.3757]]) tensor([[0.3757, 0.3757, 1.0000],
        [0.3757, 1.0000, 0.3757]])


### Since the two implementation coincide, we can trust in the implementation of Kux and repeated_ahat. Since Kuu is correct iff repeated_ahat is correct, we can also trust Kuu. In principle the correctness of the implementation of the variational posterior depends only on Kux and Kuu, so it seems that we can trust the implemenation of the variational posterior.

### Ensure that all implementations of conditioning on variational posterior give the same results

In [135]:
def matern_Phi(x, max_ell, d, kappa, nu, sigma):
    Kux = matern_Kux(x, max_ell, d)
    ahat_sqrt = matern_repeated_ahat(max_ell, d, kappa, nu, sigma)
    return Kux * ahat_sqrt


def matern_LT_Phi(x: Tensor, max_ell: int, d: int, kappa: float, nu: float, sigma: float = 1.0) -> Tensor: 
    Kux = matern_Kux(x, max_ell=max_ell, d=d) # [..., O, num_harmonics, N]
    ahat_sqrt = matern_repeated_ahat(max_ell=max_ell, d=d, kappa=kappa, nu=nu, sigma=sigma).sqrt() # [O, num_harmonics, 1]
    return Kux * ahat_sqrt # [..., O, num_harmonics, N]


def variational_posterior_covar_full(x, K, S, max_ell, kappa, nu, sigma):
    Kxx = K(x)
    Kuu = matern_Kuu(max_ell, d, kappa, nu, sigma)
    Kuu_inv = Kuu.inverse()
    Kux = matern_Kux(x, max_ell, d)
    Kxu = Kux.mT
    return Kxx + Kxu @ Kuu_inv @ (S - Kuu) @ Kuu_inv @ Kux


def variational_posterior_covar_paper(x, K, S, max_ell, kappa, nu, sigma):
    Kxx = K(x)
    Kuu = matern_Kuu(max_ell, d, kappa, nu, sigma)
    Phiux = matern_Phi(x, max_ell, d, kappa, nu, sigma)
    Phixu = Phiux.mT
    return Kxx + Phixu @ (S - Kuu) @ Phiux


def variational_posterior_covar_whitened(x, K, Linv_S_LTinv, max_ell, kappa, nu, sigma):
    Kxx = K(x)
    Linv_Kuu_LTinv = torch.eye(total_num_harmonics(max_ell, d))
    LT_Phiux = matern_LT_Phi(x, max_ell, d, kappa, nu, sigma)
    Phixu_L = LT_Phiux.mT
    return Kxx + Phixu_L @ (Linv_S_LTinv - Linv_Kuu_LTinv) @ LT_Phiux


# assumes zero mean prior 
def variational_posterior_mean_full(x, m, max_ell, kappa, nu, sigma):
    Kuu_inv = matern_Kuu(max_ell, d, kappa, nu, sigma).inverse()
    Kxu = matern_Kux(x, max_ell, d).mT
    return Kxu @ Kuu_inv @ m


def variational_posterior_mean_paper(x, m, max_ell, kappa, nu, sigma):
    Phixu = matern_Phi(x, max_ell, d, kappa, nu, sigma).mT
    return Phixu @ m


def variational_posterior_mean_whitened(x, Linv_m, max_ell, kappa, nu, sigma):
    Phixu_L = matern_LT_Phi(x, max_ell, d, kappa, nu, sigma).mT
    return Phixu_L @ Linv_m


In [137]:
# Variational distribution
z = total_num_harmonics(max_ell, d)
C = torch.tril(torch.randn(z, z))
S = C @ C.mT

# Whitened variational distribution
L = matern_Kuu(max_ell, d, kappa, nu, sigma).sqrt()
Linv_S_LTinv = L.inverse() @ S @ L.inverse().mT

# Prior covariance
K = lambda x: matern_kernel(x, x, max_ell, d, kappa, nu, sigma)
q_covar_full = variational_posterior_covar_full(x, K, S, max_ell, kappa, nu, sigma)
q_covar_paper = variational_posterior_covar_paper(x, K, S, max_ell, kappa, nu, sigma)
q_covar_whitened = variational_posterior_covar_whitened(x, K, Linv_S_LTinv, max_ell, kappa, nu, sigma)

assert torch.allclose(q_covar_full, q_covar_paper) and torch.allclose(q_covar_full, q_covar_whitened), "All variational posterior covariances should be equivalent."


# Prior mean
m = torch.randn(z)
Linv_m = L.inverse() @ m
q_mean_full = variational_posterior_mean_full(x, m, max_ell, kappa, nu, sigma)
q_mean_paper = variational_posterior_mean_paper(x, m, max_ell, kappa, nu, sigma)
q_mean_whitened = variational_posterior_mean_whitened(x, Linv_m, max_ell, kappa, nu, sigma)

assert torch.allclose(q_mean_full, q_mean_paper) and torch.allclose(q_mean_full, q_mean_whitened), "All variational posterior means should be equivalent."

### Maybe manually implement the ELBO?

In [142]:
import math 


def log_gaussian_likelihood(y, f, sigma):
    n = len(y)
    log_p_y_given_f = -0.5 * (y - f).T @ (y -f) / sigma**2
    log_p_y_given_f += -0.5 * n * math.log(2 * math.pi) - n * torch.log(sigma)
    return log_p_y_given_f


def kl_divergence_between_gaussians(mu1, sigma1, mu2, sigma2):
    d = mu1.shape[-1]
    sigma2_inv = sigma2.inverse()
    mu_diff = mu2 - mu1
    return 0.5 * (torch.logdet(sigma2) - torch.logdet(sigma1) - d + (mu_diff.T @ sigma2_inv @ mu_diff).item() + torch.trace(sigma2_inv @ sigma1))


def elbo(y, p_mu, p_sigma, q_mu, q_sigma, epsilon_sigma):
    return log_gaussian_likelihood(y, q_mu, epsilon_sigma) - kl_divergence_between_gaussians(q_mu, q_sigma, p_mu, p_sigma)


def torch_log_gaussian_likelihood(y, f, sigma):
    return torch.distributions.MultivariateNormal(f, torch.eye(len(y)) * sigma ** 2).log_prob(y)


def torch_kl_divergence_between_gaussians(mu1, sigma1, mu2, sigma2):
    p = torch.distributions.MultivariateNormal(mu1, sigma1)
    q = torch.distributions.MultivariateNormal(mu2, sigma2)
    return torch.distributions.kl.kl_divergence(p, q)

In [143]:
y = torch.randn(len(x))
p_mu = torch.zeros_like(y)
p_sigma = K(x)
q_mu = torch.randn_like(y)
q_sigma = q_covar_full
epsilon_sigma = torch.tensor(1e-4)

# Test KL divergence 
kl_mine = kl_divergence_between_gaussians(q_mu, q_sigma, p_mu, p_sigma)
kl_theirs = torch_kl_divergence_between_gaussians(q_mu, q_sigma, p_mu, p_sigma)
assert torch.allclose(kl_mine, kl_theirs), "KL divergences should be equal."

# Test log likelihood 
log_likelihood_mine = log_gaussian_likelihood(y, q_mu, epsilon_sigma)
log_likelihood_theirs = torch_log_gaussian_likelihood(y, q_mu, epsilon_sigma)
assert torch.allclose(log_likelihood_mine, log_likelihood_theirs), "Log likelihoods should be equal."

#### Seeing that both log likelihood and KL divergence is correct, we can use our ELBO to optimise the model

In [144]:
def add_jitter(x, jitter):
    return x + torch.eye(x.shape[-1]) * jitter


class SGP(torch.nn.Module):
    def __init__(self, max_ell, d, epsilon_sigma=0.01, jitter=1e-6): 
        super().__init__()
        self.jitter = jitter
        self.max_ell = max_ell

        # Variational parameters
        m = total_num_harmonics(self.max_ell, d)
        self.Linv_S_LTinv = torch.nn.Parameter(torch.eye(m))
        self.Linv_m = torch.nn.Parameter(torch.randn(m))
        
        # fix prior hyperparams 
        self.kappa = torch.tensor(1.0)
        self.nu = torch.tensor(2.5)
        self.sigma = torch.tensor(1.0)
        self.epsilon_sigma = torch.tensor(epsilon_sigma)

    def K(self, x):
        return matern_kernel(x, x, self.max_ell, d, self.kappa, self.nu, self.sigma)
    
    def q(self, x):
        sigma = variational_posterior_covar_whitened(x, self.K, self.Linv_S_LTinv, self.max_ell, self.kappa, self.nu, self.sigma)
        sigma = add_jitter(sigma, self.jitter)
        mu = variational_posterior_mean_whitened(x, self.Linv_m, self.max_ell, self.kappa, self.nu, self.sigma)
        return mu, sigma
    
    def forward(self, x):
        return self.q(x)
    
    def p(self, x):
        sigma = self.K(x)
        sigma = add_jitter(sigma, self.jitter)

        mu = torch.zeros(x.shape[:-1])
        return mu, sigma


In [145]:
from mdgp.utils.sphere import sphere_uniform_grid, spherical_harmonic


model = SGP(max_ell, d)
train_x = sphere_uniform_grid(100)
train_f = spherical_harmonic(train_x, 2, 3)
train_f = (train_f - train_f.mean()) / train_f.std()
train_y = train_f + torch.randn_like(train_f) * 0.01

In [146]:
q_mu, q_sigma = model.q(train_x)
p_mu, p_sigma = model.p(train_x)
epsilon_sigma = model.epsilon_sigma
q_mu, q_sigma, p_mu, p_sigma, epsilon_sigma

log_gaussian_likelihood(train_y, q_mu, epsilon_sigma), kl_divergence_between_gaussians(q_mu, q_sigma, p_mu, p_sigma), torch_log_gaussian_likelihood(train_y, q_mu, epsilon_sigma), torch_kl_divergence_between_gaussians(q_mu, q_sigma, p_mu, p_sigma)

(tensor(-827190.9862, grad_fn=<AddBackward0>),
 tensor(4.6184, grad_fn=<MulBackward0>),
 tensor(-827190.9862, grad_fn=<SubBackward0>),
 tensor(4.6184, grad_fn=<AddBackward0>))

In [149]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for step in range(1000):
    optimizer.zero_grad()
    q_mu, q_sigma = model.q(train_x)
    p_mu, p_sigma = model.p(train_x)
    loss = -elbo(train_y, p_mu, p_sigma, q_mu, q_sigma, model.epsilon_sigma) / len(train_y)
    loss.backward()
    optimizer.step()
    print(loss.item())


4966.510410591241
4964.254661275754
4961.380098217402
4959.121456981891
4957.124198089693
4955.096837987876
4953.174545093312
4951.5003428840455
4950.0931000608925
4948.858512496029
4947.713687540028
4946.6632234499975
4945.754541522602
4945.015534488616
4944.426029200804
4943.935985819044
4943.511117907346
4943.153338227842
4942.8820519031615
4942.705250999055
4942.606083392888
4942.554097142201
4942.5282201183445
4942.5265635163705
4942.5561268540305
4942.616595281867
4942.694056455088
4942.769425669033
4942.831731094461
4942.882073202071
4942.925962114064
4942.963731063731
4942.9887573770575
4942.994264931321
4942.980381139038
4942.953728389382
4942.92098312645
4942.883898372489
4942.840650395798
4942.791016636282
4942.738938248274
4942.6898501569385
4942.646436257241
4942.607716502531
4942.572019156372
4942.539881540654
4942.513543336129
4942.49418259317
4942.4805346911835
4942.470406310856
4942.462919067646
4942.45870465946
4942.458158953975
4942.460225242537
4942.4631732076205
49

KeyboardInterrupt: 

### Plot results

In [151]:
from mdgp.utils.sphere import sphere_meshgrid


test_x = sphere_meshgrid(50, 50)
with torch.no_grad():
    x, y, z = test_x.unbind(-1)
    surfacecolor = model(test_x.view(-1, d))[0].view(x)

In [152]:
from plotly import graph_objects as go 


fig = go.Figure(
    go.Surface(
        x=x, y=y, z=z, surfacecolor=surfacecolor, colorscale='Viridis'
    )
)


ValueError: 
    Invalid value of type 'builtins.int' received for the 'z' property of surface
        Received value: 9

    The 'z' property is an array that may be specified as a tuple,
    list, numpy array, or pandas Series