In [1]:
import os
os.environ['GEOMSTATS_BACKEND'] = 'pytorch'
from torch import Tensor 

import torch 
import gpytorch 
import geometric_kernels.torch
import geomstats._backend as gs 
from geometric_kernels.spaces import Hypersphere 
from mdgp.utils import sphere_meshgrid, sphere_uniform_grid, spherical_harmonic, spherical_antiharmonic
from mdgp.experiments.uci.model.geometric import SHFDeepGPLayer

import plotly.io as pio
from plotly import graph_objects as go
from plotly.subplots import make_subplots
from mdgp.utils import test_psd
from tqdm.autonotebook import tqdm


pio.templates.default = "plotly_dark"
torch.set_default_dtype(torch.float64)

INFO: Using pytorch backend
  from tqdm.autonotebook import tqdm


In [2]:
import torch 
from torch import nn 
from gpytorch import settings 
from gpytorch.models import ApproximateGP
from gpytorch.distributions import MultivariateNormal
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.utils.memoize import cached
from linear_operator.operators import DiagLinearOperator
from mdgp.variational.spherical_harmonic_features.utils import matern_LT_Phi_from_kernel


class SphericalHarmonicFeaturesVariationalStrategy(nn.Module):
    def __init__(self, model: ApproximateGP, variational_distribution: CholeskyVariationalDistribution, num_levels: int, jitter_val: float | None = None):
        super().__init__()
        object.__setattr__(self, "_model", model)
        self._variational_distribution = variational_distribution
        self.jitter_val = jitter_val or settings.cholesky_jitter.value(torch.get_default_dtype())
        self.num_levels = num_levels

    @property
    def model(self) -> ApproximateGP:
        return self._model
    
    @property
    def variational_distribution(self) -> MultivariateNormal:
        return self._variational_distribution()
    
    @property
    @cached(name="prior_distribution_memo")
    def prior_distribution(self) -> MultivariateNormal:
        zeros = torch.zeros(
            self._variational_distribution.shape(),
            dtype=self._variational_distribution.dtype,
            device=self._variational_distribution.device,
        )
        ones = torch.ones_like(zeros)
        res = MultivariateNormal(zeros, DiagLinearOperator(ones))
        return res

    def forward(self, x) -> MultivariateNormal:
        """
        x: [..., N, D]
        """

        # prior at x 
        px = self.model.forward(x)
        mu_x, Kxx = px.mean, px.lazy_covariance_matrix # [..., N], [..., N, N]
        Kxx = Kxx.add_jitter(self.jitter_val)

        # whitened prior at u
        Linv_pu = self.prior_distribution
        Linv_mu, Linv_Kuu_LTinv = Linv_pu.mean, Linv_pu.lazy_covariance_matrix # [..., M], [..., M, M]

        # whitened variational posterior at u
        Linv_qu = self.variational_distribution
        Linv_m, Linv_S_LTinv = Linv_qu.mean, Linv_qu.lazy_covariance_matrix # [..., M], [..., M, M]

        # unwhitening + projection and vice-versa
        LT_Phiux = matern_LT_Phi_from_kernel(x, self.model.covar_module, num_levels=self.num_levels, num_levels_prior=self.model.max_ell_prior)
        Phixu_L = LT_Phiux.mT # [..., N, M]

        # posterior at x 
        qx_sigma = Kxx + Phixu_L @ (Linv_S_LTinv - Linv_Kuu_LTinv) @ LT_Phiux
        qx_sigma = qx_sigma.add_jitter(self.jitter_val)
        qx_mu = mu_x + (Phixu_L @ (Linv_m - Linv_mu).unsqueeze(-1)).squeeze(-1) # [..., N, M] @ [..., M, 1] -> [..., N, 1]
        return MultivariateNormal(qx_mu, qx_sigma)
    
    def kl_divergence(self):
        with settings.max_preconditioner_size(0):
            kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
        return kl_divergence
    
    def __call__(self, x, prior: bool = False, **kwargs) -> MultivariateNormal:
        if prior:
            return self.model.forward(x)
        return super().__call__(x, **kwargs)


In [3]:
from torch import Tensor 


import torch 
from gpytorch import settings
from gpytorch.variational import VariationalStrategy, CholeskyVariationalDistribution
from gpytorch.models.deep_gps import DeepGPLayer, DeepGP
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.means import ConstantMean, LinearMean
from gpytorch.distributions import MultivariateNormal
from gpytorch.likelihoods import GaussianLikelihood
from geometric_kernels.spaces import Hypersphere
from mdgp.experiments.uci.data.datasets import UCIDataset
from mdgp.utils.sphere import sphere_kmeans_centers
from mdgp.kernels import GeometricMaternKernel
from mdgp.variational.spherical_harmonic_features.utils import total_num_harmonics


# Settings from the paper  
LIKELIHOOD_VARIANCE = 0.01
LENGTHSCALE = 2.0
INNER_LAYER_VARIANCE = 1e-5
OUTPUT_LAYER_VARIANCE = 1.0 # This is a (reasonable) guess
NUM_INDUCING_POINTS = 100
MAX_HIDDEN_DIMS = 30


def get_hidden_dims(dataset: UCIDataset) -> int:
    return dataset.dimension + 1


def get_inducing_points(dataset: UCIDataset, num_inducing_points: int) -> Tensor:
    """
    Initialize inducing points using kmeans. (from paper)
    """
    return sphere_kmeans_centers(dataset.train_x, num_inducing_points)


class SHFDeepGPLayer(DeepGPLayer):
    def __init__(self, max_ell, d, max_ell_prior, kappa=1.0, nu=2.5, sigma=1.0, jitter_val=1e-6, optimize_nu: bool = True, output_dims: int | None = None):
        m = total_num_harmonics(max_ell, d)
        batch_shape = torch.Size([]) if output_dims is None else torch.Size([output_dims])
        variational_distribution = CholeskyVariationalDistribution(num_inducing_points=m, batch_shape=batch_shape)
        variational_strategy = SphericalHarmonicFeaturesVariationalStrategy(self, variational_distribution, num_levels=max_ell, jitter_val=jitter_val)
        super().__init__(variational_strategy, d + 1, output_dims)
        self.batch_shape = batch_shape 

        # constants 
        self.jitter_val = jitter_val or settings.cholesky_jitter.value(torch.get_default_dtype())
        self.max_ell = max_ell
        self.max_ell_prior = max_ell_prior
        self.d = d

        # modules 
        base_kernel = GeometricMaternKernel(
            space=Hypersphere(d),
            lengthscale=kappa, 
            nu=nu, 
            trainable_nu=optimize_nu, 
            num_eigenfunctions=max_ell_prior,
            batch_shape=batch_shape,
        )
        base_kernel.lengthscale = kappa
        self.covar_module = ScaleKernel(base_kernel, batch_shape=batch_shape)
        self.covar_module.outputscale = sigma ** 2
        self.mean_module = ConstantMean(batch_shape=batch_shape)
        self.covar_module

    def forward(self, x: Tensor) -> MultivariateNormal:
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)  

In [4]:
from gpytorch.models.deep_gps import DeepGP
from gpytorch.likelihoods import GaussianLikelihood
from mdgp.samplers import sample_elementwise


class SHFDeepGP(DeepGP):
    def __init__(
        self, 
        num_layers: int, 
        num_levels_var: int, 
        num_levels_ker: int, 
        dimension: int, 
        lengthscale: float = 1.0, 
        nu: float = 2.5, 
        outputscale: float = 1.0, 
        jitter_val: float | None = None, 
        optimize_nu: bool = True, 
        noise: float = 1.0, 
    ) -> None:
        super().__init__()
        self.space = Hypersphere(dimension)
        self.layers = torch.nn.ModuleList(
            [
                SHFDeepGPLayer(
                    max_ell=num_levels_var,
                    max_ell_prior=num_levels_ker,
                    d=dimension, 
                    kappa=lengthscale,
                    nu=nu, 
                    sigma=outputscale ** 0.5, 
                    jitter_val=jitter_val,
                    optimize_nu=optimize_nu,
                    output_dims=dimension + 1, # Embedding dimension of S^d is d + 1
                ) for _ in range(num_layers - 1)
            ] + 
            [
                SHFDeepGPLayer(
                    max_ell=num_levels_var,
                    max_ell_prior=num_levels_ker,
                    d=dimension, 
                    kappa=lengthscale,
                    nu=nu, 
                    sigma=outputscale ** 0.5, 
                    jitter_val=jitter_val,
                    optimize_nu=optimize_nu,
                    output_dims=None, # For UCI datasets output is always 1-dimensional
                )
            ]
        )
        self.likelihood = GaussianLikelihood()
        self.likelihood.noise = noise

    def forward(self, x: Tensor, are_samples: bool = False) -> Tensor:
        for layer in self.layers[:-1]:
            ambient = sample_elementwise(layer(x, are_samples=are_samples))
            tangent = self.space.to_tangent(ambient, x)
            x = self.space.metric.exp(tangent, x)
            are_samples = True
        return self.layers[-1](x, are_samples=are_samples)

In [5]:
def smooth_target(x):
    return spherical_harmonic(x, 2, 3)


def nonsmooth_target(x):
    return spherical_antiharmonic(x, 2, 3)


def get_data(num_train: int = 400, smooth: bool = True, noise_std: float = 1e-2, seed: int = 0) -> tuple[Tensor, Tensor, Tensor, Tensor]:
    torch.manual_seed(seed)
    f = smooth_target if smooth else nonsmooth_target
    x_train = sphere_uniform_grid(num_train)
    f_train = f(x_train)
    y_train = f_train + noise_std * torch.randn_like(f_train)
    x_test = sphere_meshgrid(50, 50)
    f_test = f(x_test)
    y_test = f_test + noise_std * torch.randn_like(f_test)
    return x_train, y_train, x_test, y_test

In [6]:
from mdgp.variational.spherical_harmonic_features.utils import num_spherical_harmonics_to_num_levels


smooth = False
dimension = 2
num_inducing_points = 100
num_levels_var = num_spherical_harmonics_to_num_levels(num_inducing_points, dimension)[0]
num_levels_ker = 30
num_epochs = 1000
num_layers = 2

model = SHFDeepGP(
    num_layers=num_layers,
    num_levels_var=num_levels_var,
    num_levels_ker=num_levels_ker,
    dimension=dimension,
    optimize_nu=True,
)
x_train, y_train, x_test, y_test = get_data(800, smooth=smooth)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, maximize=True)
elbo = gpytorch.mlls.DeepApproximateMLL(gpytorch.mlls.VariationalELBO(model.likelihood, model, num_data=len(y_train)))

model.train()
for i in (pbar := tqdm(range(num_epochs), desc="Training")):
    optimizer.zero_grad()
    output = model(x_train)
    loss = elbo(output, y_train)
    loss.backward()
    pbar.set_postfix(loss=loss.item())
    optimizer.step()


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [18]:
from linear_operator.operators import DiagLinearOperator


with torch.no_grad():
    out = model(x_test.view(-1, 3))

In [20]:
L = DiagLinearOperator(torch.ones(2500))
mean = out.mean @ L
covar = L @ out.lazy_covariance_matrix @ L 
MultivariateNormal(mean, covar)

MultivariateNormal(loc: torch.Size([10, 2500]))

In [22]:
out.mean.shape, out.lazy_covariance_matrix.shape

(torch.Size([10, 2500]), torch.Size([10, 2500, 2500]))

In [21]:
from plotly import graph_objects as go
from plotly.subplots import make_subplots


def plot_predictions(model, x_test, y_test):
    fig = make_subplots(rows=1, cols=3, subplot_titles=("Mean", "Std", "True"), 
                        specs=[[{'type': 'surface'}, {'type': 'surface'}, {'type': 'surface'}]])
    model.eval()
    with torch.no_grad(), gpytorch.settings.num_likelihood_samples(10):
        x, y, z = x_test.unbind(-1)
        pred = model.likelihood(model(x_test.view(-1, 3)))

    mean = go.Surface(x=x, y=y, z=z, surfacecolor=pred.mean.mean(dim=0).view_as(x))
    std = go.Surface(x=x, y=y, z=z, surfacecolor=pred.stddev.mean(dim=0).view_as(x), colorscale="Viridis")
    true = go.Surface(x=x, y=y, z=z, surfacecolor=y_test)
    fig.add_trace(mean, row=1, col=1)
    fig.add_trace(std, row=1, col=2)
    fig.add_trace(true, row=1, col=3)
    fig.show()

In [22]:
plot_predictions(model, x_test, y_test)

In [25]:
total_num_harmonics(25, 2)

625