# Load in data
Currently supported datasets: power, protein, kin8nm

In [2]:
import torch 
import geometric_kernels.torch 

torch.set_default_dtype(torch.float64)

In [3]:
import os
import pandas as pd 
import torch 
from torch import Tensor
from torch.utils.data import TensorDataset, DataLoader, Dataset


class UCIDataset:

    UCI_BASE_URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/'

    def __init__(self, name: str, path: str = '../../data/uci/', normalize: bool = True, seed: int | None = None): 
        self.name = name 
        self.path = path 
        self.csv_path = os.path.join(self.path, self.name + '.csv')

        # Set generator if seed is provided.
        self.generator = torch.Generator()
        if seed is not None: 
            self.generator.manual_seed(seed)

        # Load, shuffle, split, and normalize data. TODO except for load, these don't need to be object methods. 
        # We keep the standard deviation of the test set for log-likelihood evaluation.
        x, y = self.load_data()
        x, y = self.shuffle(x, y, generator=self.generator)     
        self.train_x, self.train_y, self.test_x, self.test_y = self.split(x, y)
        self.test_y_std = self.test_y.std(dim=0, keepdim=True)
        self.train_x, self.train_y, self.test_x, self.test_y = map(
            self.normalize, (self.train_x, self.train_y, self.test_x, self.test_y))

    @property
    def dimension(self) -> int:
        return self.train_x.shape[-1]

    @property 
    def train_dataset(self) -> Dataset:
        return TensorDataset(self.train_x, self.train_y)
    
    @property
    def test_dataset(self) -> Dataset:
        return TensorDataset(self.test_x, self.test_y)

    def read_data(self) -> tuple[Tensor, Tensor]:
        xy = torch.from_numpy(pd.read_csv(self.csv_path).values)
        return xy[:, :-1], xy[:, -1]

    def download_data(self) -> None:
        NotImplementedError

    def load_data(self, overwrite: bool = False) -> tuple[Tensor, Tensor]:
        if overwrite or not os.path.isfile(self.csv_path):
            self.download_data()
        return self.read_data()

    def normalize(self, x: Tensor) -> Tensor:
        return (x - x.mean(dim=0)) / x.std(dim=0, keepdim=True)
    
    def shuffle(self, x: Tensor, y: Tensor, generator: torch.Generator) -> tuple[Tensor, Tensor]:
        perm_idx = torch.randperm(x.size(0), generator=generator)
        return x[perm_idx], y[perm_idx]
    
    def split(self, x: Tensor, y: Tensor, test_size: float = 0.1) -> tuple[Tensor, Tensor, Tensor, Tensor]: 
        """
        Split the dataset into train and test sets.
        """
        split_idx = int(test_size * x.size(0))
        return x[split_idx:], y[split_idx:], x[:split_idx], y[:split_idx]


class Kin8mn(UCIDataset):

    DEFAULT_URL = 'https://raw.githubusercontent.com/liusiyan/UQnet/master/datasets/UCI_datasets/kin8nm/dataset_2175_kin8nm.csv'

    def __init__(self, path: str = '../../data/uci/', normalize: bool = True, seed: int | None = None, url: str = DEFAULT_URL):
        super().__init__(name='kin8nm', path=path, normalize=normalize, seed=seed)
        self.url = url 

    def download_data(self) -> None:
        df = pd.read_csv(self.url)
        os.makedirs(self.path, exist_ok=True)
        df.to_csv(self.csv_path, index=False)

# Instantiate model

This is done using the same model arguments as for the benchmarking experiment

In [4]:
from linear_operator.operators import DiagLinearOperator
from gpytorch.distributions import MultivariateNormal


class SphereProjector(torch.nn.Module):
    def __init__(self):
        super().__init__()
        b = torch.tensor(2.0)
        self.register_parameter('b', torch.nn.Parameter(b))
        self.norm = None 

    def forward(self, x: Tensor, y: Tensor | None = None) -> tuple[Tensor, Tensor] | Tensor:
        b = self.b.expand(*x.shape[:-1], 1)
        x_cat_b = torch.cat([x, b], dim=-1)
        self.norm = x_cat_b.norm(dim=-1, keepdim=True)
        if y is None:
            return x_cat_b / self.norm
        else:
            return x_cat_b / self.norm, y / self.norm.squeeze(-1)
    
    def inverse(self, mvn: MultivariateNormal) -> MultivariateNormal:
        L = DiagLinearOperator(self.norm.squeeze(-1))
        mean = mvn.mean @ L
        cov = L @ mvn.lazy_covariance_matrix @ L
        return MultivariateNormal(mean=mean, covariance_matrix=cov)

In [5]:
from torch import Tensor 
from gpytorch.means import Mean
from gpytorch.kernels import ScaleKernel
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from mdgp.kernels import GeometricMaternKernel


import torch 
from gpytorch import Module, settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.utils.memoize import cached, clear_cache_hook
from linear_operator.operators import DiagLinearOperator
from functools import cached_property
# TODO Maybe just move the functions from spherical_harmonic_features.py into this file?
from mdgp.utils.spherical_harmonic_features import num_spherical_harmonics_to_degree, matern_Kuu, matern_LT_Phi


class SphericalHarmonicFeaturesVariationalStrategy(Module):
    def __init__(
        self,
        model: ApproximateGP,
        variational_distribution: CholeskyVariationalDistribution,
        dimension: int, 
        num_spherical_harmonics: int, 
        jitter_val: float | None = None,
    ):
        super().__init__()
        self._jitter_val = jitter_val

        # model, set via object.__setattr__ to avoid treatment as a module, parameter, or buffer
        object.__setattr__(self, "_model", model)

        # Variational distribution
        self._variational_distribution = variational_distribution
        self.register_buffer("variational_params_initialized", torch.tensor(0))

        # spherical harmonics 
        self.dimension = dimension 
        self.degree, self.num_spherical_harmonics = num_spherical_harmonics_to_degree(num_spherical_harmonics, dimension)

    @property
    def jitter_val(self) -> float:
        if self._jitter_val is None:
            return settings.variational_cholesky_jitter.value(dtype=torch.get_default_dtype())
        return self._jitter_val

    @property
    def model(self) -> ApproximateGP:
        return self._model
    
    @property
    def covar_module(self) -> ScaleKernel | GeometricMaternKernel:
        return self.model.covar_module
    
    @cached_property
    def base_kernel(self) -> GeometricMaternKernel:
        if isinstance(self.covar_module, ScaleKernel):
            return self.covar_module.base_kernel
        else:
            return self.covar_module
    
    @property
    def mean_module(self) -> Mean:
        return self.model.mean_module

    @property
    def kappa(self) -> Tensor:
        return self.base_kernel.lengthscale
    
    @property
    def nu(self) -> Tensor | float:
        return self.base_kernel.nu
    
    @property 
    def outputscale(self) -> Tensor:
        return self.covar_module.outputscale if hasattr(self.covar_module, "outputscale") else torch.tensor(1.0)

    @property 
    def sigma(self) -> Tensor: 
        return self.outputscale.sqrt()

    def _clear_cache(self) -> None:
        clear_cache_hook(self)

    @property
    @cached(name="prior_distribution_memo")
    def prior_distribution(self) -> MultivariateNormal:
        covariance_matrix = DiagLinearOperator(torch.ones(self.num_spherical_harmonics))
        mean = torch.zeros(self.num_spherical_harmonics)
        return MultivariateNormal(mean=mean, covariance_matrix=covariance_matrix)
    
    @property 
    @cached(name="cholesky_factor_prior_memo")
    def cholesky_factor_prior(self) -> DiagLinearOperator:
        Kuu = matern_Kuu(max_ell=self.degree, d=self.dimension, kappa=self.kappa, nu=self.nu, sigma=self.sigma)
        return Kuu.cholesky() # Kuu is a DiagLinearOperator, so .cholesky() is equivalent to .sqrt() 
        
    @property
    @cached(name="variational_distribution_memo")
    def variational_distribution(self) -> MultivariateNormal:
        return self._variational_distribution()

    def forward(self, x: Tensor, **kwargs) -> MultivariateNormal:
        # inducing-inducing prior
        pu = self.prior_distribution
        invL_muu, invL_Kuu_invLt = pu.mean, pu.lazy_covariance_matrix

        # input-input prior
        px = self.model.forward(x)
        mux, Kxx = px.mean, px.lazy_covariance_matrix

        # inducing-inducing variational
        qu = self.variational_distribution
        invL_m, invL_S_invLt = qu.mean, qu.lazy_covariance_matrix

        # Add jitter to Kxx and invL_Kuu_invLt for numerical stability
        Kxx = Kxx.add_jitter(self.jitter_val)
        invL_Kuu_invLt = invL_Kuu_invLt.add_jitter(self.jitter_val)

        # inducing-input prior  
        LT_Phi = matern_LT_Phi(
            x, max_ell=self.degree, d=self.dimension, kappa=self.kappa, nu=self.nu, sigma=self.sigma,
        ) # [..., O, num_harmonics, N]
        
        # Update the mean
        mean_update = torch.einsum('...ij,...i->...j', LT_Phi, invL_m - invL_muu) # [..., O, num_harmonics, N] @ [O, num_harmonics] -> [..., O, N]
        updated_mean = mux + mean_update # [..., O, N] + [..., O, N] -> [..., O, N]

        # Update the covariance matrix
        covariance_matrix_update = LT_Phi.mT @ (invL_S_invLt - invL_Kuu_invLt) @ LT_Phi # [O, num_harmonics, num_harmonics] @ [O, num_harmonics, num_harmonics] @ [O, num_harmonics, N] -> [O, num_harmonics, N]
        updated_covariance_matrix = Kxx + covariance_matrix_update # [..., O, N, N] + [..., O, N, N] -> [..., O, N, N]

        return MultivariateNormal(mean=updated_mean, covariance_matrix=updated_covariance_matrix)

    def kl_divergence(self) -> Tensor:
        with settings.max_preconditioner_size(0):
            kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
        return kl_divergence
    
    def __call__(self, x: Tensor, prior: bool = False, **kwargs) -> MultivariateNormal:
        # If we're in prior mode, then we're done!
        if prior:
            return self.model.forward(x, **kwargs)

        # Delete previously cached items from the training distribution
        if self.training:
            self._clear_cache()

        # (Maybe) initialize variational distribution
        if not self.variational_params_initialized.item():
            prior_dist = self.prior_distribution
            self._variational_distribution.initialize_variational_distribution(prior_dist)
            self.variational_params_initialized.fill_(1)

        return super().__call__(x, **kwargs)

INFO: Using numpy backend


In [6]:
from torch import Tensor 
from gpytorch.means import Mean
from gpytorch.kernels import ScaleKernel
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from mdgp.kernels import GeometricMaternKernel


import torch 
from gpytorch import Module, settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.utils.memoize import cached, clear_cache_hook
from linear_operator.operators import DiagLinearOperator
from functools import cached_property
# TODO Maybe just move the functions from spherical_harmonic_features.py into this file?
from mdgp.utils.spherical_harmonic_features import num_spherical_harmonics_to_degree, matern_Kuu, matern_LT_Phi


class SphericalHarmonicFeaturesVariationalStrategy(Module):
    def __init__(
        self,
        model: ApproximateGP,
        variational_distribution: CholeskyVariationalDistribution,
        dimension: int, 
        num_spherical_harmonics: int, 
        jitter_val: float | None = None,
    ):
        super().__init__()
        self._jitter_val = jitter_val

        # model, set via object.__setattr__ to avoid treatment as a module, parameter, or buffer
        object.__setattr__(self, "_model", model)

        # Variational distribution
        self._variational_distribution = variational_distribution
        self._variational_distribution.initialize_variational_distribution(self.prior_distribution)
        self.register_buffer("variational_params_initialized", torch.tensor(0))

        # spherical harmonics 
        self.dimension = dimension 
        self.degree, self.num_spherical_harmonics = num_spherical_harmonics_to_degree(num_spherical_harmonics, dimension)

    @property
    def jitter_val(self) -> float:
        if self._jitter_val is None:
            return settings.variational_cholesky_jitter.value(dtype=torch.get_default_dtype())
        return self._jitter_val

    @property
    def model(self) -> ApproximateGP:
        return self._model
    
    @property
    def covar_module(self) -> ScaleKernel | GeometricMaternKernel:
        return self.model.covar_module
    
    @cached_property
    def base_kernel(self) -> GeometricMaternKernel:
        if isinstance(self.covar_module, ScaleKernel):
            return self.covar_module.base_kernel
        else:
            return self.covar_module
    
    @property
    def mean_module(self) -> Mean:
        return self.model.mean_module

    @property
    def kappa(self) -> Tensor:
        return self.base_kernel.lengthscale
    
    @property
    def nu(self) -> Tensor | float:
        return self.base_kernel.nu
    
    @property 
    def outputscale(self) -> Tensor:
        return self.covar_module.outputscale if hasattr(self.covar_module, "outputscale") else torch.tensor(1.0)

    @property 
    def sigma(self) -> Tensor: 
        return self.outputscale.sqrt()

    def _clear_cache(self) -> None:
        clear_cache_hook(self)

    @property
    @cached(name="prior_distribution_memo")
    def prior_distribution(self) -> MultivariateNormal:
        covariance_matrix = DiagLinearOperator(torch.ones(self.num_spherical_harmonics))
        mean = torch.zeros(self.num_spherical_harmonics)
        return MultivariateNormal(mean=mean, covariance_matrix=covariance_matrix)
    
    @property 
    @cached(name="cholesky_factor_prior_memo")
    def cholesky_factor_prior(self) -> DiagLinearOperator:
        Kuu = matern_Kuu(max_ell=self.degree, d=self.dimension, kappa=self.kappa, nu=self.nu, sigma=self.sigma)
        return Kuu.cholesky() # Kuu is a DiagLinearOperator, so .cholesky() is equivalent to .sqrt() 
        
    @property
    @cached(name="variational_distribution_memo")
    def variational_distribution(self) -> MultivariateNormal:
        return self._variational_distribution()

    def forward(self, x: Tensor, **kwargs) -> MultivariateNormal:
        # inducing-inducing prior
        pu = self.prior_distribution
        invL_Kuu_invLt = pu.lazy_covariance_matrix

        # input-input prior
        mux, Kxx = self.model.mean_module(x), self.model.covar_module(x)

        # inducing-inducing variational
        qu = self.variational_distribution
        invL_S_invLt = qu.lazy_covariance_matrix

        # Add jitter to Kxx and invL_Kuu_invLt for numerical stability
        Kxx = Kxx.add_jitter(self.jitter_val)
        invL_Kuu_invLt = invL_Kuu_invLt.add_jitter(self.jitter_val)

        # inducing-input prior  
        LT_Phi = matern_LT_Phi(
            x, max_ell=self.degree, d=self.dimension, kappa=self.kappa, nu=self.nu, sigma=self.sigma,
        ) # [..., O, num_harmonics, N]
        
        # Update the mean
        updated_mean = mux

        # Update the covariance matrix
        covariance_matrix_update = LT_Phi.mT @ (invL_S_invLt - invL_Kuu_invLt) @ LT_Phi # [O, num_harmonics, num_harmonics] @ [O, num_harmonics, num_harmonics] @ [O, num_harmonics, N] -> [O, num_harmonics, N]
        updated_covariance_matrix = Kxx + covariance_matrix_update # [..., O, N, N] + [..., O, N, N] -> [..., O, N, N]

        return MultivariateNormal(mean=updated_mean, covariance_matrix=updated_covariance_matrix)

    def kl_divergence(self) -> Tensor:
        with settings.max_preconditioner_size(0):
            kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
        return kl_divergence
    
    def __call__(self, x: Tensor, prior: bool = False, **kwargs) -> MultivariateNormal:
        # If we're in prior mode, then we're done!
        if prior:
            return self.model.forward(x, **kwargs)

        # Delete previously cached items from the training distribution
        if self.training:
            self._clear_cache()

        # (Maybe) initialize variational distribution
        if not self.variational_params_initialized.item():
            prior_dist = self.prior_distribution
            self._variational_distribution.initialize_variational_distribution(prior_dist)
            self.variational_params_initialized.fill_(1)

        return super().__call__(x, **kwargs)

In [7]:
class SphericalHarmonicFeaturesVariationalStrategy(Module):
    def __init__(
        self,
        model: ApproximateGP,
        variational_distribution: CholeskyVariationalDistribution,
        dimension: int, 
        num_spherical_harmonics: int, 
        jitter_val: float | None = None,
    ):
        super().__init__()
        self._jitter_val = jitter_val

        # model, set via object.__setattr__ to avoid treatment as a module, parameter, or buffer
        object.__setattr__(self, "_model", model)

        # Variational distribution
        self._variational_distribution = variational_distribution
        self.register_buffer("variational_params_initialized", torch.tensor(0))

        # spherical harmonics 
        self.dimension = dimension 
        self.degree, self.num_spherical_harmonics = num_spherical_harmonics_to_degree(num_spherical_harmonics, dimension)

    @property
    def jitter_val(self) -> float:
        if self._jitter_val is None:
            return settings.variational_cholesky_jitter.value(dtype=torch.get_default_dtype())
        return self._jitter_val

    @property
    def model(self) -> ApproximateGP:
        return self._model

    @property
    def kappa(self) -> Tensor:
        return torch.tensor([[0.001]])
    
    @property
    def nu(self) -> Tensor | float:
        return torch.tensor([[2.5]])

    @property 
    def sigma(self) -> Tensor: 
        return torch.tensor(1.0)

    def _clear_cache(self) -> None:
        clear_cache_hook(self)

    @property
    def prior_distribution(self) -> MultivariateNormal:
        covariance_matrix = DiagLinearOperator(torch.ones(self.num_spherical_harmonics))
        mean = torch.zeros(self.num_spherical_harmonics)
        return MultivariateNormal(mean=mean, covariance_matrix=covariance_matrix)
        
    @property
    def variational_distribution(self) -> MultivariateNormal:
        return self._variational_distribution()

    def forward(self, x: Tensor, **kwargs) -> MultivariateNormal:
        # inducing-inducing prior
        pu = self.prior_distribution
        invL_Kuu_invLt = pu.lazy_covariance_matrix

        # input-input prior
        mux, Kxx = self.model.mean_module(x), self.model.covar_module(x)

        # inducing-inducing variational
        qu = self.variational_distribution
        invL_S_invLt = qu.lazy_covariance_matrix

        # Add jitter to Kxx and invL_Kuu_invLt for numerical stability
        Kxx = Kxx.add_jitter(self.jitter_val)
        invL_Kuu_invLt = invL_Kuu_invLt.add_jitter(self.jitter_val)

        # inducing-input prior  
        LT_Phi = matern_LT_Phi(
            x, max_ell=self.degree, d=self.dimension, kappa=self.kappa, nu=self.nu, sigma=self.sigma,
        ) # [..., O, num_harmonics, N]
        
        # Update the mean
        updated_mean = mux

        # Update the covariance matrix
        covariance_matrix_update = LT_Phi.mT @ (invL_S_invLt - invL_Kuu_invLt) @ LT_Phi # [O, num_harmonics, num_harmonics] @ [O, num_harmonics, num_harmonics] @ [O, num_harmonics, N] -> [O, num_harmonics, N]
        updated_covariance_matrix = Kxx + covariance_matrix_update # [..., O, N, N] + [..., O, N, N] -> [..., O, N, N]

        return MultivariateNormal(mean=updated_mean, covariance_matrix=updated_covariance_matrix)

    def kl_divergence(self) -> Tensor:
        with settings.max_preconditioner_size(0):
            kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
        return kl_divergence
    
    def __call__(self, x: Tensor, prior: bool = False, **kwargs) -> MultivariateNormal:
        # If we're in prior mode, then we're done!
        if prior:
            return self.model.forward(x, **kwargs)

        # Delete previously cached items from the training distribution
        if self.training:
            self._clear_cache()

        # (Maybe) initialize variational distribution
        if not self.variational_params_initialized.item():
            prior_dist = self.prior_distribution
            self._variational_distribution.initialize_variational_distribution(prior_dist)
            self.variational_params_initialized.fill_(1)

        return super().__call__(x, **kwargs)

In [8]:
from mdgp.utils.spherical_harmonic_features import matern_Kux, matern_repeated_ahat

def matern_Phi(x: Tensor, max_ell: int, d: int, kappa: float, nu: float, rsh: bool = True, sigma: float = 1.0) -> Tensor: 
    """
    Returns the feature vector of spherical harmonics evaluated at x. 
    """
    Kux = matern_Kux(x, max_ell=max_ell, d=d) # [*B, num_harmonics, n]
    ahat = matern_repeated_ahat(max_ell=max_ell, d=d, kappa=kappa, nu=nu, sigma=sigma).unsqueeze(-1) # [num_harmonics, 1]
    return Kux * ahat # [*B, num_harmonics, n]


class SphericalHarmonicFeaturesVariationalStrategy(Module):
    def __init__(
        self,
        model: ApproximateGP,
        variational_distribution: CholeskyVariationalDistribution,
        dimension: int, 
        num_spherical_harmonics: int, 
        jitter_val: float | None = None,
    ):
        super().__init__()
        self._jitter_val = jitter_val

        # model, set via object.__setattr__ to avoid treatment as a module, parameter, or buffer
        object.__setattr__(self, "_model", model)

        # Variational distribution
        self._variational_distribution = variational_distribution
        self.register_buffer("variational_params_initialized", torch.tensor(0))

        # spherical harmonics 
        self.dimension = dimension 
        self.degree, self.num_spherical_harmonics = num_spherical_harmonics_to_degree(num_spherical_harmonics, dimension)

    @property
    def jitter_val(self) -> float:
        if self._jitter_val is None:
            return settings.variational_cholesky_jitter.value(dtype=torch.get_default_dtype())
        return self._jitter_val

    @property
    def model(self) -> ApproximateGP:
        return self._model

    @property
    def kappa(self) -> Tensor:
        return torch.tensor([[0.001]])
    
    @property
    def nu(self) -> Tensor | float:
        return torch.tensor([[2.5]])

    @property 
    def sigma(self) -> Tensor: 
        return torch.tensor(1.0)

    def _clear_cache(self) -> None:
        clear_cache_hook(self)

    @property
    def prior_distribution(self) -> MultivariateNormal:
        covariance_matrix = matern_Kuu(
            max_ell=self.degree, 
            d=self.dimension,
            kappa=self.kappa,
            nu=self.nu,
            sigma=self.sigma,
        )
        mean = torch.zeros(self.num_spherical_harmonics)
        return MultivariateNormal(mean=mean, covariance_matrix=covariance_matrix)
        
    @property
    def variational_distribution(self) -> MultivariateNormal:
        return self._variational_distribution()

    def forward(
        self,
        x: Tensor,
        **kwargs,
    ) -> MultivariateNormal:
        
        # inducing-inducing prior
        qu = self.variational_distribution
        S = qu.lazy_covariance_matrix

        # inducing variables prior
        fu_mvn = self.prior_distribution
        muu, Kuu = fu_mvn.mean, fu_mvn.lazy_covariance_matrix

        # input points prior 
        mux, Kxx = self.model.mean_module(x), self.model.covar_module(x)

        # compute Phi, consdering that matern_Phi does not accept batch dimensions 
        Phi = matern_Phi(x, self.degree, d=self.dimension, kappa=self.kappa, nu=self.nu, sigma=self.sigma) 
        updated_covariance_matrix = Kxx + Phi.mT @ (S - Kuu) @ Phi
        updated_mean = mux

        return MultivariateNormal(mean=updated_mean, covariance_matrix=updated_covariance_matrix)

    def kl_divergence(self) -> Tensor:
        with settings.max_preconditioner_size(0):
            kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
        return kl_divergence
    
    def __call__(self, x: Tensor, prior: bool = False, **kwargs) -> MultivariateNormal:
        # If we're in prior mode, then we're done!
        if prior:
            return self.model.forward(x, **kwargs)

        # Delete previously cached items from the training distribution
        if self.training:
            self._clear_cache()

        # (Maybe) initialize variational distribution
        if not self.variational_params_initialized.item():
            prior_dist = self.prior_distribution
            self._variational_distribution.initialize_variational_distribution(prior_dist)
            self.variational_params_initialized.fill_(1)

        return super().__call__(x, **kwargs)

# Simple shallow variational GP

In [9]:
import gpytorch


class SHApproximateGP(gpytorch.models.ApproximateGP):
    def __init__(self, dimension: int, num_spherical_harmonics: int, 
                 mean: Mean, covar_module: GeometricMaternKernel, variational_distribution: CholeskyVariationalDistribution, jitter_val: float | None = None):
        variational_strategy = SphericalHarmonicFeaturesVariationalStrategy(
            model=self,
            variational_distribution=variational_distribution,
            dimension=dimension,
            num_spherical_harmonics=num_spherical_harmonics,
            jitter_val=jitter_val
        )
        super().__init__(variational_strategy)
        self.mean_module = mean
        self.covar_module = covar_module
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()


    def forward(self, x: Tensor) -> MultivariateNormal:
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)
    

class IPApproximateGP(gpytorch.models.ApproximateGP):
    def __init__(self, mean: Mean, covar_module: GeometricMaternKernel, inducing_points: Tensor, variational_distribution: CholeskyVariationalDistribution):
        variational_strategy = gpytorch.variational.UnwhitenedVariationalStrategy(
            self, 
            inducing_points=inducing_points,
            variational_distribution=variational_distribution, 
            learn_inducing_locations=False
        )
        super().__init__(variational_strategy)
        self.mean_module = mean
        self.covar_module = covar_module
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()

    def forward(self, x: Tensor) -> MultivariateNormal:
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

# Train model
Training has to be done differently from the bechmarking experiment, because we need minibatch SGD with the larger datasets and minibatch metrics. 

In [10]:
def print_smallest_eigenvalues(covar):
    # mvn = projector.inverse(mvn)
    smallest_eigenvalues = torch.linalg.eigvalsh(covar).min()
    print(f"Smallest eigenvalue: {smallest_eigenvalues.item()}")

In [11]:
from geometric_kernels.spaces import Hypersphere
from mdgp.utils.spherical_harmonic_features import num_spherical_harmonics_to_degree


dataset = Kin8mn()

# Generic parameters
num_spherical_harmonics = 50
dimension = dataset.dimension + 1

degree, num_spherical_harmonics = num_spherical_harmonics_to_degree(num_spherical_harmonics, dimension)
space = Hypersphere(dimension)

batch_shape = torch.Size([])

# Model with spherical harmonic features
mean = gpytorch.means.ZeroMean()
covar_module = GeometricMaternKernel(nu=2.5, space=space, num_eigenfunctions=4, batch_shape=batch_shape)
covar_module.initialize(lengthscale=0.001)
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
    num_inducing_points=num_spherical_harmonics, batch_shape=batch_shape
)

model_sh = SHApproximateGP(
    dimension=dimension,
    num_spherical_harmonics=num_spherical_harmonics,
    mean=mean,
    covar_module=covar_module,
    variational_distribution=variational_distribution,
)

# Model with inducing points 
mean = gpytorch.means.ZeroMean()
covar_module = GeometricMaternKernel(nu=2.5, space=space, num_eigenfunctions=4, batch_shape=batch_shape)
covar_module.initialize(lengthscale=0.001)
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
    num_inducing_points=num_spherical_harmonics, batch_shape=batch_shape
)

inducing_points = torch.randn(num_spherical_harmonics, dimension)
inducing_points = inducing_points / inducing_points.norm(dim=-1, keepdim=True)
model_ip = IPApproximateGP(
    mean=mean,
    covar_module=covar_module,
    inducing_points=inducing_points,
    variational_distribution=variational_distribution
)


# Arbitrary projector 
projector = SphereProjector()

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 54, which includes all spherical harmonics up to degree 3 (incl.)


In [12]:
model = model_sh

batch_size = 256
train_loader = DataLoader(dataset.train_dataset, batch_size=batch_size, shuffle=True)

# 
parameters = [
    model.variational_strategy._variational_distribution.chol_variational_covar,
]
optimizer = torch.optim.Adam(parameters, lr=0.1)
mll = gpytorch.mlls.VariationalELBO(model.likelihood, model, dataset.train_x.size(0))

for x_batch, y_batch in train_loader:
    x_batch, y_batch = projector(x_batch, y_batch)
    optimizer.zero_grad()
    output = model.variational_strategy(x_batch)
    print_smallest_eigenvalues(output.covariance_matrix)
    # print_smallest_eigenvalues(model.variational_strategy.variational_distribution.covariance_matrix)
    loss = -mll(output, y_batch)
    loss.backward()
    optimizer.step()

    print(loss.item())

Smallest eigenvalue: -9.077262895129363e-16


RuntimeError: grad can be implicitly created only for scalar outputs

In [89]:
model = model_ip

batch_size = 256
train_loader = DataLoader(dataset.train_dataset, batch_size=batch_size, shuffle=True)

# 
parameters = [
    model.variational_strategy._variational_distribution.chol_variational_covar,
]
optimizer = torch.optim.Adam(parameters, lr=0.1)
mll = gpytorch.mlls.VariationalELBO(model.likelihood, model, dataset.train_x.size(0))

for x_batch, y_batch in train_loader:
    x_batch, y_batch = projector(x_batch, y_batch)
    optimizer.zero_grad()
    output = model(x_batch)
    print_smallest_eigenvalues(output.covariance_matrix)
    loss = -mll(output, y_batch)
    loss.backward()
    optimizer.step()

    print(loss.item())

Smallest eigenvalue: 9.999999988798501e-07
1.5182695768211714
Smallest eigenvalue: 9.999999989562332e-07
1.4947433100617336
Smallest eigenvalue: 9.99999998908287e-07
1.4414993314065088
Smallest eigenvalue: 9.999999988947179e-07
1.4041950480701977
Smallest eigenvalue: 9.999999990619052e-07
1.3796913437218985
Smallest eigenvalue: 9.999999991144903e-07
1.3637745642671706
Smallest eigenvalue: 9.999999990971003e-07
1.3576423747766195
Smallest eigenvalue: 9.999999990506075e-07
1.3314937268999756
Smallest eigenvalue: 9.9999999894546e-07
1.3163540589840574
Smallest eigenvalue: 9.99999999111952e-07
1.3129889125774872
Smallest eigenvalue: 9.999999990953514e-07
1.3158454601627185
Smallest eigenvalue: 9.999999990809995e-07
1.323727298242035
Smallest eigenvalue: 9.999999991838973e-07
1.321055025913
Smallest eigenvalue: 9.999999990535063e-07
1.3224700873685378
Smallest eigenvalue: 9.999999991874656e-07
1.314687115369776
Smallest eigenvalue: 9.999999991031666e-07
1.3197032585534363
Smallest eigenvalu

### Debug

In [14]:
import plotly.express as px 
import pandas as pd 


def plot_kernel_vs_angle(kernel, dimension: int, n: int = 100): 
    pole = torch.zeros(1, dimension)
    pole[:, -1] = 1.
    theta = torch.linspace(0, torch.pi, n)
    x = torch.cat([torch.zeros(n, dimension - 2), theta.cos().unsqueeze(-1), theta.sin().unsqueeze(-1)], dim=-1)
    with torch.no_grad():
        y = kernel(x).lazy_covariance_matrix[..., 0]
        if y.ndim == 2: 
            y = y.mean(0)

    data = pd.DataFrame({'theta': theta.squeeze().numpy(), 'y': y})
    fig = px.line(data, x='theta', y='y')
    fig.show()
