# Demonstration of how to use SphericalHarmonicFeaturesVariationalStrategy in a shallow ApproximateGP

In [43]:
import geometric_kernels.torch


import torch 
import gpytorch 
from gpytorch.distributions import MultivariateNormal
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.utils.memoize import cached 
from geometric_kernels.spaces import Hypersphere
from linear_operator.operators import DiagLinearOperator
from mdgp.variational.spherical_harmonic_features.utils import * 
from mdgp.variational.spherical_harmonic_features_variational_strategy import SphericalHarmonicFeaturesVariationalStrategy
from mdgp.kernels import GeometricMaternKernel
from mdgp.utils import test_psd, sphere_uniform_grid, sphere_meshgrid, spherical_antiharmonic, spherical_harmonic
from tqdm.autonotebook import tqdm 


torch.set_default_dtype(torch.float64)

In [44]:
class gpytorchSGP(ApproximateGP):
    def __init__(self, max_ell, d, max_ell_prior, epsilon_sigma=1.0, kappa=1.0, nu=2.5, sigma=1.0, batch_shape=torch.Size([]), jitter_val=1e-6, optimize_nu: bool = True):
        m = total_num_harmonics(max_ell, d)
        variational_distribution = CholeskyVariationalDistribution(num_inducing_points=m, batch_shape=batch_shape)
        variational_strategy = SphericalHarmonicFeaturesVariationalStrategy(self, variational_distribution, num_levels=max_ell, jitter_val=jitter_val)
        super().__init__(variational_strategy=variational_strategy)

        # constants 
        self.jitter_val = jitter_val
        self.max_ell = max_ell
        self.max_ell_prior = max_ell_prior
        self.d = d

        # modules 
        base_kernel = GeometricMaternKernel(
            space=Hypersphere(d),
            lengthscale=kappa, 
            nu=nu, 
            trainable_nu=optimize_nu, 
            num_eigenfunctions=max_ell_prior,
        )
        self.covar_module = gpytorch.kernels.ScaleKernel(base_kernel)
        self.mean_module = gpytorch.means.ConstantMean()
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()

        # prior hyperparams 
        self.covar_module.outputscale = sigma ** 2
        self.likelihood.noise = epsilon_sigma ** 2

    def forward(self, x) -> MultivariateNormal:
        p_sigma = self.covar_module(x)
        p_mu = self.mean_module(x)
        return MultivariateNormal(p_mu, p_sigma)

In [45]:
def get_data(num_train: int = 400, smooth: bool = True, noise_std: float = 1e-2, seed: int = 0) -> tuple[Tensor, Tensor, Tensor, Tensor]:
    torch.manual_seed(seed)

    f = spherical_harmonic if smooth else spherical_antiharmonic
    x_train = sphere_uniform_grid(num_train)
    f_train = f(x_train, 2, 3)
    y_train = f_train + noise_std * torch.randn_like(f_train)
    x_test = sphere_meshgrid(50, 50)
    f_test = f(x_test, 2, 3)
    y_test = f_test + noise_std * torch.randn_like(f_test)
    return x_train, y_train, x_test, y_test

In [46]:
smooth = False
d = 2
num_inducing_points = 100
max_ell = num_spherical_harmonics_to_num_levels(num_inducing_points, d)[0]
max_ell_prior = 30
num_epochs = 1000


model = gpytorchSGP(max_ell=max_ell, d=d, max_ell_prior=max_ell_prior, optimize_nu=True)
x_train, y_train, x_test, y_test = get_data(smooth=smooth)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, maximize=True)
elbo = gpytorch.mlls.VariationalELBO(model.likelihood, model, num_data=len(y_train))

model.train()
for i in (pbar := tqdm(range(num_epochs), desc="Training")):
    optimizer.zero_grad()
    output = model(x_train)
    test_psd(output.covariance_matrix)
    loss = elbo(output, y_train)
    loss.backward()
    pbar.set_postfix(loss=loss.item())
    optimizer.step()


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

In [47]:
from plotly import graph_objects as go
from plotly.subplots import make_subplots


def plot_predictions(model, x_test, y_test):
    fig = make_subplots(rows=1, cols=3, subplot_titles=("Mean", "Std", "True"), 
                        specs=[[{'type': 'surface'}, {'type': 'surface'}, {'type': 'surface'}]])
    model.eval()
    with torch.no_grad():
        x, y, z = x_test.unbind(-1)
        pred = model.likelihood(model(x_test.view(-1, 3)))

    mean = go.Surface(x=x, y=y, z=z, surfacecolor=pred.mean.view_as(x))
    std = go.Surface(x=x, y=y, z=z, surfacecolor=pred.stddev.view_as(x), colorscale="Viridis")
    true = go.Surface(x=x, y=y, z=z, surfacecolor=y_test)
    fig.add_trace(mean, row=1, col=1)
    fig.add_trace(std, row=1, col=2)
    fig.add_trace(true, row=1, col=3)
    fig.show()

In [48]:
plot_predictions(model, x_test, y_test)