In [2]:
import geometric_kernels.torch 
import torch 
import gpytorch 
from mdgp.kernels import GeometricMaternKernel
from gpytorch.kernels import ScaleKernel
from geometric_kernels.spaces import Hypersphere
from gpytorch import Module
from gpytorch.utils.memoize import cached, clear_cache_hook
from gpytorch import settings 
from gpytorch.kernels import ScaleKernel

torch.set_default_dtype(torch.float64)

Idea: We could reuse the spherical harmonics computed in the variational strategy for the kernel itself.

There actually seems to be a simplified form for the update when using the same number of spherical harmonics for the kernel approximation and for the inducing variables. But that might be too many inducing variables. 

In [3]:
import warnings
from typing import Any, Dict, Iterable, Optional, Tuple, Union

import torch
from linear_operator import to_dense
from linear_operator.operators import (
    CholLinearOperator,
    DiagLinearOperator,
    LinearOperator,
    MatmulLinearOperator,
    RootLinearOperator,
    SumLinearOperator,
    TriangularLinearOperator,
    DenseLinearOperator,
)
from linear_operator.utils.cholesky import psd_safe_cholesky
from linear_operator.utils.errors import NotPSDError
from torch import Tensor

from gpytorch.variational._variational_strategy import _VariationalStrategy
from gpytorch.variational.cholesky_variational_distribution import CholeskyVariationalDistribution

from gpytorch.distributions import MultivariateNormal, Distribution
from gpytorch.models import ApproximateGP
from gpytorch.settings import _linalg_dtype_cholesky, trace_mode
from gpytorch.utils.errors import CachingError
from gpytorch.utils.memoize import cached, pop_from_cache_ignore_args
from gpytorch.variational import _VariationalDistribution
from abc import ABC, abstractproperty


class _VariationalStrategy(Module, ABC):
    """
    Abstract base class for all Variational Strategies.
    """

    has_fantasy_strategy = False

    def __init__(
        self,
        model: Union[ApproximateGP, "_VariationalStrategy"],
        variational_distribution: _VariationalDistribution,
        jitter_val: Optional[float] = None,
    ):
        super().__init__()

        self._jitter_val = jitter_val

        # Model
        object.__setattr__(self, "model", model)

        # Variational distribution
        self._variational_distribution = variational_distribution
        self.register_buffer("variational_params_initialized", torch.tensor(0))

    def _clear_cache(self) -> None:
        clear_cache_hook(self)

    @property
    def jitter_val(self) -> float:
        if self._jitter_val is None:
            return settings.variational_cholesky_jitter.value(dtype=torch.get_default_dtype())
        return self._jitter_val

    @jitter_val.setter
    def jitter_val(self, jitter_val: float):
        self._jitter_val = jitter_val

    @abstractproperty
    @cached(name="prior_distribution_memo")
    def prior_distribution(self) -> MultivariateNormal:
        r"""
        The :func:`~gpytorch.variational.VariationalStrategy.prior_distribution` method determines how to compute the
        GP prior distribution of the inducing points, e.g. :math:`p(u) \sim N(\mu(X_u), K(X_u, X_u))`. Most commonly,
        this is done simply by calling the user defined GP prior on the inducing point data directly.

        :rtype: :obj:`~gpytorch.distributions.MultivariateNormal`
        :return: The distribution :math:`p( \mathbf u)`
        """
        raise NotImplementedError

    @property
    @cached(name="variational_distribution_memo")
    def variational_distribution(self) -> Distribution:
        return self._variational_distribution()

    def forward(
        self,
        x: Tensor,
        inducing_values: Tensor,
        variational_inducing_covar: Optional[LinearOperator] = None,
        **kwargs,
    ) -> MultivariateNormal:
        r"""
        The :func:`~gpytorch.variational.VariationalStrategy.forward` method determines how to marginalize out the
        inducing point function values. Specifically, forward defines how to transform a variational distribution
        over the inducing point values, :math:`q(u)`, in to a variational distribution over the function values at
        specified locations x, :math:`q(f|x)`, by integrating :math:`\int p(f|x, u)q(u)du`

        :param x: Locations :math:`\mathbf X` to get the
            variational posterior of the function values at.
        :param inducing_points: Locations :math:`\mathbf Z` of the inducing points
        :param inducing_values: Samples of the inducing function values :math:`\mathbf u`
            (or the mean of the distribution :math:`q(\mathbf u)` if q is a Gaussian.
        :param variational_inducing_covar: If
            the distribuiton :math:`q(\mathbf u)` is
            Gaussian, then this variable is the covariance matrix of that Gaussian.
            Otherwise, it will be None.

        :rtype: :obj:`~gpytorch.distributions.MultivariateNormal`
        :return: The distribution :math:`q( \mathbf f(\mathbf X))`
        """
        raise NotImplementedError

    def kl_divergence(self) -> Tensor:
        r"""
        Compute the KL divergence between the variational inducing distribution :math:`q(\mathbf u)`
        and the prior inducing distribution :math:`p(\mathbf u)`.
        """
        with settings.max_preconditioner_size(0):
            kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
        return kl_divergence
    
    def __call__(self, x: Tensor, prior: bool = False, **kwargs) -> MultivariateNormal:
        # If we're in prior mode, then we're done!
        if prior:
            return self.model.forward(x, **kwargs)

        # Delete previously cached items from the training distribution
        if self.training:
            self._clear_cache()

        # (Maybe) initialize variational distribution
        if not self.variational_params_initialized.item():
            prior_dist = self.prior_distribution
            self._variational_distribution.initialize_variational_distribution(prior_dist)
            self.variational_params_initialized.fill_(1)

        # Get p(u)/q(u)
        variational_dist_u = self.variational_distribution

        # Get q(f)
        if isinstance(variational_dist_u, MultivariateNormal):
            return super().__call__(
                x,
                inducing_values=variational_dist_u.mean,
                variational_inducing_covar=variational_dist_u.lazy_covariance_matrix,
                **kwargs,
            )
        raise RuntimeError

In [104]:
from mdgp.utils.spherical_harmonic_features import matern_Kuu, matern_Kux
from gpytorch.variational import VariationalStrategy


def _ensure_updated_strategy_flag_set(
    state_dict: Dict[str, Tensor],
    prefix: str,
    local_metadata: Dict[str, Any],
    strict: bool,
    missing_keys: Iterable[str],
    unexpected_keys: Iterable[str],
    error_msgs: Iterable[str],
):
    device = state_dict[list(state_dict.keys())[0]].device
    if prefix + "updated_strategy" not in state_dict:
        state_dict[prefix + "updated_strategy"] = torch.tensor(False, device=device)
        warnings.warn(
            "You have loaded a variational GP model (using `VariationalStrategy`) from a previous version of "
            "GPyTorch. We have updated the parameters of your model to work with the new version of "
            "`VariationalStrategy` that uses whitened parameters.\nYour model will work as expected, but we "
            "recommend that you re-save your model.",
        )


class VariationalStrategy(_VariationalStrategy):
    def __init__(
        self,
        model: ApproximateGP,
        variational_distribution: _VariationalDistribution,
        max_ell: int, 
        jitter_val: Optional[float] = None,
    ):
        super().__init__(
            model, variational_distribution, jitter_val=jitter_val
        )
        self._max_ell = max_ell

        self.register_buffer("updated_strategy", torch.tensor(True))
        self._register_load_state_dict_pre_hook(_ensure_updated_strategy_flag_set)
        self.has_fantasy_strategy = True

    @property
    def max_ell(self):
        return self._max_ell
    
    @property
    def d(self):
        return self.model.covar_module.base_kernel.space.dimension + 1
    
    @property
    def kappa(self):
        return self.model.covar_module.base_kernel.lengthscale
    
    @property
    def nu(self):
        return self.model.covar_module.base_kernel.nu
    
    @property
    def sigma(self):
        return self.model.covar_module.outputscale.sqrt()

    @cached(name="cholesky_factor", ignore_args=True)
    def _cholesky_factor(self, induc_induc_covar: LinearOperator) -> TriangularLinearOperator:
        return induc_induc_covar.cholesky()

    @property
    @cached(name="prior_distribution_memo")
    def prior_distribution(self) -> MultivariateNormal:
        zeros = torch.zeros(
            self._variational_distribution.shape(),
            dtype=self._variational_distribution.dtype,
            device=self._variational_distribution.device,
        )
        ones = torch.ones_like(zeros)
        res = MultivariateNormal(zeros, DiagLinearOperator(ones))
        return res
    
    def Kuu(self) -> DiagLinearOperator: 
        return matern_Kuu(max_ell=self.max_ell, d=self.d, kappa=self.kappa, nu=self.nu, sigma=self.sigma)
    
    def Kux(self, x: Tensor) -> DenseLinearOperator:
        return DenseLinearOperator(matern_Kux(x, max_ell=self.max_ell, d=self.d))

    def forward(
        self,
        x: Tensor,
        inducing_values: Tensor,
        variational_inducing_covar: Optional[LinearOperator] = None,
        **kwargs,
    ) -> MultivariateNormal:
        
        px = self.model.forward(x, **kwargs)

        # Prior covariance terms
        test_mean = px.mean
        data_data_covar = px.lazy_covariance_matrix
        return MultivariateNormal(test_mean, data_data_covar)

        induc_induc_covar = self.Kuu()#.add_jitter(self.jitter_val)
        induc_data_covar = self.Kux(x).to_dense()

        # Compute interpolation terms
        # K_ZZ^{-1/2} K_ZX
        # K_ZZ^{-1/2} \mu_Z
        # L = self._cholesky_factor(induc_induc_covar)
        # interp_term = L.solve(induc_data_covar.type(_linalg_dtype_cholesky.value())).to(torch.get_default_dtype()) # L^-1 @ Kux 
        L = induc_induc_covar.inverse().cholesky()
        interp_term = L @ induc_data_covar

        # Compute the mean of q(f)
        # k_XZ K_ZZ^{-1/2} (m - K_ZZ^{-1/2} \mu_Z) + \mu_X
        predictive_mean = (interp_term.transpose(-1, -2) @ inducing_values.unsqueeze(-1)).squeeze(-1) + test_mean

        # Compute the covariance of q(f)
        # K_XX + k_XZ K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2} k_ZX
        middle_term = self.prior_distribution.lazy_covariance_matrix.mul(-1) # -S
        middle_term = SumLinearOperator(variational_inducing_covar, middle_term) # I - S 

        predictive_covar = SumLinearOperator(
            data_data_covar.add_jitter(self.jitter_val),
            MatmulLinearOperator(interp_term.transpose(-1, -2), middle_term @ interp_term), # Kxu @ L.T @ (I - S) @ L @ Kux
        )

        # Return the distribution
        return MultivariateNormal(predictive_mean, predictive_covar)

In [105]:
from mdgp.utils import sphere_uniform_grid, sphere_meshgrid, spherical_harmonic


def target_fnc(x):
    f = spherical_harmonic(x, 2, 3)
    return f + 0.01 * torch.randn_like(f)


test_inputs = sphere_meshgrid(100, 100)
test_targets = target_fnc(test_inputs)

train_inputs = sphere_uniform_grid(200)
train_targets = target_fnc(train_inputs)

In [106]:
from mdgp.utils.spherical_harmonic_features import num_spherical_harmonics_to_degree, total_num_harmonics
from gpytorch.models import ApproximateGP


class SimpleApproximateGP(ApproximateGP):
    def __init__(self, mean_module, covar_module, likelihood, max_ell, jitter_val=None):
        num_inducing = total_num_harmonics(max_ell, d=covar_module.base_kernel.space.dimension + 1)
        variational_strategy = VariationalStrategy(
            self, 
            variational_distribution=CholeskyVariationalDistribution(num_inducing_points=num_inducing),
            max_ell=max_ell,
            jitter_val=jitter_val,
        )
        super().__init__(variational_strategy=variational_strategy)
        self.covar_module = covar_module
        self.mean_module = mean_module
        self.likelihood = likelihood

    def forward(self, x):
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        mvn = MultivariateNormal(mean, covar)
        return mvn

In [107]:
num_spherical_harmonics = 80
dimension = 3
degree, num_inducing = num_spherical_harmonics_to_degree(num_spherical_harmonics=num_spherical_harmonics, dimension=dimension)

# Variational distribution 
batch_shape = torch.Size([])

# covar
space = Hypersphere(dim=dimension - 1)
base_kernel = GeometricMaternKernel(space=space, trainable_nu=False, num_eigenfunctions=35, nu=2.5, batch_shape=batch_shape)
covar_module = ScaleKernel(base_kernel, batch_shape=batch_shape)

# mean
mean_module = gpytorch.means.ZeroMean(batch_shape=batch_shape)

# model
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = SimpleApproximateGP(
    covar_module=covar_module,
    mean_module=mean_module,
    max_ell=degree,
    likelihood=likelihood,
)

The number of spherical harmonics requested does not lead to complete levels of spherical harmonics. We have thus increased the number to 81, which includes all spherical harmonics up to degree 9 (incl.)


## Test on UCI

In [108]:
import os
import pandas as pd 
import torch 
from torch import Tensor
from torch.utils.data import TensorDataset, DataLoader, Dataset


class UCIDataset:

    UCI_BASE_URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/'

    def __init__(self, name: str, path: str = '../../data/uci/', normalize: bool = True, seed: int | None = None): 
        self.name = name 
        self.path = path 
        self.csv_path = os.path.join(self.path, self.name + '.csv')

        # Set generator if seed is provided.
        self.generator = torch.Generator()
        if seed is not None: 
            self.generator.manual_seed(seed)

        # Load, shuffle, split, and normalize data. TODO except for load, these don't need to be object methods. 
        # We keep the standard deviation of the test set for log-likelihood evaluation.
        x, y = self.load_data()
        x, y = self.shuffle(x, y, generator=self.generator)     
        self.train_x, self.train_y, self.test_x, self.test_y = self.split(x, y)
        self.test_y_std = self.test_y.std(dim=0, keepdim=True)
        self.train_x, self.train_y, self.test_x, self.test_y = map(
            self.normalize, (self.train_x, self.train_y, self.test_x, self.test_y))

    @property
    def dimension(self) -> int:
        return self.train_x.shape[-1]

    @property 
    def train_dataset(self) -> Dataset:
        return TensorDataset(self.train_x, self.train_y)
    
    @property
    def test_dataset(self) -> Dataset:
        return TensorDataset(self.test_x, self.test_y)


    def read_data(self) -> tuple[Tensor, Tensor]:
        xy = torch.from_numpy(pd.read_csv(self.csv_path).values)
        return xy[:, :-1], xy[:, -1]

    def download_data(self) -> None:
        NotImplementedError

    def load_data(self, overwrite: bool = False) -> tuple[Tensor, Tensor]:
        if overwrite or not os.path.isfile(self.csv_path):
            self.download_data()
        return self.read_data()

    def normalize(self, x: Tensor) -> Tensor:
        return (x - x.mean(dim=0)) / x.std(dim=0, keepdim=True)
    
    def shuffle(self, x: Tensor, y: Tensor, generator: torch.Generator) -> tuple[Tensor, Tensor]:
        perm_idx = torch.randperm(x.size(0), generator=generator)
        return x[perm_idx], y[perm_idx]
    
    def split(self, x: Tensor, y: Tensor, test_size: float = 0.1) -> tuple[Tensor, Tensor, Tensor, Tensor]: 
        """
        Split the dataset into train and test sets.
        """
        split_idx = int(test_size * x.size(0))
        return x[split_idx:], y[split_idx:], x[:split_idx], y[:split_idx]


class Kin8mn(UCIDataset):

    DEFAULT_URL = 'https://raw.githubusercontent.com/liusiyan/UQnet/master/datasets/UCI_datasets/kin8nm/dataset_2175_kin8nm.csv'

    def __init__(self, path: str = '../../data/uci/', normalize: bool = True, seed: int | None = None, url: str = DEFAULT_URL):
        super().__init__(name='kin8nm', path=path, normalize=normalize, seed=seed)
        self.url = url 

    def download_data(self) -> None:
        df = pd.read_csv(self.url)
        os.makedirs(self.path, exist_ok=True)
        df.to_csv(self.csv_path, index=False)

In [109]:
from linear_operator.operators import DiagLinearOperator
from abc import ABC, abstractmethod
from gpytorch.distributions import MultivariateNormal


class Projector(torch.nn.Module, ABC):

    @abstractmethod
    def forward(self, x: Tensor, y: Tensor | None = None) -> tuple[Tensor, Tensor] | Tensor:
        pass

    @abstractmethod 
    def inverse(self, mvn: MultivariateNormal) -> MultivariateNormal:
        pass


class IdentityProjector(Projector):

    def __init__(self, *args, **kwargs): 
        super().__init__()

    def forward(self, x: Tensor, y: Tensor | None = None) -> tuple[Tensor, Tensor]:
        if y is None:
            return x
        else:
            return x, y
    
    def inverse(self, mvn: MultivariateNormal) -> MultivariateNormal:
        return mvn


class SphereProjector(torch.nn.Module):
    def __init__(self):
        super().__init__()
        b = 1. + torch.exp(torch.randn(torch.Size([])))
        self.register_parameter('b', torch.nn.Parameter(b))
        self.norm = None 

    def forward(self, x: Tensor, y: Tensor | None = None) -> tuple[Tensor, Tensor] | Tensor:
        b = self.b.expand(*x.shape[:-1], 1)
        x_cat_b = torch.cat([x, b], dim=-1)
        self.norm = x_cat_b.norm(dim=-1, keepdim=True)
        if y is None:
            return x_cat_b / self.norm
        else:
            return x_cat_b / self.norm, y / self.norm.squeeze(-1)
    
    def inverse(self, mvn: MultivariateNormal) -> MultivariateNormal:
        L = DiagLinearOperator(self.norm.squeeze(-1))
        mean = mvn.mean @ L
        cov = L @ mvn.lazy_covariance_matrix @ L
        return MultivariateNormal(mean=mean, covariance_matrix=cov)
    

from typing import Callable


class FunctionalProjector(torch.nn.Module):
    def __init__(self, f: Callable[[Tensor], Tensor]):
        super().__init__()
        self.f = f
        self.norm = None 

    def forward(self, x: Tensor, y: Tensor | None = None) -> tuple[Tensor, Tensor] | Tensor:
        b = self.f(x).expand(*x.shape[:-1], 1)
        x_cat_b = torch.cat([x, b], dim=-1)
        self.norm = x_cat_b.norm(dim=-1, keepdim=True)
        if y is None:
            return x_cat_b / self.norm
        else:
            return x_cat_b / self.norm, y / self.norm.squeeze(-1)
    
    def inverse(self, mvn: MultivariateNormal) -> MultivariateNormal:
        L = DiagLinearOperator(self.norm.squeeze(-1))
        mean = mvn.mean @ L
        cov = L @ mvn.lazy_covariance_matrix @ L
        return MultivariateNormal(mean=mean, covariance_matrix=cov)

In [98]:
from gpytorch import settings 
from torch import no_grad 
from dataclasses import dataclass, field
from tqdm.autonotebook import tqdm 
from gpytorch.metrics import mean_squared_error
from mdgp.experiments.experiment_utils.logging import log 
from gpytorch.mlls import VariationalELBO, DeepApproximateMLL
from gpytorch.models.deep_gps import DeepGP


@dataclass
class FitArguments: 
    num_steps: int = field(default=1000, metadata={'help': 'Number of steps to train for'})
    sample_hidden: str = field(
        default='elementwise', 
        metadata={'help': 'Sampling method from hidden layers. Must be one of ["elementwise", "pathwise"]'}
    )
    batch_size: int = field(default=64, metadata={'help': 'Batch size'})
    lr: float = field(default=0.01, metadata={'help': 'Learning rate'})
    test_num_likelihood_samples: int = field(default=100, metadata={'help': 'Number of likelihood samples for test set evaluation'})
    deep_train_num_likelihood_samples: int = field(default=3, metadata={'help': 'Number of likelihood samples used when training deep models.'})
    optimize_projector: bool = field(default=False, metadata={'help': 'Whether to optimize the projection bias'})

    def train_num_likelihood_samples(self, model) -> int:
        if isinstance(model, DeepGP): 
            return self.deep_train_num_likelihood_samples
        else:
            return 1


def get_mll(model, dataset: UCIDataset): 
    return VariationalELBO(model.likelihood, model, num_data=dataset.train_y.size(0))


def get_optimizer(model, projector: Projector, fit_args: FitArguments):
    params = [
        {'params': model.parameters()},
    ]
    if fit_args.optimize_projector is True: 
        params.append({'params': projector.parameters()})
    return torch.optim.Adam(params, lr=fit_args.lr, maximize=True)


def train_step(model, projector: Projector, dataloader: DataLoader, elbo, optimizer, num_likelihood_samples: int | None = None, prior=False): 
    model.train() 
    total_loss = 0.0
    if num_likelihood_samples is not None: 
        init_num_likelihood_samples = settings.num_likelihood_samples.value() 
        settings.num_likelihood_samples._set_value(num_likelihood_samples)
    for batch_x, batch_y in dataloader:
        # When training we don't rescale, or "invert", the predictive distribution
        # before evaluating the ELBO.
        pbatch_x, pbatch_y = projector(batch_x, batch_y)
        optimizer.zero_grad(set_to_none=True)
        poutputs = model(pbatch_x, prior=prior)
        loss = elbo(poutputs, pbatch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    if num_likelihood_samples is not None:
        settings.num_likelihood_samples._set_value(init_num_likelihood_samples)
    return {'elbo': total_loss / len(dataloader)}


from gpytorch.distributions import MultivariateNormal
from scipy.stats import multivariate_normal


def negative_log_predictive_density(outputs: MultivariateNormal, targets: Tensor) -> Tensor:
    mean, cov = outputs.mean.detach().numpy(), outputs.covariance_matrix.detach().numpy()
    targets = targets.detach().numpy()
    nlpd = -multivariate_normal.logpdf(targets, mean=mean, cov=cov)
    return torch.tensor(nlpd) / targets.shape[0]


def test_log_likelihood(outputs: MultivariateNormal, targets: Tensor, y_std: Tensor) -> Tensor:
    mean, stddev = outputs.mean, outputs.stddev
    logpdf = torch.distributions.Normal(loc=mean, scale=stddev).log_prob(targets) - torch.log(y_std)
    # average over likelihood samples 
    # logpdf = logsumexp(logpdf.numpy(), axis=0, b=1 / mean.size(0))
    # average over data points
    return torch.tensor(logpdf.mean())


def mean_squared_error(outputs: MultivariateNormal, targets: Tensor, y_std: Tensor) -> Tensor:
    mean = outputs.mean.mean(0)
    return y_std ** 2 * ((mean - targets) ** 2).mean(0)


def test(model, projector: Projector, uci_dataset: UCIDataset, fit_args: FitArguments = None, prior=False):
    with no_grad(), settings.num_likelihood_samples(fit_args.test_num_likelihood_samples):
        total_mse = 0.0
        total_tll = 0.0
        total_nlpd = 0.0
        model.eval() 
        dataloader = DataLoader(uci_dataset.test_dataset, batch_size=fit_args.batch_size)
        for batch_x, batch_y in dataloader:
            # When testing we rescale, or "invert", the predictive distribution
            pbatch_x, _ = projector(batch_x, batch_y)
            poutputs = model.likelihood(model(pbatch_x, prior=prior))
            outputs = projector.inverse(poutputs)
            total_mse += mean_squared_error(outputs, batch_y, y_std=uci_dataset.test_y_std).item()
            total_tll += test_log_likelihood(outputs, batch_y, y_std=uci_dataset.test_y_std).item()
            total_nlpd += negative_log_predictive_density(outputs, batch_y).mean(0).item()
        return {
            'rmse': (total_mse / len(dataloader)) ** 0.5,
            'tll': total_tll / len(dataloader),
            'nlpd': total_nlpd / len(dataloader),
        }
    

def fit(model, projector, optimizer, elbo, uci_dataset: UCIDataset, train_loggers=None, fit_args: FitArguments = None, prior=False): 
    metrics = {'elbo': None}
    pbar = tqdm(range(1, fit_args.num_steps + 1), desc="Fitting")
    for step in pbar:
        metrics_step = train_step(
            model=model, 
            projector=projector,
            dataloader=DataLoader(uci_dataset.train_dataset, batch_size=fit_args.batch_size, shuffle=True),
            elbo=elbo,
            optimizer=optimizer,
            num_likelihood_samples=fit_args.train_num_likelihood_samples(model),
            prior=prior
        )
        # Update, log, and display metrics 
        metrics.update(metrics_step)
        log(train_loggers, metrics=metrics, step=step)
        pbar.set_postfix(metrics)
        if step % 10 == 0: 
            print(test(model, projector, uci_dataset, fit_args=fit_args, prior=prior))
            print(f"kappa: {model.variational_strategy.kappa=}, nu: {model.variational_strategy.nu=}, sigma: {model.variational_strategy.sigma=}")
    return model 
    

In [99]:
dataset = Kin8mn()

In [128]:
# Variational distribution 
batch_shape = torch.Size([])

# covar
space = Hypersphere(dim=dataset.dimension)
base_kernel = GeometricMaternKernel(space=space, trainable_nu=False, num_eigenfunctions=5, nu=2.5)
base_kernel.initialize(lengthscale=0.1)
covar_module = ScaleKernel(base_kernel)
covar_module.initialize(outputscale=1.0)

# mean
mean_module = gpytorch.means.ZeroMean()

# likelihood 
likelihood = gpytorch.likelihoods.GaussianLikelihood()

# model
model = SimpleApproximateGP(
    covar_module=covar_module,
    mean_module=mean_module,
    likelihood=likelihood, 
    max_ell=5,
    jitter_val=0.0,
)

In [129]:
torch.random.manual_seed(0)
projector = SphereProjector()

batch_size = dataset.train_y.size(0)
num_likelihood_samples = 3
fit_args = FitArguments(
    num_steps=1000, 
    batch_size=batch_size,
    optimize_projector=False,
    deep_train_num_likelihood_samples=num_likelihood_samples,
)
prior=True

elbo = get_mll(model, dataset)
optimizer = get_optimizer(model, projector, fit_args=fit_args)
optimizer = torch.optim.Adam([p for m, p in model.named_parameters() if not m.endswith('raw_outputscale')], maximize=True, lr=0.01)

In [130]:
with torch.no_grad():
    model.train()
    model(projector(dataset.train_x))

In [131]:
torch.manual_seed(1)
projector = SphereProjector()

with torch.no_grad():
    model.eval()
    x = projector(dataset.test_x)
    K = model.forward(x).covariance_matrix
    # K = model.variational_strategy(x, prior=True).covariance_matrix
    print(torch.linalg.eigvalsh(K)[:100])

tensor([-2.9133e-15, -2.4821e-15, -2.1286e-15, -2.0159e-15, -1.8958e-15,
        -1.7735e-15, -1.6801e-15, -1.6749e-15, -1.6060e-15, -1.5665e-15,
        -1.4716e-15, -1.4037e-15, -1.3444e-15, -1.3232e-15, -1.2868e-15,
        -1.2755e-15, -1.2440e-15, -1.1890e-15, -1.1308e-15, -1.1156e-15,
        -1.0712e-15, -1.0500e-15, -1.0334e-15, -1.0072e-15, -9.6068e-16,
        -9.3174e-16, -9.1668e-16, -8.9751e-16, -8.7090e-16, -8.6363e-16,
        -8.3948e-16, -8.2650e-16, -8.0557e-16, -7.8097e-16, -7.5251e-16,
        -7.3913e-16, -7.1350e-16, -7.0879e-16, -6.7423e-16, -6.5745e-16,
        -6.4198e-16, -6.1202e-16, -6.0459e-16, -5.8388e-16, -5.6523e-16,
        -5.5353e-16, -5.3081e-16, -5.1574e-16, -5.0293e-16, -4.8369e-16,
        -4.7478e-16, -4.6069e-16, -4.4407e-16, -4.3078e-16, -3.9733e-16,
        -3.7383e-16, -3.6706e-16, -3.4725e-16, -3.4166e-16, -3.2276e-16,
        -3.1322e-16, -2.9707e-16, -2.8926e-16, -2.6475e-16, -2.2897e-16,
        -2.2287e-16, -2.1290e-16, -1.9903e-16, -1.8

In [132]:
torch.manual_seed(1)

covar_module = GeometricMaternKernel(
    space=space, trainable_nu=False, num_eigenfunctions=6, nu=2.5
)
covar_module.initialize(lengthscale=0.1)

projector = SphereProjector()

with torch.no_grad():
    model.eval()
    x = projector(dataset.test_x)
    K = covar_module(x).evaluate()
    print(torch.linalg.eigvalsh(K)[:100])

tensor([0.0005, 0.0006, 0.0007, 0.0008, 0.0008, 0.0008, 0.0009, 0.0009, 0.0009,
        0.0010, 0.0010, 0.0010, 0.0010, 0.0011, 0.0011, 0.0011, 0.0012, 0.0012,
        0.0012, 0.0012, 0.0012, 0.0013, 0.0013, 0.0013, 0.0014, 0.0014, 0.0014,
        0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0016, 0.0016, 0.0016, 0.0016,
        0.0017, 0.0017, 0.0018, 0.0018, 0.0018, 0.0018, 0.0019, 0.0019, 0.0019,
        0.0020, 0.0020, 0.0020, 0.0020, 0.0021, 0.0021, 0.0021, 0.0022, 0.0022,
        0.0022, 0.0023, 0.0023, 0.0023, 0.0024, 0.0024, 0.0024, 0.0025, 0.0025,
        0.0025, 0.0026, 0.0026, 0.0027, 0.0027, 0.0027, 0.0028, 0.0028, 0.0028,
        0.0029, 0.0029, 0.0030, 0.0030, 0.0030, 0.0031, 0.0031, 0.0031, 0.0032,
        0.0032, 0.0032, 0.0033, 0.0034, 0.0034, 0.0034, 0.0035, 0.0035, 0.0035,
        0.0036, 0.0036, 0.0037, 0.0037, 0.0037, 0.0038, 0.0038, 0.0038, 0.0039,
        0.0039])


In [14]:
from geometric_kernels.kernels import MaternGeometricKernel


kernel = MaternGeometricKernel(space=space, num=6, normalize=True)
params = kernel.init_params()
params['lengthscale'] = torch.tensor(0.1)
params['nu'] = torch.tensor(2.5)

with torch.no_grad():
    K = kernel.K(params, x)
    print(torch.linalg.eigvalsh(K)[:50])

tensor([0.0005, 0.0006, 0.0007, 0.0008, 0.0008, 0.0008, 0.0009, 0.0009, 0.0009,
        0.0009, 0.0010, 0.0010, 0.0010, 0.0011, 0.0011, 0.0011, 0.0011, 0.0012,
        0.0012, 0.0012, 0.0012, 0.0012, 0.0013, 0.0013, 0.0014, 0.0014, 0.0014,
        0.0014, 0.0014, 0.0014, 0.0015, 0.0015, 0.0015, 0.0016, 0.0016, 0.0016,
        0.0016, 0.0017, 0.0017, 0.0017, 0.0017, 0.0018, 0.0018, 0.0019, 0.0019,
        0.0019, 0.0019, 0.0020, 0.0020, 0.0020])


In [23]:
torch.set_grad_enabled(True)
fit(model, projector, optimizer, elbo, dataset, fit_args=fit_args, prior=False)

Fitting:   0%|          | 0/1000 [00:01<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x10 and 9x7373)

In [92]:
model.variational_strategy.variational_distribution

MultivariateNormal(loc: Parameter containing:
tensor([0.0035], requires_grad=True))

In [15]:
from scipy.stats._multivariate import _PSD 


def test_psd(K): 
    if isinstance(K, LinearOperator):
        K = to_dense(K)
    if isinstance(K, Tensor):
        K = K.clone().detach().numpy()
    _PSD(K)

# For only the 0-th harmonic level, and two test points, manual_seed = 0 gives non-PSD matrix

In [16]:
torch.set_grad_enabled(False)
for i in tqdm(range(10000)):
    try:
        torch.random.manual_seed(i)
        x = torch.randn(2, 9)
        x = x / x.norm(dim=-1, keepdim=True)
        test_psd(model(x).covariance_matrix)
    except Exception as e:
        print(f"Exception at manual seed {i}")
        print(e)
        break

  0%|          | 0/10000 [00:02<?, ?it/s]

Exception at manual seed 0
mat1 and mat2 shapes cannot be multiplied (1x10 and 9x2)





In [103]:
torch.random.manual_seed(0)
x = torch.randn(2, 9)
x = x / x.norm(dim=-1, keepdim=True)
test_psd(model(x).covariance_matrix)

ValueError: The input matrix must be symmetric positive semidefinite.

In [108]:
test_psd(model(x, prior=True).covariance_matrix)
test_psd(model.variational_strategy.prior_distribution.covariance_matrix)
test_psd(model.variational_strategy.variational_distribution.covariance_matrix)

print(model.covar_module.outputscale, model.covar_module.base_kernel.lengthscale, model.covar_module.base_kernel.nu)

tensor(0.4128) tensor([[1.0000]]) tensor([[2.5000]])


In [96]:
torch.linalg.eigvalsh(model(x).covariance_matrix)

tensor([-0.0014,  0.0081])