# Load in data
Currently supported datasets: power, protein, kin8nm

In [1]:
# Types 
from torch import Tensor 

# Backends 
import geometric_kernels.torch 


import os
import torch 
import pandas as pd 
from math import comb 
from spherical_harmonics import SphericalHarmonics
from torch.utils.data import TensorDataset, DataLoader, Dataset
from io import BytesIO
from urllib.request import urlopen
from zipfile import ZipFile


torch.set_default_dtype(torch.float64)

In [2]:
def num_harmonics_single(ell: int, d: int) -> int:
    r"""
    Number of spherical harmonics of degree ell on S^{d - 1}.
    """
    if ell == 0:
        return 1
    if d == 3:
        return 2 * ell + 1
    else:
        return (2 * ell + d - 2) * comb(ell + d - 3, ell - 1) // ell


def num_harmonics(ell: Tensor | int, d: int) -> Tensor:
    """
    Vectorized version of num_harmonics_single
    """
    if isinstance(ell, int):
        return num_harmonics_single(ell, d)
    return ell.apply_(lambda e: num_harmonics_single(ell=e, d=d))


def total_num_harmonics(max_ell: int, d: int) -> int:
    """
    Total number of spherical harmonics on S^{d-1} with degree <= max_ell
    """
    return int(sum(num_harmonics(ell=torch.arange(max_ell + 1), d=d)))


def eigenvalue_laplacian(ell: Tensor, d: int) -> Tensor:
    """
    Eigenvalue of the Laplace-Beltrami operator for a spherical harmonic of degree ell on S_{d-1}
    ell: [...]
    d: []
    return: [...]
    """
    return ell * (ell + d - 2)


def unnormalized_matern_spectral_density(n: Tensor, d: int, kappa: Tensor | float, nu: Tensor | float) -> Tensor | float: 
    """
    compute (unnormalized) spectral density of the matern kernel on S_{d-1}
    n: [N]
    d: []
    kappa: [O, 1, 1]
    nu: [O, 1, 1]
    return: [O, 1, N]
    """
    # Squared exponential kernel 
    if torch.all(nu.isinf()):
        exponent = -kappa ** 2 / 2 * eigenvalue_laplacian(ell=n, d=d) # [O, N, 1]
        return torch.exp(exponent)
    # Matern kernel
    else:
        base = (
            2.0 * nu / kappa**2 + # [O, 1, 1]
            eigenvalue_laplacian(ell=n, d=d).unsqueeze(-1) # [N, 1]
        ) # [O, N, 1]
        exponent = -nu - (d - 1) / 2.0 # [O, 1, 1]
        return base ** exponent # [O, N, 1]


def matern_spectral_density_normalizer(d: int, max_ell: int, kappa: Tensor | float, nu: Tensor | float) -> Tensor:
    """
    Normalizing constant for the spectral density of the Matern kernel on S^{d-1}. 
    Depends on kappa and nu. Also depends on max_ell, as truncation of the infinite 
    sum from Karhunen-Loeve decomposition. 
    """
    n = torch.arange(max_ell + 1)
    spectral_values = unnormalized_matern_spectral_density(n=n, d=d, kappa=kappa, nu=nu) # [O, max_ell + 1, 1]
    num_harmonics_per_level = num_harmonics(torch.arange(max_ell + 1), d=d).type(spectral_values.dtype) # [max_ell + 1]
    normalizer = spectral_values.mT @ num_harmonics_per_level # [O, 1, max_ell + 1] @ [max_ell + 1] -> [O, 1]
    return normalizer.unsqueeze(-2) # [O, 1, 1]


def matern_spectral_density(n: Tensor, d: int, kappa: Tensor, nu: Tensor, max_ell: int, sigma: float = 1.0) -> Tensor:
    """
    Spectral density of the Matern kernel on S^{d-1}
    """
    return (
        unnormalized_matern_spectral_density(n=n, d=d, kappa=kappa, nu=nu) / # [O, N, 1]
        matern_spectral_density_normalizer(d=d, max_ell=max_ell, kappa=kappa, nu=nu) * # [O, 1, 1]
        (sigma ** 2)[..., *(None,) * (kappa.ndim - 1)] # [O, 1, 1]
    ) # [O, N, 1] / [O, 1, 1] * [O, 1, 1] -> [O, N, 1]


def matern_ahat(ell: Tensor, d: int, max_ell: int, kappa: Tensor | float, nu: Tensor | float, 
                m: int | None = None, sigma: Tensor | float = 1.0) -> float:
    """
    :math: `\hat{a} = \rho(\ell)` where :math: `\rho` is the spectral density on S^{d-1}
    """
    return matern_spectral_density(n=ell, d=d, kappa=kappa, nu=nu, max_ell=max_ell, sigma=sigma) # [O, N, 1]


def matern_repeated_ahat(max_ell: int, d: int, kappa: Tensor | float, nu: Tensor | float, sigma: Tensor | float = 1.0) -> Tensor:
    """
    Returns a tensor of repeated ahat values for each ell. 
    """
    ells = torch.arange(max_ell + 1) # [max_ell + 1]
    ahat = matern_ahat(ell=ells, d=d, max_ell=max_ell, kappa=kappa, nu=nu, sigma=sigma) # [O, max_ell + 1, 1]
    repeats = num_harmonics(ell=ells, d=d) # [max_ell + 1]
    return torch.repeat_interleave(ahat, repeats=repeats, dim=-2) # [O, num_harmonics, 1]


def matern_Kuu(max_ell: int, d: int, kappa: float, nu: float, sigma: float = 1.0) -> Tensor: 
    """
    Returns the covariance matrix, which is a diagonal matrix with entries 
    equal to inv_ahat of the corresponding ell. 
    """
    return torch.diag(1 / matern_repeated_ahat(max_ell, d, kappa, nu, sigma=sigma).squeeze(-1)) # [O, num_harmonics, num_harmonics]


def spherical_harmonics(x: Tensor, max_ell: int, d: int) -> Tensor: 
    # Make sure that x is at least 2d and flatten it
    x = torch.atleast_2d(x)
    batch_shape, n = x.shape[:-2], x.shape[-2]
    x = x.flatten(0, -2)

    # Get spherical harmonics callable
    f = SphericalHarmonics(dimension=d, degrees=max_ell + 1) # [... * O, N, num_harmonics]

    # Evaluate x and reintroduce batch dimensions
    return f(x).reshape(*batch_shape, n, total_num_harmonics(max_ell, d)) # [..., O, N, num_harmonics]


def matern_Kux(x: Tensor, max_ell: int, d: int) -> Tensor: 
    return spherical_harmonics(x, max_ell=max_ell, d=d).mT # [..., O, num_harmonics, N]


def num_spherical_harmonics_to_degree(num_spherical_harmonics: int, dimension: int) -> tuple[int, int]:
    """
    Returns the minimum degree for which there are at least
    `num_eigenfunctions` in the collection.
    """
    n, degree = 0, 0  # n: number of harmonics, d: degree (or level)
    while n < num_spherical_harmonics:
        n += num_harmonics(d=dimension, ell=degree)
        degree += 1

    if n > num_spherical_harmonics:
        print(
            "The number of spherical harmonics requested does not lead to complete "
            "levels of spherical harmonics. We have thus increased the number to "
            f"{n}, which includes all spherical harmonics up to degree {degree} (incl.)"
        )
    return degree - 1, n


from gpytorch.kernels import Kernel, ScaleKernel


def matern_LT_Phi(x: Tensor, max_ell: int, d: int, kappa: float, nu: float, sigma: float = 1.0) -> Tensor: 
    Kux = matern_Kux(x, max_ell=max_ell, d=d) # [..., O, num_harmonics, N]
    ahat_sqrt = matern_repeated_ahat(max_ell=max_ell, d=d, kappa=kappa, nu=nu, sigma=sigma).sqrt() # [O, num_harmonics, 1]
    return Kux * ahat_sqrt # [..., O, num_harmonics, N]


def matern_LT_Phi_from_kernel(x: Tensor, covar_module: Kernel, num_levels: int) -> Tensor: 
    # Extract kernel parameters  
    if isinstance(covar_module, ScaleKernel):
        sigma = covar_module.outputscale.sqrt()
        base_kernel = covar_module.base_kernel
    else:
        sigma = torch.tensor(1.0, dtype=x.dtype, device=x.device)
        base_kernel = covar_module
    kappa = base_kernel.lengthscale
    nu = base_kernel.nu 

    # Extract constants 
    d = base_kernel.space.dimension + 1
    max_ell = num_levels
    
    return matern_LT_Phi(x, max_ell=max_ell, d=d, kappa=kappa, nu=nu, sigma=sigma)

In [3]:
class UCIDataset:

    UCI_BASE_URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/'

    def __init__(self, name: str, url: str, path: str = '../../data/uci/', normalize: bool = True, seed: int | None = None): 
        self.name = name 
        self.url = url
        self.path = path 
        self.csv_path = os.path.join(self.path, self.name + '.csv')

        # Set generator if seed is provided.
        self.generator = torch.Generator()
        if seed is not None: 
            self.generator.manual_seed(seed)

        # Load, shuffle, split, and normalize data. TODO except for load, these don't need to be object methods. 
        # We keep the standard deviation of the test set for log-likelihood evaluation.
        x, y = self.load_data()
        x, y = self.shuffle(x, y, generator=self.generator)     
        self.train_x, self.train_y, self.test_x, self.test_y = self.split(x, y)
        self.test_y_std = self.test_y.std(dim=0, keepdim=True)
        self.train_x, self.train_y, self.test_x, self.test_y = map(
            self.normalize, (self.train_x, self.train_y, self.test_x, self.test_y))

    @property
    def dimension(self) -> int:
        return self.train_x.shape[-1]

    @property 
    def train_dataset(self) -> Dataset:
        return TensorDataset(self.train_x, self.train_y)
    
    @property
    def test_dataset(self) -> Dataset:
        return TensorDataset(self.test_x, self.test_y)

    def read_data(self) -> tuple[Tensor, Tensor]:
        xy = torch.from_numpy(pd.read_csv(self.csv_path).values)
        return xy[:, :-1], xy[:, -1]

    def download_data(self) -> None:
        NotImplementedError

    def load_data(self, overwrite: bool = False) -> tuple[Tensor, Tensor]:
        if overwrite or not os.path.isfile(self.csv_path):
            self.download_data()
        return self.read_data()

    def normalize(self, x: Tensor) -> Tensor:
        return (x - x.mean(dim=0)) / x.std(dim=0, keepdim=True)
    
    def shuffle(self, x: Tensor, y: Tensor, generator: torch.Generator) -> tuple[Tensor, Tensor]:
        perm_idx = torch.randperm(x.size(0), generator=generator)
        return x[perm_idx], y[perm_idx]
    
    def split(self, x: Tensor, y: Tensor, test_size: float = 0.1) -> tuple[Tensor, Tensor, Tensor, Tensor]: 
        """
        Split the dataset into train and test sets.
        """
        split_idx = int(test_size * x.size(0))
        return x[split_idx:], y[split_idx:], x[:split_idx], y[:split_idx]


class Kin8mn(UCIDataset):

    DEFAULT_URL = 'https://raw.githubusercontent.com/liusiyan/UQnet/master/datasets/UCI_datasets/kin8nm/dataset_2175_kin8nm.csv'

    def __init__(self, path: str = '../../data/uci/', normalize: bool = True, seed: int | None = None, url: str | None = None):
        url = url or Kin8mn.DEFAULT_URL
        super().__init__(name='kin8nm', path=path, normalize=normalize, url=url, seed=seed)

    def download_data(self) -> None:
        df = pd.read_csv(self.url)
        os.makedirs(self.path, exist_ok=True)
        df.to_csv(self.csv_path, index=False)


class Power(UCIDataset):

    DEFAULT_URL = UCIDataset.UCI_BASE_URL + "00294/CCPP.zip"

    def __init__(self, path: str = '../../data/uci/', normalize: bool = True, seed: int | None = None, url: str | None = None):
        url = url or Power.DEFAULT_URL
        super().__init__(name='power', path=path, normalize=normalize, url=url, seed=seed)

    def download_data(self):
        with urlopen(self.url) as zipresp:
            with ZipFile(BytesIO(zipresp.read())) as zfile:
                zfile.extractall('/tmp/')

        df = pd.read_excel('/tmp/CCPP//Folds5x2_pp.xlsx')
        os.makedirs(self.path, exist_ok=True)
        df.to_csv(self.csv_path, index=False)


class Concrete(UCIDataset):

    DEFAULT_URL = UCIDataset.UCI_BASE_URL + 'concrete/compressive/Concrete_Data.xls'

    def __init__(self, path: str = '../../data/uci/', normalize: bool = True, seed: int | None = None, url: str | None = None):
        url = url or Concrete.DEFAULT_URL
        super().__init__(name='concrete', path=path, normalize=normalize, url=url, seed=seed)

    def download_data(self):
        df = pd.read_excel(self.url)
        os.makedirs(self.path, exist_ok=True)
        df.to_csv(self.csv_path, index=False)


default_lr = {
    'kin8nm': 0.01,
    'power': 0.01,
    'concrete': 0.01,
}

default_num_epochs = {
    'kin8nm': 20,
    'power': 20,
    'concrete': 125,
}

dimension_to_prior_num_eigenfunctions = {
    4: 336,
    6: 294,
    8: 210, 
}

dimension_to_num_inducing = {
    4: 336,
    6: 294,
    8: 210, 
}

# Instantiate model

This is done using the same model arguments as for the benchmarking experiment

In [4]:
from linear_operator.operators import DiagLinearOperator
from abc import ABC, abstractmethod
from gpytorch.distributions import MultivariateNormal


class Projector(torch.nn.Module, ABC):

    @abstractmethod
    def forward(self, x: Tensor, y: Tensor | None = None) -> tuple[Tensor, Tensor] | Tensor:
        pass

    @abstractmethod 
    def inverse(self, mvn: MultivariateNormal) -> MultivariateNormal:
        pass


class IdentityProjector(Projector):

    def __init__(self, *args, **kwargs): 
        super().__init__()

    def forward(self, x: Tensor, y: Tensor | None = None) -> tuple[Tensor, Tensor]:
        if y is None:
            return x
        else:
            return x, y
    
    def inverse(self, mvn: MultivariateNormal) -> MultivariateNormal:
        return mvn


class ConstantProjector(Projector):
    def __init__(self, b: float = 2.0):
        super().__init__()
        self.b = torch.nn.Parameter(torch.tensor(b))
        self.norm = None 

    def forward(self, x: Tensor, y: Tensor | None = None) -> tuple[Tensor, Tensor] | Tensor:
        b = self.b.expand(*x.shape[:-1], 1)
        x_cat_b = torch.cat([x, b], dim=-1)
        self.norm = x_cat_b.norm(dim=-1, keepdim=True)
        if y is None:
            return x_cat_b / self.norm
        else:
            return x_cat_b / self.norm, y / self.norm.squeeze(-1)
    
    def inverse(self, mvn: MultivariateNormal) -> MultivariateNormal:
        L = DiagLinearOperator(self.norm.squeeze(-1))
        mean = mvn.mean @ L
        cov = L @ mvn.lazy_covariance_matrix @ L
        return MultivariateNormal(mean=mean, covariance_matrix=cov)

In [5]:
import gpytorch 
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution


class SphericalHarmonicsVariationalStrategy(gpytorch.Module):
    def __init__(self, model: gpytorch.Module, variational_distribution: CholeskyVariationalDistribution, num_levels, jitter_val: float | None = None):
        super().__init__()
        object.__setattr__(self, "_model", model)
        self._variational_distribution = variational_distribution
        self.jitter_val = jitter_val or gpytorch.settings.cholesky_jitter.value(variational_distribution.dtype)
        self.num_levels = num_levels

    @property
    def model(self) -> gpytorch.Module:
        return self._model
    
    @property
    def variational_distribution(self) -> MultivariateNormal:
        return self._variational_distribution()
    
    @property
    def prior_distribution(self) -> MultivariateNormal:
        zeros = torch.zeros(
            self._variational_distribution.shape(),
            dtype=self._variational_distribution.dtype,
            device=self._variational_distribution.device,
        )
        ones = torch.ones_like(zeros)
        res = MultivariateNormal(zeros, DiagLinearOperator(ones))
        return res

    def forward(self, x) -> MultivariateNormal:
        # prior at x 
        px = self.model.forward(x)
        mu_x, Kxx = px.mean, px.lazy_covariance_matrix

        # whitened prior at u
        Linv_pu = self.prior_distribution
        Linv_mu, Linv_Kuu_LTinv = Linv_pu.mean, Linv_pu.lazy_covariance_matrix

        # whitened variational posterior at u
        Linv_qu = self.variational_distribution
        Linv_m, Linv_S_LTinv = Linv_qu.mean, Linv_qu.lazy_covariance_matrix

        # unwhitening + projection and vice-versa
        LT_Phiux = matern_LT_Phi_from_kernel(x, self.model.covar_module, num_levels=self.num_levels)
        Phixu_L = LT_Phiux.mT

        # posterior at x 
        qx_sigma = Kxx + Phixu_L @ (Linv_S_LTinv - Linv_Kuu_LTinv) @ LT_Phiux
        qx_sigma = qx_sigma.add_jitter(self.jitter_val)
        qx_mu = mu_x + Phixu_L @ (Linv_m - Linv_mu)
        return MultivariateNormal(qx_mu, qx_sigma)
    
    def kl_divergence(self):
        with gpytorch.settings.max_preconditioner_size(0):
            kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
        return kl_divergence
    
    def __call__(self, x, prior: bool = False, **kwargs) -> MultivariateNormal:
        if prior:
            return self.model.forward(x)
        return super().__call__(x, **kwargs)

In [6]:
from geometric_kernels.spaces import Space
from gpytorch.variational import VariationalStrategy
from geometric_kernels.spaces import Hypersphere
from mdgp.utils.spherical_harmonic_features import num_spherical_harmonics_to_degree
from mdgp.variational.inducing_points import initialize_kmeans



class VariationalStrategyFactory:
    def __init__(self, name: str, num_inducing: int, inputs: Tensor | None = None, learn_inducing_locations: bool = False) -> None: 
        self.name = name 
        self.learn_inducing_locations = learn_inducing_locations
        self.num_inducing_variables = num_inducing

        if name == 'points': 
            assert inputs is not None, 'Must provide inputs to use points variational strategy.'
        self.inputs = inputs 

    @staticmethod
    def make_variational_strategy(
        name: str, model: ApproximateGP, space: Space, num_inducing_variables: int, inputs: Tensor, 
        learn_inducing_locations: bool = False, batch_shape: torch.Size = torch.Size([]), jitter_val: float | None = None
    ) -> VariationalStrategy | SphericalHarmonicsVariationalStrategy: 
        if name == 'points': 
            inducing_points = initialize_kmeans(x=inputs, n=num_inducing_variables, space=space)
            variational_distribution = CholeskyVariationalDistribution(
                num_inducing_points=num_inducing_variables,
                batch_shape=batch_shape,
            )
            variational_strategy = VariationalStrategy(
                model=model, inducing_points=inducing_points, 
                variational_distribution=variational_distribution, 
                learn_inducing_locations=learn_inducing_locations,
            )
        elif name == 'harmonics': 
            assert isinstance(space, Hypersphere), f'Harmonic features only implemented for hyperspheres, not {space}.'
            dimension = space.dim + 1
            degree, num_spherical_harmonics = num_spherical_harmonics_to_degree(
                num_spherical_harmonics=num_inducing_variables, dimension=dimension
            )
            variational_distribution = CholeskyVariationalDistribution(
                num_inducing_points=num_spherical_harmonics,
                batch_shape=batch_shape,
            )
            variational_strategy = SphericalHarmonicsVariationalStrategy(
                model=model, 
                variational_distribution=variational_distribution, 
                num_levels=degree,
                jitter_val=jitter_val,
            )
        else: 
            raise ValueError(f'Variational strategy {name} not recognized. Must be one of ["points", "harmonics"].')
        
        return variational_strategy
    
    def __call__(self, model: ApproximateGP, space: Space, batch_shape: torch.Size = torch.Size([])) -> VariationalStrategy:
        return self.make_variational_strategy(
            name=self.name, model=model, space=space, num_inducing_variables=self.num_inducing_variables, 
            inputs=self.inputs, learn_inducing_locations=self.learn_inducing_locations, batch_shape=batch_shape,
        )


INFO: Using numpy backend


In [101]:
# Imports 
from dataclasses import dataclass, field
from mdgp.models import *
from geometric_kernels.spaces import Hypersphere, Euclidean
from gpytorch.priors import GammaPrior



@dataclass
class ModelArguments:
    dataset: UCIDataset

    # deep model specs 
    model_name: str = field(default='residual_geometric', 
                            metadata={'help': 'Name of the model. Must be one of ["residual_geometric", "euclidean", "geometric_head"]'})
    num_hidden: int = field(default=1, metadata={'help': 'Number of hidden layers'})
    hidden_dims: int | None = field(default=None, metadata={'help': 'Number of output dimensions of the hidden layers.'})
    output_dims: int | None = field(default=None, metadata={'help': 'Number of output dimensions of the final layer.'})
    to_tangent: str = field(default='project', init=False, repr=False)

    # kernel specs 
    nu: float = field(default=1.5, metadata={'help': 'Smoothness parameter'})
    optimize_nu: bool = field(default=True, metadata={'help': 'Whether to optimize the smoothness parameter'})
    lengthscale: float = field(default=1.0, metadata={'help': 'Lengthscale of the kernel'})
    outputscale_mean: float = field(default=1.0, metadata={'help': 'Mean of the outputscale'})
    prior_num_eigenfunctions: int | None = field(default=None)

    # variational specs 
    variational_strategy_name: str = field(default='points', metadata={'help': 'Name of the variational strategy. Must be one of ["points", "harmonics"]'})
    learn_inducing_locations: bool = field(default=False, metadata={'help': 'Whether to learn the inducing locations'})
    num_inducing: int | None = field(default=None)

    # sampler specs
    sampler_inv_jitter: None = field(default=None, init=False, repr=False)


    def __post_init__(self):
        assert self.model_name in ['residual_geometric', 'residual_euclidean', 'euclidean', 'geometric_head', 'exact']
        assert self.variational_strategy_name in ['points', 'harmonics']

        if self.variational_strategy_name == 'harmonics': 
            assert self.model_name in ['residual_geometric', 'residual_euclidean', 'geometric_head', 'exact']

        # If number of eigenfunctions is None, then use the default for the dataset dimension
        self.num_inducing = self.num_inducing or dimension_to_num_inducing[self.dataset.dimension]
        self.prior_num_eigenfunctions = self.prior_num_eigenfunctions or dimension_to_prior_num_eigenfunctions[self.dataset.dimension]

        # Set the space for the model
        if self.model_name in ['residual_geometric', 'geometric_head']:
            self.space = Hypersphere(dim=self.dataset.dimension)
        if self.model_name in ['euclidean']:
            self.space = Euclidean(dim=self.dataset.dimension)
    
    @property
    def outputscale_prior(self) -> GammaPrior:
        return GammaPrior(concentration=1.0, rate=1 / self.outputscale_mean)
    

def get_projector(model_args: ModelArguments) -> Projector:
    if model_args.model_name in ['residual_geometric', 'geometric_head']:
        return ConstantProjector()
    else:
        return IdentityProjector()


def create_model(model_args: ModelArguments, dataset: UCIDataset, projector: Projector | None = None):
    inputs = dataset.train_x
    degree = num_spherical_harmonics_to_degree(model_args.prior_num_eigenfunctions, model_args.space.dim + 1)[0] + 1
    if projector is not None:
        inputs = projector(inputs)
    variational_strategy_factory = VariationalStrategyFactory(
        name=model_args.variational_strategy_name, 
        num_inducing=model_args.num_inducing,
        learn_inducing_locations=model_args.learn_inducing_locations,
        inputs=inputs,
    )
    if model_args.model_name == 'residual_geometric':
        return ResidualGeometricDeepGP(
            space=model_args.space, 
            num_hidden=model_args.num_hidden,
            variational_strategy_factory=variational_strategy_factory,
            output_dims=model_args.output_dims,
            to_tangent=model_args.to_tangent,
            nu=model_args.nu,
            optimize_nu=model_args.optimize_nu,
            outputscale_prior=model_args.outputscale_prior,
            num_eigenfunctions=degree,
            sampler_inv_jitter=model_args.sampler_inv_jitter,
        )
    if model_args.model_name == 'geometric_head':
        return GeometricHeadDeepGP(
            space=model_args.space, 
            num_hidden=model_args.num_hidden,
            variational_strategy_factory=variational_strategy_factory,
            output_dims=model_args.output_dims,
            nu=model_args.nu,
            optimize_nu=model_args.optimize_nu,
            outputscale_prior=model_args.outputscale_prior,
            num_eigenfunctions=degree,
            sampler_inv_jitter=model_args.sampler_inv_jitter,
        )
    if model_args.model_name == 'euclidean': 
        return EuclideanDeepGP(
            space=model_args.space,
            num_hidden=model_args.num_hidden,
            variational_strategy_factory=variational_strategy_factory,
            output_dims=model_args.output_dims,
            nu=model_args.nu,
            outputscale_prior=model_args.outputscale_prior,
            num_eigenfunctions=degree,
            sampler_inv_jitter=model_args.sampler_inv_jitter,
        )
    raise ValueError((
        f'Unknown model name: {model_args.model_name}.'
        f'Must be one of ["residual_geometric", "euclidean", "geometric_head"]'
    ))

# Train model
Training has to be done differently from the bechmarking experiment, because we need minibatch SGD with the larger datasets and minibatch metrics. 

In [102]:
from linear_operator.operators import LinearOperator


def test_psd(S: LinearOperator | Tensor, tol=1e-8):
    if isinstance(S, LinearOperator):
        S = S.to_dense()

    if S.ndim > 2:
        S = S.flatten(end_dim=-3)
        for S_row in S:
            test_psd(S_row, tol=tol)
        return

    assert torch.allclose(S, S.mT), "K should be symmetric."

    eigs = torch.linalg.eigvalsh(S)
    assert (eigs > -tol).all(), f"K should be positive definite. K has shape {S.shape}"

In [103]:
from gpytorch import settings 
from torch import no_grad 
from dataclasses import dataclass, field
from tqdm.autonotebook import tqdm 
from gpytorch.metrics import mean_squared_error
from mdgp.experiments.experiment_utils.logging import log 
from gpytorch.mlls import VariationalELBO, DeepApproximateMLL


@dataclass
class FitArguments: 
    num_epochs: int = field(default=20, metadata={'help': 'Number of epochs to train for'})
    sample_hidden: str = field(default='elementwise', init=False, repr=False)
    batch_size: int = field(default=64, metadata={'help': 'Batch size'})
    lr: float = field(default=0.01, init=False, repr=False)
    test_num_samples: int = field(default=100, metadata={'help': 'Number of likelihood samples for test set evaluation'})
    train_num_samples: int = field(default=3, metadata={'help': 'Number of likelihood samples used when training deep models.'})
    optimize_projector: bool = field(default=False, metadata={'help': 'Whether to optimize the projection bias'})


def get_mll(model, dataset: UCIDataset): 
    return DeepApproximateMLL(VariationalELBO(model.likelihood, model, num_data=dataset.train_y.size(0)))


def get_optimizer(model, projector: Projector, fit_args: FitArguments):
    params = [
        {'params': model.parameters()},
    ]
    if fit_args.optimize_projector is True: 
        params.append({'params': projector.parameters()})
    return torch.optim.Adam(params, lr=fit_args.lr, maximize=True)


def train_step(x, y, model, projector, elbo, optimizer, sample_hidden):
    optimizer.zero_grad(set_to_none=True)
    x, y = projector(x, y)
    y_hat = model(x, sample_hidden=sample_hidden)

    test_psd(y_hat.lazy_covariance_matrix)

    loss = elbo(y_hat, y)
    loss.backward()
    optimizer.step()
    return loss.item()


def train_epoch(model, projector: Projector, dataloader: DataLoader, elbo, optimizer, sample_hidden, num_samples): 
    model.train()
    epoch_loss = 0.0
    with settings.num_likelihood_samples(num_samples):
        for batch_x, batch_y in dataloader:
            epoch_loss += train_step(batch_x, batch_y, model, projector, elbo, optimizer, sample_hidden)
    return epoch_loss 


from gpytorch.metrics import negative_log_predictive_density
from scipy.special import logsumexp


def test_log_likelihood(outputs: MultivariateNormal, targets: Tensor, y_std: Tensor) -> Tensor:
    mean, stddev = outputs.mean, outputs.stddev
    logpdf = torch.distributions.Normal(loc=mean, scale=stddev).log_prob(targets) - torch.log(y_std)
    # average over likelihood samples 
    logpdf = logsumexp(logpdf.numpy(), axis=0, b=1 / mean.size(0))
    # average over data points
    return torch.mean(torch.from_numpy(logpdf))


def mean_squared_error(outputs: MultivariateNormal, targets: Tensor, y_std: Tensor) -> Tensor:
    # TODO add handling for multiple samples 
    mean = outputs.mean.mean(0)
    return y_std ** 2 * ((mean - targets) ** 2).mean(0)


def test(model, projector: Projector, uci_dataset: UCIDataset, sample_hidden='elementwise', fit_args: FitArguments = None):
    with no_grad(), settings.num_likelihood_samples(fit_args.test_num_samples):
        total_mse = 0.0
        total_tll = 0.0
        total_nlpd = 0.0
        model.eval() 
        dataloader = DataLoader(uci_dataset.test_dataset, batch_size=fit_args.batch_size)
        for batch_x, batch_y in dataloader:
            # When testing we rescale, or "invert", the predictive distribution
            batch_x = projector(batch_x)
            batch_y_hat = model.likelihood(model(batch_x, sample_hidden=sample_hidden))
            batch_y_hat = projector.inverse(batch_y_hat)
            total_mse += mean_squared_error(batch_y_hat, batch_y, y_std=uci_dataset.test_y_std).item()
            total_tll += test_log_likelihood(batch_y_hat, batch_y, y_std=uci_dataset.test_y_std).item()
            total_nlpd += negative_log_predictive_density(batch_y_hat, batch_y).mean(0).item()
        return {
            'mse': total_mse / len(dataloader),
            'rmse': (total_mse / len(dataloader)) ** 0.5,
            'tll': total_tll / len(dataloader),
            'nlpd': total_nlpd / len(dataloader),
        }
    

def fit(model, projector, optimizer, elbo, uci_dataset: UCIDataset, train_loggers=None, fit_args: FitArguments = None): 
    metrics = {'elbo': None}
    pbar = tqdm(range(1, fit_args.num_epochs + 1), desc="Fitting")
    for step in pbar:
        metrics['elbo'] = train_epoch(
            model=model, 
            projector=projector,
            dataloader=DataLoader(uci_dataset.train_dataset, batch_size=fit_args.batch_size, shuffle=True),
            elbo=elbo,
            optimizer=optimizer,
            sample_hidden=fit_args.sample_hidden,
            num_samples=fit_args.train_num_samples,
        )
        # Update, log, and display metrics 
        log(train_loggers, metrics=metrics, step=step)
        pbar.set_postfix(metrics)
    return model 

# Test Euclidean deep GP

In [122]:
dataset = Kin8mn(normalize=False)
model_args = ModelArguments(
    dataset=dataset,
    model_name='euclidean',
    optimize_nu=False,
    learn_inducing_locations=False, 
    num_hidden=0, 
    variational_strategy_name='points',
    outputscale_mean=1.,
    num_inducing=100,
    nu=torch.inf,
)
fit_args = FitArguments(
    num_epochs=50,
    batch_size=512,
    test_num_samples=1,
    train_num_samples=1,
    optimize_projector=False,
)

In [123]:
projector = get_projector(model_args)
model = create_model(model_args, dataset, projector)
elbo = get_mll(model, dataset)
optimizer = get_optimizer(model, projector, fit_args)

fit(model, projector, optimizer, elbo, dataset, fit_args=fit_args)
test(model, projector, dataset, fit_args=fit_args)

Fitting: 100%|██████████| 50/50 [00:20<00:00,  2.43it/s, elbo=-13.8]


{'mse': 0.021189513122918436,
 'rmse': 0.14556618124728846,
 'tll': 0.5019268527251693,
 'nlpd': 0.810272167794537}

# Test shallow GP with spherical harmonic variational inference

In [107]:
dataset = Kin8mn()
model_args = ModelArguments(
    dataset=dataset,
    model_name='residual_geometric',
    optimize_nu=False,
    learn_inducing_locations=False, 
    num_hidden=0, 
    variational_strategy_name='harmonics',
    outputscale_mean=1.,
)
fit_args = FitArguments(
    num_epochs=20,
    batch_size=256,
    test_num_samples=1,
    train_num_samples=1,
    optimize_projector=False,
)

In [108]:
projector = get_projector(model_args)
model = create_model(model_args, dataset, projector)
elbo = get_mll(model, dataset)
optimizer = get_optimizer(model, projector, fit_args)

In [109]:
fit(model, projector, optimizer, elbo, dataset, fit_args=fit_args)
test(model, projector, dataset, fit_args=fit_args)

Fitting: 100%|██████████| 20/20 [01:36<00:00,  4.80s/it, elbo=15.1]


{'mse': 0.014120858337625566,
 'rmse': 0.11883121785804253,
 'tll': 0.731592262108463,
 'nlpd': 0.6019806334492757}

In [35]:
def train_step(x, y, model, projector, optimizer, mll) -> float:
    optimizer.zero_grad(set_to_none=True)
    x, y = projector(x, y)
    output = model(x)
    test_psd(output.lazy_covariance_matrix)
    loss = mll(output, y)
    loss.backward()
    optimizer.step()
    return loss.item()


def train(dataset, model, projector, num_epochs=20, lr=0.01) -> list[float]: 
    # optimizer and criterion
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, maximize=True)
    mll = gpytorch.mlls.VariationalELBO(model.likelihood, model, dataset.train_y.size(0))

    # data
    train_loader = DataLoader(dataset.train_dataset, batch_size=256, shuffle=True)

    # Training loop
    losses = []
    pbar = tqdm(range(num_epochs))
    for epoch in pbar:
        epoch_loss = 0
        for x_batch, y_batch in train_loader:
            loss = train_step(x=x_batch, y=y_batch, model=model, projector=projector, optimizer=optimizer, mll=mll)
            epoch_loss += loss
        losses.append(epoch_loss)
        pbar.set_postfix({'ELBO': losses[-1]})

    return losses 


def evaluate(dataset, model, projector):
    with torch.no_grad():
        test_x, test_y = projector(dataset.test_x), dataset.test_y
        out = model.likelihood(model(test_x))
        out = projector.inverse(out)
        nlpd = negative_log_predictive_density(out, test_y)
        mse = mean_squared_error(out, test_y)
        metrics = {
            'nlpd': nlpd.item(), 
            'mse': mse.item(),
        }
        print(f"NLPD: {metrics['nlpd']}, MSE: {metrics['mse']}")
    return metrics 


def get_model_and_projector(dataset: UCIDataset):
    sphere_dimension = dataset.dimension + 1

    # number of levels for variational inference 
    num_spherical_harmonics = dimension_to_num_inducing[dataset.dimension]
    max_ell, _ = num_spherical_harmonics_to_degree(num_spherical_harmonics, sphere_dimension)

    # number of levels for prior
    num_spherical_harmonics_prior = dimension_to_prior_num_eigenfunctions[dataset.dimension]
    max_ell_prior, _ = num_spherical_harmonics_to_degree(num_spherical_harmonics_prior, sphere_dimension)

    model = gpytorchSGP(max_ell=max_ell, d=sphere_dimension, max_ell_prior=max_ell_prior, kappa=1.0, nu=1.5, optimize_nu=False, batch_shape=torch.Size([]))
    projector = ConstantProjector()
    return model, projector


def reproduce_results(dataset, num_runs: int = 5, num_epochs=20, lr=0.01):
    print(f"Reproducing results for {dataset.name}".center(80, '-') + '\n')

    metrics = []
    for run in range(num_runs):
        print(f"Run {run + 1}".center(80, '-'))

        torch.random.manual_seed(run)
        model, projector = get_model_and_projector(dataset)
        train(dataset, model, projector, num_epochs=num_epochs, lr=lr)
        run_metrics = evaluate(dataset, model, projector)
        metrics.append(run_metrics)
    df = pd.DataFrame(metrics)

    print("Metrics mean".center(80, '-'))
    print(df.mean())

    print("Metrics STD".center(80, '-'))
    print(df.std())

    return df 

In [36]:
reproduce_results(dataset)

-------------------------Reproducing results for kin8nm-------------------------

-------------------------------------Run 1--------------------------------------


  5%|▌         | 1/20 [00:05<01:46,  5.59s/it, ELBO=-37.3]


KeyboardInterrupt: 