# Load in data
Currently supported datasets: power, protein, kin8nm

In [2]:
import torch 
import geometric_kernels.torch 

torch.set_default_dtype(torch.float64)

### Create data on S^D

In [3]:
from geometric_kernels.spaces import Hypersphere


def get_space_and_data(dim):
    space = Hypersphere(dim)
    x = torch.tensor(space.random_uniform(100))
    y = torch.sin(x[:, 0])
    return space, x, y

INFO: Using numpy backend


# Instantiate model

This is done using the same model arguments as for the benchmarking experiment

In [4]:
from torch import Tensor 
from gpytorch.means import Mean
from gpytorch.kernels import ScaleKernel
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from mdgp.kernels import GeometricMaternKernel


import torch 
from gpytorch import Module, settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.utils.memoize import cached, clear_cache_hook
from linear_operator.operators import DiagLinearOperator
from functools import cached_property
# TODO Maybe just move the functions from spherical_harmonic_features.py into this file?
from mdgp.utils.spherical_harmonic_features import num_spherical_harmonics_to_degree, matern_Kuu, matern_LT_Phi

In [5]:
class VariationalDistribution(torch.nn.Module):
    def __init__(self, num_inducing: int):
        super().__init__()
        self.num_inducing = num_inducing
        self.chol_variational_covar = torch.nn.Parameter(torch.eye(num_inducing))

    def forward(self) -> MultivariateNormal:
        covar = self.chol_variational_covar @ self.chol_variational_covar.mT
        return MultivariateNormal(torch.zeros(self.num_inducing), covar)
    

class SHVariationalStrategy(torch.nn.Module):
    def __init__(self,
        covar_module: GeometricMaternKernel,
        dimension: int,
        num_spherical_harmonics: int,
    ):
        super().__init__()
        self.covar_module = covar_module
        self.dimension = dimension
        self.degree, self.num_spherical_harmonics = num_spherical_harmonics_to_degree(num_spherical_harmonics, dimension)
        self._variational_distribution = VariationalDistribution(num_spherical_harmonics)

    @property
    def kappa(self) -> Tensor:
        return torch.tensor([[0.001]])
    
    @property
    def nu(self) -> Tensor | float:
        return torch.tensor([[2.5]])

    @property 
    def sigma(self) -> Tensor: 
        return torch.tensor(1.0)

    @property
    def prior_distribution(self) -> MultivariateNormal:
        covariance_matrix = torch.eye(self.num_spherical_harmonics)
        mean = torch.zeros(self.num_spherical_harmonics)
        return MultivariateNormal(mean=mean, covariance_matrix=covariance_matrix)
        
    @property
    def variational_distribution(self) -> MultivariateNormal:
        return self._variational_distribution()

    def forward(self, x: Tensor, **kwargs) -> MultivariateNormal:
        # inducing-inducing prior
        pu = self.prior_distribution
        invL_Kuu_invLt = pu.lazy_covariance_matrix

        # input-input prior
        Kxx = self.covar_module(x)

        # inducing-inducing variational
        qu = self.variational_distribution
        invL_S_invLt = qu.lazy_covariance_matrix

        # inducing-input prior  
        LT_Phi = matern_LT_Phi(x, max_ell=self.degree, d=self.dimension, kappa=self.kappa, nu=self.nu, sigma=self.sigma,) # [..., O, num_harmonics, N]

        # Update the covariance matrix
        covariance_matrix_update = LT_Phi.mT @ (invL_S_invLt - invL_Kuu_invLt) @ LT_Phi # [O, num_harmonics, num_harmonics] @ [O, num_harmonics, num_harmonics] @ [O, num_harmonics, N] -> [O, num_harmonics, N]
        updated_covariance_matrix = Kxx + covariance_matrix_update # [..., O, N, N] + [..., O, N, N] -> [..., O, N, N]

        return MultivariateNormal(mean=torch.zeros(x.shape[:-1]), covariance_matrix=updated_covariance_matrix)

    def kl_divergence(self) -> Tensor:
        with settings.max_preconditioner_size(0):
            kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
        return kl_divergence

In [16]:
from mdgp.utils.spherical_harmonic_features import matern_Kux


class SHVariationalStrategy(torch.nn.Module):
    def __init__(self,
        covar_module: GeometricMaternKernel,
        dimension: int,
        num_spherical_harmonics: int,
    ):
        super().__init__()
        self.covar_module = covar_module
        self.dimension = dimension
        self.degree, self.num_spherical_harmonics = num_spherical_harmonics_to_degree(num_spherical_harmonics, dimension)
        self._variational_distribution = VariationalDistribution(num_spherical_harmonics)
        self._variational_distribution.chol_variational_covar = torch.nn.Parameter(
            matern_Kuu(max_ell=self.degree, d=self.dimension, kappa=self.kappa, nu=self.nu, sigma=self.sigma).sqrt().to_dense()
        )

    @property
    def kappa(self) -> Tensor:
        return torch.tensor([[0.001]])
    
    @property
    def nu(self) -> Tensor | float:
        return torch.tensor([[2.5]])

    @property 
    def sigma(self) -> Tensor: 
        return torch.tensor(1.0)

    @property
    def prior_distribution(self) -> MultivariateNormal:
        covariance_matrix = matern_Kuu(max_ell=self.degree, d=self.dimension, kappa=self.kappa, nu=self.nu, sigma=self.sigma)
        mean = torch.zeros(self.num_spherical_harmonics)
        return MultivariateNormal(mean=mean, covariance_matrix=covariance_matrix)
        
    @property
    def variational_distribution(self) -> MultivariateNormal:
        return self._variational_distribution()

    def forward(self, x: Tensor, **kwargs) -> MultivariateNormal:
        jitter = 1e-6
        # inducing-inducing prior
        Kuu = matern_Kuu(max_ell=self.degree, d=self.dimension, kappa=self.kappa, nu=self.nu, sigma=self.sigma)
        Kuu = Kuu.add_jitter(jitter)
        Kuu_inv = Kuu.to_dense().inverse()

        # input-input prior
        Kxx = self.covar_module(x)
        Kxx = Kxx.add_jitter(jitter)

        # inducing-input prior
        Kux = matern_Kux(x, max_ell=self.degree, d=self.dimension)

        # inducing-inducing variational
        S = self.variational_distribution.covariance_matrix

        # inducing-input prior  
        R = S - Kuu
        covariance_matrix_update = Kux.mT @ Kuu_inv.mT @ (S - Kuu) @ Kuu_inv @ Kux
        updated_covariance_matrix = Kxx + covariance_matrix_update # Since Kxx is always PSD, if the updated cov is not PSD, then the update must not be PSD

        return MultivariateNormal(mean=torch.zeros(x.shape[:-1]), covariance_matrix=updated_covariance_matrix)

    def kl_divergence(self) -> Tensor:
        with settings.max_preconditioner_size(0):
            kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
        return kl_divergence

In [None]:
from gpytorch.variational import VariationalStrategy

# Simple shallow variational GP

In [17]:
import gpytorch


class SHApproximateGP(gpytorch.models.ApproximateGP):
    def __init__(self, dimension: int, num_spherical_harmonics: int, 
                 mean: Mean, covar_module: GeometricMaternKernel, variational_distribution: CholeskyVariationalDistribution):
        variational_strategy = SHVariationalStrategy(
            covar_module=covar_module,
            dimension=dimension,
            num_spherical_harmonics=num_spherical_harmonics,
        )
        super().__init__(variational_strategy)
        self.mean_module = mean
        self.covar_module = covar_module
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()


    def forward(self, x: Tensor) -> MultivariateNormal:
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)
    

class IPApproximateGP(gpytorch.models.ApproximateGP):
    def __init__(self, mean: Mean, covar_module: GeometricMaternKernel, inducing_points: Tensor, variational_distribution: CholeskyVariationalDistribution):
        variational_strategy = gpytorch.variational.UnwhitenedVariationalStrategy(
            self, 
            inducing_points=inducing_points,
            variational_distribution=variational_distribution, 
            learn_inducing_locations=False
        )
        super().__init__(variational_strategy)
        self.mean_module = mean
        self.covar_module = covar_module
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()

    def forward(self, x: Tensor) -> MultivariateNormal:
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

# Train model
Training has to be done differently from the bechmarking experiment, because we need minibatch SGD with the larger datasets and minibatch metrics. 

In [18]:
def print_smallest_eigenvalues(covar):
    # mvn = projector.inverse(mvn)
    smallest_eigenvalues = torch.linalg.eigvalsh(covar).min()
    print(f"Smallest eigenvalue: {smallest_eigenvalues.item()}")

In [19]:
from geometric_kernels.spaces import Hypersphere
from mdgp.utils.spherical_harmonic_features import num_spherical_harmonics_to_degree


dimension = 2
space, x, y = get_space_and_data(dimension)
num_spherical_harmonics = 100

degree, num_spherical_harmonics = num_spherical_harmonics_to_degree(num_spherical_harmonics, dimension + 1)

batch_shape = torch.Size([])

# Model with spherical harmonic features
mean = gpytorch.means.ZeroMean()
covar_module = GeometricMaternKernel(nu=2.5, space=space, num_eigenfunctions=4, batch_shape=batch_shape)
covar_module.initialize(lengthscale=0.001)
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
    num_inducing_points=num_spherical_harmonics, batch_shape=batch_shape
)

model_sh = SHApproximateGP(
    dimension=dimension + 1,
    num_spherical_harmonics=num_spherical_harmonics,
    mean=mean,
    covar_module=covar_module,
    variational_distribution=variational_distribution,
)

# Model with inducing points 
mean = gpytorch.means.ZeroMean()
covar_module = GeometricMaternKernel(nu=2.5, space=space, num_eigenfunctions=4, batch_shape=batch_shape)
covar_module.initialize(lengthscale=0.001)
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
    num_inducing_points=num_spherical_harmonics, batch_shape=batch_shape
)

inducing_points = torch.randn(num_spherical_harmonics, dimension + 1)
inducing_points = inducing_points / inducing_points.norm(dim=-1, keepdim=True)
model_ip = IPApproximateGP(
    mean=mean,
    covar_module=covar_module,
    inducing_points=inducing_points,
    variational_distribution=variational_distribution
)

In [20]:
torch.linalg.eigvalsh(model_sh.variational_strategy(x).covariance_matrix)

tensor([9.6762e-07, 9.7083e-07, 9.7321e-07, 9.7488e-07, 9.7569e-07, 9.7603e-07,
        9.7730e-07, 9.7806e-07, 9.7934e-07, 9.7995e-07, 9.8102e-07, 9.8120e-07,
        9.8155e-07, 9.8251e-07, 9.8279e-07, 9.8317e-07, 9.8366e-07, 9.8457e-07,
        9.8465e-07, 9.8473e-07, 9.8554e-07, 9.8619e-07, 9.8637e-07, 9.8660e-07,
        9.8715e-07, 9.8754e-07, 9.8787e-07, 9.8832e-07, 9.8868e-07, 9.8906e-07,
        9.8939e-07, 9.8980e-07, 9.9045e-07, 9.9052e-07, 9.9095e-07, 9.9154e-07,
        9.9184e-07, 9.9194e-07, 9.9244e-07, 9.9314e-07, 9.9349e-07, 9.9371e-07,
        9.9450e-07, 9.9485e-07, 9.9490e-07, 9.9527e-07, 9.9551e-07, 9.9605e-07,
        9.9652e-07, 9.9661e-07, 9.9684e-07, 9.9716e-07, 9.9758e-07, 9.9769e-07,
        9.9808e-07, 9.9813e-07, 9.9826e-07, 9.9854e-07, 9.9856e-07, 9.9863e-07,
        9.9890e-07, 9.9913e-07, 9.9930e-07, 9.9939e-07, 9.9950e-07, 9.9955e-07,
        9.9963e-07, 9.9967e-07, 9.9982e-07, 9.9983e-07, 9.9987e-07, 9.9992e-07,
        9.9993e-07, 9.9995e-07, 9.9996e-

In [21]:
torch.linalg.eigvalsh(model_ip.variational_strategy(x).covariance_matrix)

tensor([1.3028e-07, 1.3183e-07, 1.3498e-07, 1.4423e-07, 1.4432e-07, 1.4533e-07,
        1.4788e-07, 1.4938e-07, 1.5051e-07, 1.5134e-07, 1.5226e-07, 1.5409e-07,
        1.5455e-07, 1.5520e-07, 1.5596e-07, 1.5681e-07, 1.5780e-07, 1.5934e-07,
        1.6063e-07, 1.6133e-07, 1.6196e-07, 1.6255e-07, 1.6295e-07, 1.6376e-07,
        1.6419e-07, 1.6480e-07, 1.6625e-07, 1.6706e-07, 1.6787e-07, 1.6816e-07,
        1.6893e-07, 1.6927e-07, 1.6963e-07, 1.7013e-07, 1.7099e-07, 1.7208e-07,
        1.7271e-07, 1.7332e-07, 1.7363e-07, 1.7393e-07, 1.7451e-07, 1.7562e-07,
        1.7609e-07, 1.7728e-07, 1.7786e-07, 1.7849e-07, 1.7955e-07, 1.7984e-07,
        1.8131e-07, 1.8271e-07, 1.8447e-07, 1.8589e-07, 1.8637e-07, 1.8755e-07,
        1.8770e-07, 1.8844e-07, 1.8911e-07, 1.8918e-07, 1.9030e-07, 1.9316e-07,
        1.9392e-07, 1.9409e-07, 1.9488e-07, 1.9593e-07, 1.9711e-07, 1.9905e-07,
        1.9995e-07, 2.0121e-07, 2.0346e-07, 2.0400e-07, 2.0598e-07, 2.0771e-07,
        2.1002e-07, 2.1149e-07, 2.1248e-

In [22]:
model = model_sh
batch_size = 256

# 
parameters = [
    model.variational_strategy._variational_distribution.chol_variational_covar,
]
optimizer = torch.optim.Adam(parameters, lr=0.01)
mll = gpytorch.mlls.VariationalELBO(model.likelihood, model, y.size(0))

for i in range(100):
    optimizer.zero_grad()
    output = model.variational_strategy(x)
    print_smallest_eigenvalues(output.covariance_matrix)
    # print_smallest_eigenvalues(model.variational_strategy.variational_distribution.covariance_matrix)
    loss = -mll(output, y)
    loss.backward()
    optimizer.step()

    print(loss.item())

Smallest eigenvalue: 9.676153182452867e-07
1.6437344249419936
Smallest eigenvalue: -0.1455644387943823
1.6316795987210282
Smallest eigenvalue: -0.2838455303028324
1.6200684666969047
Smallest eigenvalue: -0.4166342255635759
1.6088985964061877
Smallest eigenvalue: -0.5442706061651829
1.5981651146588145
Smallest eigenvalue: -0.6666336179574633
1.587861863542639
Smallest eigenvalue: -0.783391764420869
1.5779812116732486
Smallest eigenvalue: -0.8942948397339103
1.568514472683726
Smallest eigenvalue: -0.9992557623660254
1.5594520779824301
Smallest eigenvalue: -1.0982853183919623
1.5507836256219896
Smallest eigenvalue: -1.191477528995732
1.542497976418569
Smallest eigenvalue: -1.2790041448088605
1.534583351147792
Smallest eigenvalue: -1.3610866134388289
1.527027397760234
Smallest eigenvalue: -1.437963993427625
1.519817288662135
Smallest eigenvalue: -1.509875988176104
1.5129398639603933
Smallest eigenvalue: -1.5770599207438294
1.5063817826283201
Smallest eigenvalue: -1.6397553906860596
1.50012

In [15]:
model = model_ip
batch_size = 256

# 
parameters = [
    model.variational_strategy._variational_distribution.chol_variational_covar,
]
optimizer = torch.optim.Adam(parameters, lr=0.1)
mll = gpytorch.mlls.VariationalELBO(model.likelihood, model, y.size(0))

for i in range(100):
    optimizer.zero_grad()
    output = model.variational_strategy(x)
    print_smallest_eigenvalues(output.covariance_matrix)
    # print_smallest_eigenvalues(model.variational_strategy.variational_distribution.covariance_matrix)
    loss = -mll(output, y)
    loss.backward()
    optimizer.step()

    print(loss.item())

Smallest eigenvalue: 9.275705871771942e-08
66.2402015974144
Smallest eigenvalue: 9.275705892549164e-08
207786.2552360287
Smallest eigenvalue: 9.275705893491027e-08
15083.591265096342
Smallest eigenvalue: 9.275705886112875e-08
44909.40672470607
Smallest eigenvalue: 9.275705885928685e-08
118830.74573333697
Smallest eigenvalue: 9.275705889898926e-08
90288.40738159583
Smallest eigenvalue: 9.275705894768268e-08
26945.780456647564
Smallest eigenvalue: 9.275705912624109e-08
1297.5724322773274
Smallest eigenvalue: 9.275705885692302e-08
25332.60668872481
Smallest eigenvalue: 9.275705864967637e-08
55730.93495361714
Smallest eigenvalue: 9.275705908779709e-08
54781.34677493532
Smallest eigenvalue: 9.275705921880414e-08
28478.15057929962
Smallest eigenvalue: 9.27570588146154e-08
4856.33781605401
Smallest eigenvalue: 9.275705837313666e-08
2964.018910319306
Smallest eigenvalue: 9.275705768159898e-08
18239.906292469357
Smallest eigenvalue: 9.275705928431523e-08
30729.674002484437
Smallest eigenvalue: 

### Debug

In [14]:
import plotly.express as px 
import pandas as pd 


def plot_kernel_vs_angle(kernel, dimension: int, n: int = 100): 
    pole = torch.zeros(1, dimension)
    pole[:, -1] = 1.
    theta = torch.linspace(0, torch.pi, n)
    x = torch.cat([torch.zeros(n, dimension - 2), theta.cos().unsqueeze(-1), theta.sin().unsqueeze(-1)], dim=-1)
    with torch.no_grad():
        y = kernel(x).lazy_covariance_matrix[..., 0]
        if y.ndim == 2: 
            y = y.mean(0)

    data = pd.DataFrame({'theta': theta.squeeze().numpy(), 'y': y})
    fig = px.line(data, x='theta', y='y')
    fig.show()
