# Load in data
Currently supported datasets: power, protein, kin8nm

In [1]:
import torch 
import geometric_kernels.torch 

torch.set_default_dtype(torch.float64)

### Create data on S^D

In [2]:
from geometric_kernels.spaces import Hypersphere
from mdgp.utils.sphere import sphere_uniform_grid, spherical_harmonic


def get_space_and_data(n=200):
    x = sphere_uniform_grid(n)
    y = spherical_harmonic(x, 2, 3)
    return Hypersphere(2), x, y

INFO: Using numpy backend


# Instantiate model

This is done using the same model arguments as for the benchmarking experiment

In [3]:
from torch import Tensor 
from gpytorch.means import Mean
from gpytorch.kernels import ScaleKernel
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from mdgp.kernels import GeometricMaternKernel


import torch 
from gpytorch import settings
from gpytorch.distributions import MultivariateNormal
# TODO Maybe just move the functions from spherical_harmonic_features.py into this file?
from mdgp.utils.spherical_harmonic_features import num_spherical_harmonics_to_degree, matern_Kuu, matern_LT_Phi

In [4]:
class VariationalDistribution(torch.nn.Module):
    def __init__(self, num_inducing: int):
        super().__init__()
        self.num_inducing = num_inducing
        self.chol_variational_covar = torch.nn.Parameter(torch.eye(num_inducing))

    def forward(self) -> MultivariateNormal:
        covar = self.chol_variational_covar @ self.chol_variational_covar.mT
        return MultivariateNormal(torch.zeros(self.num_inducing), covar)
    

class SHVariationalStrategy(torch.nn.Module):
    def __init__(self,
        covar_module: GeometricMaternKernel,
        dimension: int,
        num_spherical_harmonics: int,
    ):
        super().__init__()
        self.covar_module = covar_module
        self.dimension = dimension
        self.degree, self.num_spherical_harmonics = num_spherical_harmonics_to_degree(num_spherical_harmonics, dimension)
        self._variational_distribution = VariationalDistribution(num_spherical_harmonics)

    @property
    def kappa(self) -> Tensor:
        return torch.tensor([[0.001]])
    
    @property
    def nu(self) -> Tensor | float:
        return torch.tensor([[2.5]])

    @property 
    def sigma(self) -> Tensor: 
        return torch.tensor(1.0)

    @property
    def prior_distribution(self) -> MultivariateNormal:
        covariance_matrix = torch.eye(self.num_spherical_harmonics)
        mean = torch.zeros(self.num_spherical_harmonics)
        return MultivariateNormal(mean=mean, covariance_matrix=covariance_matrix)
        
    @property
    def variational_distribution(self) -> MultivariateNormal:
        return self._variational_distribution()

    def forward(self, x: Tensor, **kwargs) -> MultivariateNormal:
        # inducing-inducing prior
        pu = self.prior_distribution
        invL_Kuu_invLt = pu.lazy_covariance_matrix

        # input-input prior
        Kxx = self.covar_module(x)

        # inducing-inducing variational
        qu = self.variational_distribution
        invL_S_invLt = qu.lazy_covariance_matrix

        # inducing-input prior  
        LT_Phi = matern_LT_Phi(x, max_ell=self.degree, d=self.dimension, kappa=self.kappa, nu=self.nu, sigma=self.sigma,) # [..., O, num_harmonics, N]

        # Update the covariance matrix
        covariance_matrix_update = LT_Phi.mT @ (invL_S_invLt - invL_Kuu_invLt) @ LT_Phi # [O, num_harmonics, num_harmonics] @ [O, num_harmonics, num_harmonics] @ [O, num_harmonics, N] -> [O, num_harmonics, N]
        updated_covariance_matrix = Kxx + covariance_matrix_update # [..., O, N, N] + [..., O, N, N] -> [..., O, N, N]

        return MultivariateNormal(mean=torch.zeros(x.shape[:-1]), covariance_matrix=updated_covariance_matrix)

    def kl_divergence(self) -> Tensor:
        with settings.max_preconditioner_size(0):
            kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
        return kl_divergence

In [5]:
from mdgp.utils.spherical_harmonic_features import matern_Kux, matern_Kuu








class SHVariationalStrategy(torch.nn.Module):
    def __init__(self,
        covar_module: GeometricMaternKernel,
        dimension: int,
        num_spherical_harmonics: int,
    ):
        super().__init__()
        self.covar_module = covar_module
        self.dimension = dimension
        self.degree, self.num_spherical_harmonics = num_spherical_harmonics_to_degree(num_spherical_harmonics, dimension)
        self._variational_distribution = VariationalDistribution(num_spherical_harmonics)
        self._variational_distribution.chol_variational_covar = torch.nn.Parameter(
            matern_Kuu(max_ell=self.degree, d=self.dimension, kappa=self.kappa, nu=self.nu, sigma=self.sigma).sqrt().to_dense()
        )

    @property
    def kappa(self) -> Tensor:
        return torch.tensor([[0.001]])
    
    @property
    def nu(self) -> Tensor | float:
        return torch.tensor([[2.5]])

    @property 
    def sigma(self) -> Tensor: 
        return torch.tensor(1.0)

    @property
    def prior_distribution(self) -> MultivariateNormal:
        covariance_matrix = matern_Kuu(max_ell=self.degree, d=self.dimension, kappa=self.kappa, nu=self.nu, sigma=self.sigma)
        mean = torch.zeros(self.num_spherical_harmonics)
        return MultivariateNormal(mean=mean, covariance_matrix=covariance_matrix)
        
    @property
    def variational_distribution(self) -> MultivariateNormal:
        return self._variational_distribution()

    def forward(self, x: Tensor, **kwargs) -> MultivariateNormal:
        jitter = 1e-6
        # inducing-inducing prior
        Kuu = matern_Kuu(max_ell=self.degree, d=self.dimension, kappa=self.kappa, nu=self.nu, sigma=self.sigma)
        Kuu = Kuu.add_jitter(jitter)
        Kuu_inv = Kuu.to_dense().inverse()

        # input-input prior
        Kxx = self.covar_module(x)
        Kxx = Kxx.add_jitter(jitter)

        # inducing-input prior
        Kux = matern_Kux(x, max_ell=self.degree, d=self.dimension)

        # inducing-inducing variational
        S = self.variational_distribution.covariance_matrix

        # inducing-input prior  
        covariance_matrix_update = Kux.mT @ Kuu_inv.mT @ (S - Kuu) @ Kuu_inv @ Kux # Since Kxx is always PSD, if the updated cov is not PSD, then the update must not be PSD
        # covariance_matrix_update = covariance_matrix_update + torch.eye(covariance_matrix_update.shape[-1]).mul(jitter)
        updated_covariance_matrix = Kxx + covariance_matrix_update 

        return MultivariateNormal(mean=torch.zeros(x.shape[:-1]), covariance_matrix=updated_covariance_matrix)

    def kl_divergence(self) -> Tensor:
        with settings.max_preconditioner_size(0):
            kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
        return kl_divergence

# Simple shallow variational GP

In [6]:
import gpytorch


class SHApproximateGP(gpytorch.models.ApproximateGP):
    def __init__(self, dimension: int, num_spherical_harmonics: int, 
                 mean: Mean, covar_module: GeometricMaternKernel, variational_distribution: CholeskyVariationalDistribution):
        variational_strategy = SHVariationalStrategy(
            covar_module=covar_module,
            dimension=dimension,
            num_spherical_harmonics=num_spherical_harmonics,
        )
        super().__init__(variational_strategy)
        self.mean_module = mean
        self.covar_module = covar_module
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()


    def forward(self, x: Tensor) -> MultivariateNormal:
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)
    

class IPApproximateGP(gpytorch.models.ApproximateGP):
    def __init__(self, mean: Mean, covar_module: GeometricMaternKernel, inducing_points: Tensor, variational_distribution: CholeskyVariationalDistribution):
        variational_strategy = gpytorch.variational.UnwhitenedVariationalStrategy(
            self, 
            inducing_points=inducing_points,
            variational_distribution=variational_distribution, 
            learn_inducing_locations=False
        )
        super().__init__(variational_strategy)
        self.mean_module = mean
        self.covar_module = covar_module
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()

    def forward(self, x: Tensor) -> MultivariateNormal:
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

# Train model
Training has to be done differently from the bechmarking experiment, because we need minibatch SGD with the larger datasets and minibatch metrics. 

In [7]:
def print_smallest_eigenvalues(covar):
    # mvn = projector.inverse(mvn)
    smallest_eigenvalues = torch.linalg.eigvalsh(covar).min()
    print(f"Smallest eigenvalue: {smallest_eigenvalues.item()}")

In [8]:
from geometric_kernels.spaces import Hypersphere
from mdgp.utils.spherical_harmonic_features import num_spherical_harmonics_to_degree


dimension = 2
space, x, y = get_space_and_data(200)
num_spherical_harmonics = 500

degree, num_spherical_harmonics = num_spherical_harmonics_to_degree(num_spherical_harmonics, dimension + 1)

batch_shape = torch.Size([])

# Model with spherical harmonic features
mean = gpytorch.means.ZeroMean()
covar_module = GeometricMaternKernel(nu=2.5, space=space, num_eigenfunctions=30, batch_shape=batch_shape)
covar_module.initialize(lengthscale=1.0)
covar_module = ScaleKernel(covar_module)
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
    num_inducing_points=num_spherical_harmonics, batch_shape=batch_shape
)

model_sh = SHApproximateGP(
    dimension=dimension + 1,
    num_spherical_harmonics=num_spherical_harmonics,
    mean=mean,
    covar_module=covar_module,
    variational_distribution=variational_distribution,
)

# Model with inducing points 
mean = gpytorch.means.ZeroMean()
covar_module = GeometricMaternKernel(nu=2.5, space=space, num_eigenfunctions=30, batch_shape=batch_shape)
covar_module.initialize(lengthscale=1.0)
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
    num_inducing_points=num_spherical_harmonics, batch_shape=batch_shape
)

inducing_points = torch.randn(num_spherical_harmonics, dimension + 1)
inducing_points = inducing_points / inducing_points.norm(dim=-1, keepdim=True)
model_ip = IPApproximateGP(
    mean=mean,
    covar_module=covar_module,
    inducing_points=inducing_points,
    variational_distribution=variational_distribution
)

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 529, which includes all spherical harmonics up to degree 23 (incl.)


In [9]:
model = model_sh

# 
parameters = model.parameters()
optimizer = torch.optim.Adam(parameters, lr=0.01)
mll = gpytorch.mlls.VariationalELBO(model.likelihood, model, y.size(0))

for i in range(1000):
    optimizer.zero_grad()
    output = model.variational_strategy(x)
    print_smallest_eigenvalues(output.covariance_matrix)
    loss = -mll(output, y)
    loss.backward()
    optimizer.step()

    print(loss.item())

Smallest eigenvalue: 0.000138436906963486
1.2643801216225734
Smallest eigenvalue: -0.05670803228225404
1.2461014181538244
Smallest eigenvalue: -0.11158823229680875
1.2288470769104518
Smallest eigenvalue: -0.16427468676872647
1.212423815069876
Smallest eigenvalue: -0.21479486884665155
1.1965878577848377
Smallest eigenvalue: -0.26320631844133957
1.181235522393201
Smallest eigenvalue: -0.3095788170674781
1.1663263047864552
Smallest eigenvalue: -0.35395804260083913
1.1518324214004605
Smallest eigenvalue: -0.396357346097838
1.1377288360970554
Smallest eigenvalue: -0.4367697429263466
1.1239917991715833
Smallest eigenvalue: -0.47519045600118875
1.1105985933241764
Smallest eigenvalue: -0.511633643576327
1.0975272229438278
Smallest eigenvalue: -0.5461348873312788
1.0847561555432066
Smallest eigenvalue: -0.5787491663370447
1.0722642800419695
Smallest eigenvalue: -0.6095455730285996
1.06003100088277
Smallest eigenvalue: -0.638598516848669
1.0480363013234315


KeyboardInterrupt: 

In [34]:
model = model_ip
batch_size = 256

# 
parameters = [
    model.variational_strategy._variational_distribution.chol_variational_covar,
]
optimizer = torch.optim.Adam(parameters, lr=0.1)
mll = gpytorch.mlls.VariationalELBO(model.likelihood, model, y.size(0))

for i in range(100):
    optimizer.zero_grad()
    output = model.variational_strategy(x)
    print_smallest_eigenvalues(output.covariance_matrix)
    # print_smallest_eigenvalues(model.variational_strategy.variational_distribution.covariance_matrix)
    loss = -mll(output, y)
    loss.backward()
    optimizer.step()

    print(loss.item())

Smallest eigenvalue: 1.1481051897019095e-07
418.7588268497319
Smallest eigenvalue: 1.1481051896040093e-07
135206.41611263787
Smallest eigenvalue: 1.1481051931463858e-07
28516.621114224683
Smallest eigenvalue: 1.1481051833897677e-07
45890.731124151054
Smallest eigenvalue: 1.1481051960847277e-07
70374.32070225292
Smallest eigenvalue: 1.1481051681177052e-07
50411.19093092461
Smallest eigenvalue: 1.1481051870067949e-07
24583.84176166236
Smallest eigenvalue: 1.1481051827941582e-07
17814.314417809554
Smallest eigenvalue: 1.1481051978054234e-07
26099.420989398466
Smallest eigenvalue: 1.1481051856745828e-07
31718.77870604916
Smallest eigenvalue: 1.1481051885604307e-07
26540.292297390377
Smallest eigenvalue: 1.1481051846563987e-07
16712.760769188648
Smallest eigenvalue: 1.14810519520386e-07
10931.868048022894
Smallest eigenvalue: 1.148105188095749e-07
11693.621856448437
Smallest eigenvalue: 1.1481051872180677e-07
15096.564575818524
Smallest eigenvalue: 1.1481051965177757e-07
15992.366172978744


### Debug

In [14]:
import plotly.express as px 
import pandas as pd 


def plot_kernel_vs_angle(kernel, dimension: int, n: int = 100): 
    pole = torch.zeros(1, dimension)
    pole[:, -1] = 1.
    theta = torch.linspace(0, torch.pi, n)
    x = torch.cat([torch.zeros(n, dimension - 2), theta.cos().unsqueeze(-1), theta.sin().unsqueeze(-1)], dim=-1)
    with torch.no_grad():
        y = kernel(x).lazy_covariance_matrix[..., 0]
        if y.ndim == 2: 
            y = y.mean(0)

    data = pd.DataFrame({'theta': theta.squeeze().numpy(), 'y': y})
    fig = px.line(data, x='theta', y='y')
    fig.show()
