# Demonstration of how to use SphericalHarmonicFeaturesVariationalStrategy in a shallow ApproximateGP

In [1]:
from torch import Tensor 


import geometric_kernels.torch


import torch 
import gpytorch 
from gpytorch.distributions import MultivariateNormal
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from geometric_kernels.spaces import Hypersphere
from mdgp.variational.spherical_harmonic_features.utils import total_num_harmonics, num_spherical_harmonics_to_num_levels
from mdgp.variational.spherical_harmonic_features_variational_strategy import SphericalHarmonicFeaturesVariationalStrategy
from mdgp.kernels import GeometricMaternKernel
from mdgp.utils import test_psd, sphere_uniform_grid, sphere_meshgrid, spherical_antiharmonic, spherical_harmonic
from tqdm.autonotebook import tqdm 


torch.set_default_dtype(torch.float64)

INFO: Using numpy backend
  from tqdm.autonotebook import tqdm


In [19]:
import torch 
from torch import nn 
from gpytorch import settings 
from gpytorch.models import ApproximateGP
from gpytorch.distributions import MultivariateNormal
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.utils.memoize import cached
from linear_operator.operators import DiagLinearOperator, DenseLinearOperator
from mdgp.variational.spherical_harmonic_features.utils import matern_LT_Phi_from_kernel


class SphericalHarmonicFeaturesVariationalStrategy(nn.Module):
    def __init__(self, model: ApproximateGP, variational_distribution: CholeskyVariationalDistribution, num_levels: int, jitter_val: float | None = None):
        super().__init__()
        object.__setattr__(self, "_model", model)
        self._variational_distribution = variational_distribution
        self.jitter_val = jitter_val or settings.cholesky_jitter.value(torch.get_default_dtype())
        self.num_levels = num_levels

    @property
    def model(self) -> ApproximateGP:
        return self._model
    
    @property
    def variational_distribution(self) -> MultivariateNormal:
        return self._variational_distribution()
    
    @property
    @cached(name="prior_distribution_memo")
    def prior_distribution(self) -> MultivariateNormal:
        zeros = torch.zeros(
            self._variational_distribution.shape(),
            dtype=self._variational_distribution.dtype,
            device=self._variational_distribution.device,
        )
        ones = torch.ones_like(zeros)
        res = MultivariateNormal(zeros, DiagLinearOperator(ones))
        return res

    def forward(self, x) -> MultivariateNormal:
        # prior at x 
        px = self.model.forward(x)
        mu_x, Kxx = px.mean, px.lazy_covariance_matrix
        Kxx = Kxx.add_jitter(self.jitter_val)

        # whitened prior at u
        Linv_pu = self.prior_distribution
        Linv_mu, Linv_Kuu_LTinv = Linv_pu.mean, Linv_pu.lazy_covariance_matrix # [..., N], [..., N, N]

        # whitened variational posterior at u
        Linv_qu = self.variational_distribution
        Linv_m, Linv_S_LTinv = Linv_qu.mean, Linv_qu.lazy_covariance_matrix # [..., M], [..., M, M]

        # unwhitening + projection and vice-versa
        LT_Phiux = matern_LT_Phi_from_kernel(x, self.model.covar_module, num_levels=self.num_levels, num_levels_prior=self.model.max_ell_prior)
        Phixu_L = LT_Phiux.mT # [..., N, M]

        # posterior at x 
        qx_sigma = Kxx + Phixu_L @ (Linv_S_LTinv - Linv_Kuu_LTinv) @ LT_Phiux # [..., N, N] + [..., N, M] @ [..., M, M] @ [..., M, N] -> [..., N, N]
        qx_sigma = qx_sigma.add_jitter(self.jitter_val)

        # qx_mu = mu_x + Phixu_L @ (Linv_m - Linv_mu)
        qx_mu = mu_x + torch.einsum("...ij,...j->...i", Phixu_L, Linv_m - Linv_mu) # [..., M, M] @ [..., M] -> [..., M]
        return MultivariateNormal(qx_mu, qx_sigma)
    
    def kl_divergence(self):
        with settings.max_preconditioner_size(0):
            kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
        return kl_divergence
    
    def __call__(self, x, prior: bool = False, **kwargs) -> MultivariateNormal:
        if prior:
            return self.model.forward(x)
        return super().__call__(x, **kwargs)


In [20]:
class gpytorchSGP(ApproximateGP):
    def __init__(self, max_ell, d, max_ell_prior, epsilon_sigma=1.0, kappa=1.0, nu=2.5, sigma=1.0, batch_shape=torch.Size([]), jitter_val=1e-6, optimize_nu: bool = True):
        m = total_num_harmonics(max_ell, d)
        variational_distribution = CholeskyVariationalDistribution(num_inducing_points=m, batch_shape=batch_shape)
        variational_strategy = SphericalHarmonicFeaturesVariationalStrategy(self, variational_distribution, num_levels=max_ell, jitter_val=jitter_val)
        super().__init__(variational_strategy=variational_strategy)

        # constants 
        self.jitter_val = jitter_val
        self.max_ell = max_ell
        self.max_ell_prior = max_ell_prior
        self.d = d

        # modules 
        base_kernel = GeometricMaternKernel(
            space=Hypersphere(d),
            lengthscale=kappa, 
            nu=nu, 
            trainable_nu=optimize_nu, 
            num_eigenfunctions=max_ell_prior,
            batch_shape=batch_shape
        )
        self.covar_module = gpytorch.kernels.ScaleKernel(base_kernel, batch_shape=batch_shape)
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape)
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood(batch_shape=batch_shape)

        # prior hyperparams 
        self.covar_module.outputscale = sigma ** 2
        self.likelihood.noise = epsilon_sigma ** 2

    def forward(self, x) -> MultivariateNormal:
        p_sigma = self.covar_module(x)
        p_mu = self.mean_module(x)
        return MultivariateNormal(p_mu, p_sigma)

In [21]:
def get_data(num_train: int = 400, smooth: bool = True, noise_std: float = 1e-2, seed: int = 0) -> tuple[Tensor, Tensor, Tensor, Tensor]:
    torch.manual_seed(seed)

    f = spherical_harmonic if smooth else spherical_antiharmonic
    x_train = sphere_uniform_grid(num_train)
    f_train = f(x_train, 2, 3)
    y_train = f_train + noise_std * torch.randn_like(f_train)
    x_test = sphere_meshgrid(50, 50)
    f_test = f(x_test, 2, 3)
    y_test = f_test + noise_std * torch.randn_like(f_test)
    return x_train, y_train, x_test, y_test

In [22]:
smooth = False
d = 2
num_inducing_points = 100
max_ell = num_spherical_harmonics_to_num_levels(num_inducing_points, d)[0]
max_ell_prior = 30
num_epochs = 1000


model = gpytorchSGP(max_ell=max_ell, d=d, max_ell_prior=max_ell_prior, optimize_nu=True)
x_train, y_train, x_test, y_test = get_data(smooth=smooth)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, maximize=True)
elbo = gpytorch.mlls.VariationalELBO(model.likelihood, model, num_data=len(y_train))

model.train()
for i in (pbar := tqdm(range(num_epochs), desc="Training")):
    optimizer.zero_grad()
    output = model(x_train)
    test_psd(output.covariance_matrix)
    loss = elbo(output, y_train)
    loss.backward()
    pbar.set_postfix(loss=loss.item())
    optimizer.step()


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

In [23]:
from plotly import graph_objects as go
from plotly.subplots import make_subplots


def plot_predictions(model, x_test, y_test):
    fig = make_subplots(rows=1, cols=3, subplot_titles=("Mean", "Std", "True"), 
                        specs=[[{'type': 'surface'}, {'type': 'surface'}, {'type': 'surface'}]])
    model.eval()
    with torch.no_grad():
        x, y, z = x_test.unbind(-1)
        pred = model.likelihood(model(x_test.view(-1, 3)))

    mean = go.Surface(x=x, y=y, z=z, surfacecolor=pred.mean.view_as(x))
    std = go.Surface(x=x, y=y, z=z, surfacecolor=pred.stddev.view_as(x), colorscale="Viridis")
    true = go.Surface(x=x, y=y, z=z, surfacecolor=y_test)
    fig.add_trace(mean, row=1, col=1)
    fig.add_trace(std, row=1, col=2)
    fig.add_trace(true, row=1, col=3)
    fig.show()

In [24]:
plot_predictions(model, x_test, y_test)

# Test that model works with multiple outputs

In [25]:
class gpytorchSGP(ApproximateGP):
    def __init__(self, max_ell, d, max_ell_prior, epsilon_sigma=1.0, kappa=1.0, nu=2.5, sigma=1.0, batch_shape=torch.Size([]), jitter_val=1e-6, optimize_nu: bool = True):
        m = total_num_harmonics(max_ell, d)
        variational_distribution = CholeskyVariationalDistribution(num_inducing_points=m, batch_shape=batch_shape)
        variational_strategy = gpytorch.variational.MultitaskVariationalStrategy(
            SphericalHarmonicFeaturesVariationalStrategy(self, variational_distribution, num_levels=max_ell, jitter_val=jitter_val), 
            num_tasks=batch_shape[0]
        )
        super().__init__(variational_strategy=variational_strategy)

        # constants 
        self.jitter_val = jitter_val
        self.max_ell = max_ell
        self.max_ell_prior = max_ell_prior
        self.d = d

        # modules 
        base_kernel = GeometricMaternKernel(
            space=Hypersphere(d),
            lengthscale=kappa, 
            nu=nu, 
            trainable_nu=optimize_nu, 
            num_eigenfunctions=max_ell_prior,
            batch_shape=batch_shape
        )
        self.covar_module = gpytorch.kernels.ScaleKernel(base_kernel, batch_shape=batch_shape)
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape)
        self.likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=batch_shape[0])

        # prior hyperparams 
        self.covar_module.outputscale = sigma ** 2
        self.likelihood.noise = epsilon_sigma ** 2

    def forward(self, x) -> MultivariateNormal:
        p_sigma = self.covar_module(x)
        p_mu = self.mean_module(x)
        return MultivariateNormal(p_mu, p_sigma)

In [26]:
def smooth_target(x: Tensor) -> Tensor:
    return torch.stack([spherical_harmonic(x, 2, 3), spherical_harmonic(x, 1, 2)], dim=-1)

def nonsmooth_target(x: Tensor) -> Tensor:
    return [spherical_antiharmonic(x, 2, 3), spherical_antiharmonic(x, 1, 2)]


def get_data(num_train: int = 400, smooth: bool = True, noise_std: float = 1e-2, seed: int = 0) -> tuple[Tensor, Tensor, Tensor, Tensor]:
    torch.manual_seed(seed)

    f = smooth_target if smooth else nonsmooth_target
    x_train = sphere_uniform_grid(num_train)
    f_train = f(x_train)
    y_train = f_train + noise_std * torch.randn_like(f_train)
    x_test = sphere_meshgrid(50, 50)
    f_test = f(x_test)
    y_test = f_test + noise_std * torch.randn_like(f_test)
    return x_train, y_train, x_test, y_test

In [53]:
torch.random.manual_seed(1)

smooth = True
model = gpytorchSGP(max_ell=max_ell, d=d, max_ell_prior=max_ell_prior, optimize_nu=True, batch_shape=torch.Size([2]))
x_train, y_train, x_test, y_test = get_data(400, smooth=smooth)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, maximize=True)
# likelihood = gpytorch.likelihoods.GaussianLikelihood()
elbo = gpytorch.mlls.VariationalELBO(model.likelihood, model, num_data=len(y_train))


model.train()
for i in (pbar := tqdm(range(num_epochs), desc="Training")):
    optimizer.zero_grad()
    output = model(x_train)
    test_psd(output.covariance_matrix)
    loss = elbo(output, y_train)
    loss.backward()
    pbar.set_postfix(loss=loss.item())
    optimizer.step()

Training:   0%|          | 0/1000 [00:00<?, ?it/s]

In [54]:
from plotly import graph_objects as go
from plotly.subplots import make_subplots


def plot_predictions(model, x_test, y_test, index: int = 0):
    fig = make_subplots(rows=1, cols=3, subplot_titles=("Mean", "Std", "True"), 
                        specs=[[{'type': 'surface'}, {'type': 'surface'}, {'type': 'surface'}]])
    model.eval()
    with torch.no_grad():
        x, y, z = x_test.unbind(-1)
        pred = model.likelihood(model(x_test.view(-1, 3)))

    mean = go.Surface(x=x, y=y, z=z, surfacecolor=pred.mean[..., index].view_as(x))
    std = go.Surface(x=x, y=y, z=z, surfacecolor=pred.stddev[..., index].view_as(x), colorscale="Viridis")
    true = go.Surface(x=x, y=y, z=z, surfacecolor=y_test[..., index])
    fig.add_trace(mean, row=1, col=1)
    fig.add_trace(std, row=1, col=2)
    fig.add_trace(true, row=1, col=3)
    fig.show()

In [55]:
plot_predictions(model, x_test, y_test, 0)

In [56]:
plot_predictions(model, x_test, y_test, 1)