In [1]:
import torch 
import gpytorch 
import geometric_kernels.torch 
from geometric_kernels.frontends.pytorch.gpytorch import GPytorchGeometricKernel
from geometric_kernels.kernels.geometric_kernels import MaternKarhunenLoeveKernel
from geometric_kernels.spaces import Hypersphere

  from .autonotebook import tqdm as notebook_tqdm
INFO: Using numpy backend


In [2]:
class GeometricMaternKernel(GPytorchGeometricKernel): 
    def __init__(self, space, lengthscale=1.0, nu=2.5, trainable_nu=True, num_eigenfunctions=20, normalize=True, **kwargs): 
        geometric_kernel = MaternKarhunenLoeveKernel(
            space=space, 
            num_eigenfunctions=num_eigenfunctions, 
            normalize=normalize, 
        )
        super().__init__(geometric_kernel, lengthscale=lengthscale, nu=nu, trainable_nu=trainable_nu, **kwargs)

In [4]:
space = Hypersphere(2)
batch_shape = torch.Size([2])
ard_num_dims = None
nu = 2.5
kernel = GeometricMaternKernel(space=space, batch_shape=batch_shape, ard_num_dims=ard_num_dims, nu=nu)
rbf = gpytorch.kernels.MaternKernel(nu=nu, batch_shape=batch_shape, ard_num_dims=ard_num_dims)

x1 = torch.randn(2, 11, 3)
x2 = torch.randn(12, 3)
x3 = torch.randn(3, 7, 3)
x4 = x1.expand(10, *x1.shape)
x5 = x2.expand(10, 2, *x2.shape)

In [6]:
import torch 
from geometric_kernels.spaces import Space
from mdgp.frames import HypersphereFrame


class ProjectToTangentIntrinsic(torch.nn.Module): 
    def __init__(self, space: Space, get_normal_vector=None) -> None: 
        super().__init__()
        self.frame = HypersphereFrame(dim=space.dim, get_normal_vector=get_normal_vector)

    def forward(self, x, coeff): 
        return self.frame.coeff_to_tangent(x=x, coeff=coeff)
    

class ProjectToTangentExtrinsic(torch.nn.Module):
    def __init__(self, space: Space) -> None:
        super().__init__()
        self.manifold = space_to_manifold(space)

    def forward(self, x: torch.Tensor, coeff: torch.Tensor) -> torch.Tensor:
        return self.manifold.proju(x=x, u=coeff)
    

def space_to_manifold(space: Space): 
    from geoopt import Sphere 
    if isinstance(space, Hypersphere): 
        return Sphere(torch.eye(space.dim + 1))
    raise NotImplementedError


class ExponentialMap(torch.nn.Module): 
    def __init__(self, space: Space) -> None:
        super().__init__()
        self.manifold = space_to_manifold(space)

    def forward(self, x, u): 
        return self.manifold.expmap(x=x, u=u)
    

class Retraction(torch.nn.Module): 
    def __init__(self, space: Space) -> None:
        super().__init__()
        self.manifold = space_to_manifold(space)

    def forward(self, x, u): 
        return self.manifold.retr(x=x, u=u)
    


from torch import Tensor 
from gpytorch.distributions import MultivariateNormal
from gpytorch.variational import UnwhitenedVariationalStrategy, VariationalStrategy
from mdgp.samplers import RFFSampler, VISampler, PosteriorSampler, sample_naive


class DeepGPLayer(gpytorch.models.deep_gps.DeepGPLayer):
    def __init__(self, mean_module, covar_module, inducing_points, output_dims, learn_inducing_locations=False, whitened_variational_strategy=True):
        batch_shape = torch.Size([output_dims]) if output_dims is not None else torch.Size([])
        num_inducing_points, input_dims = inducing_points.shape

        # Variational Parameters
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            num_inducing_points=num_inducing_points,
            batch_shape=batch_shape
        )
        variational_strategy_class = VariationalStrategy if whitened_variational_strategy else UnwhitenedVariationalStrategy
        variational_strategy = variational_strategy_class(
            self,
            inducing_points,
            variational_distribution,
            learn_inducing_locations=learn_inducing_locations
        )

        super().__init__(variational_strategy, input_dims, output_dims)
        self.mean_module = mean_module
        self.covar_module = covar_module

    def forward(self, x: Tensor) -> MultivariateNormal:
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)  
    
    def sample_naive(self, inputs, are_samples=False, **kwargs):
        return sample_naive(super().__call__(inputs, are_samples=are_samples, **kwargs))

    def sample_pathwise(self, inputs, are_samples=False):
        raise NotImplementedError
    
    def __call__(self, inputs, are_samples=False, sample=False, mean=False, **kwargs):
        if mean: 
            with gpytorch.settings.num_likelihood_samples(1):
                return super().__call__(inputs, are_samples=are_samples, **kwargs).mean[0]
        if sample is None or sample is False: 
            return super().__call__(inputs, are_samples=are_samples, **kwargs)
        if sample == 'naive':
            return self.sample_naive(inputs=inputs, are_samples=are_samples, **kwargs)
        if sample == 'pathwise': 
            return self.sample_pathwise(inputs=inputs, are_samples=are_samples)
        raise NotImplementedError(f"Expected sample argument to be either 'naive', 'pathwise', False, or None. Got {sample}")
        

class GeometricDeepGPLayer(DeepGPLayer):
    def __init__(
        self, 
        space, 
        num_eigenfunctions: int,
        output_dims: int,
        inducing_points: torch.Tensor,
        nu: float = 2.5, 
        optimize_nu: bool = False,
        feature_map: str = 'deterministic', 
        learn_inducing_locations: bool = False,
        whitened_variational_strategy=False, 
        sampler_inv_jitter=10e-8,
        outputscale_prior=None,
    ) -> None: 
        batch_shape = torch.Size([output_dims]) if output_dims is not None else torch.Size([])

        # Initialize mean and kernel modules 
        mean_module = gpytorch.means.ConstantMean(
            batch_shape=batch_shape,
        )
        base_kernel = GeometricMaternKernel(
            space=space, nu=nu, num_eigenfunctions=num_eigenfunctions, batch_shape=batch_shape, optimize_nu=optimize_nu
        )
        covar_module = gpytorch.kernels.ScaleKernel(
            base_kernel=base_kernel,
            batch_shape=batch_shape,
            outputscale_prior=outputscale_prior,
        )
        if outputscale_prior is not None: 
            covar_module.initialize(outputscale=outputscale_prior.mean)

        super().__init__(mean_module=mean_module, covar_module=covar_module, inducing_points=inducing_points, output_dims=output_dims, learn_inducing_locations=learn_inducing_locations, whitened_variational_strategy=whitened_variational_strategy)

        # Set up posterior sampler. VISampler needs the VariationalDistribution object for that the changing parameters are tracked properly
        rff_sampler = None # RFFSampler(covar_module=covar_module, mean_module=mean_module, feature_map=feature_map)
        vi_sampler = None # VISampler(variational_distribution=self.variational_strategy._variational_distribution)
        self.sampler = None # PosteriorSampler(rff_sampler=rff_sampler, vi_sampler=vi_sampler, inducing_points=inducing_points, whitened_variational_strategy=whitened_variational_strategy, inv_jitter=sampler_inv_jitter)

    def sample_pathwise(self, inputs, are_samples=False):
        # Clear cache if training, since otherwise we risk "trying to backward through the graph a second time" errors 
        if self.training: 
            self.variational_strategy._clear_cache()
        # Maybe initialize variational distribution (Taken from gpytorch.variational._VariationalStrategy.__call__)
        if not self.variational_strategy.variational_params_initialized.item():
            prior_dist = self.variational_strategy.prior_distribution
            self.variational_strategy._variational_distribution.initialize_variational_distribution(prior_dist)
            self.variational_strategy.variational_params_initialized.fill_(1)

        # Take sample 
        sample_shape = torch.Size([gpytorch.settings.num_likelihood_samples.value()])
        if are_samples: # [S, N, D]
            inputs_head_shape = inputs.shape[1:-1]
            inputs = inputs.flatten(start_dim=1, end_dim=-2)
            sample = torch.stack([self.sampler(inputs_, sample_shape=torch.Size([])) for inputs_ in inputs.unbind(0)], dim=0)
        else:
            inputs_head_shape = inputs.shape[:-1]
            inputs = inputs.flatten(start_dim=0, end_dim=-2)
            sample = self.sampler(inputs, sample_shape=sample_shape)

        # Reshape to [S, N, O]
        if sample.dim() == 1: # [*N]
            return sample.reshape(*inputs_head_shape) # [N]
        if sample.dim() == 2: # [S, *N]. Sidenote: [O, *N] cannot happen because num_likelihood_samples is at least 1 
            return sample.reshape(*sample_shape, *inputs_head_shape) # [S, N]
        return sample.mT.reshape(*sample_shape, *inputs_head_shape, -1) # [S, O, *N] -> [S, *N, O] -> [S, N, O]


class EuclideanDeepGPLayer(DeepGPLayer):
    def __init__(
            self, 
            inducing_points, 
            output_dims, 
            mean_type='constant', 
            learn_inducing_locations=False, 
            nu=2.5,
            constant_prior=None, 
            whitened_variational_strategy=True,
            outputscale_prior=None,
        ) -> None:
        batch_shape = torch.Size([output_dims]) if output_dims is not None else torch.Size([])
        input_dims = inducing_points.size(-1)

        # Mean 
        if mean_type == 'constant':
            mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape, constant_prior=constant_prior)
        else:
            mean_module = gpytorch.means.LinearMean(input_dims)

        # Covariance 
        base_kernel = gpytorch.kernels.MaternKernel(nu=nu, batch_shape=batch_shape, ard_num_dims=input_dims)
        covar_module = gpytorch.kernels.ScaleKernel(base_kernel=base_kernel, batch_shape=batch_shape, ard_num_dims=None, outputscale_prior=outputscale_prior)
        if outputscale_prior is not None: 
            covar_module.initialize(outputscale=outputscale_prior.mean)

        super().__init__(mean_module=mean_module, covar_module=covar_module, inducing_points=inducing_points, output_dims=output_dims, learn_inducing_locations=learn_inducing_locations, whitened_variational_strategy=whitened_variational_strategy)


class ManifoldToManifoldDeepGPLayer(torch.nn.Module): 
    def __init__(self, gp, space, project_to_tangent: str = 'intrinsic', tangent_to_manifold: str = 'exp', get_normal_vector='nn'): 
        assert project_to_tangent in {'intrinsic', 'extrinsic'}
        assert tangent_to_manifold in {'exp', 'retr'}
        super().__init__()
        self.gp = gp 

        if project_to_tangent == 'intrinsic': 
            self.project_to_tangent = ProjectToTangentIntrinsic(space=space, get_normal_vector=get_normal_vector)
        if project_to_tangent == 'extrinsic': 
            self.project_to_tangent = ProjectToTangentExtrinsic(space=space)

        if tangent_to_manifold == 'exp': 
            self.tangent_to_manifold = ExponentialMap(space=space)
        if tangent_to_manifold == 'retr': 
            self.tangent_to_manifold = Retraction(space=space)

    def forward(self, x, are_samples=False, return_hidden=False, mean=False, sample='naive'): 
        coeff = self.gp(x, mean=mean, sample=sample, are_samples=are_samples)
        u = self.project_to_tangent(x=x, coeff=coeff)
        y = self.tangent_to_manifold(x=x, u=u)
        if return_hidden: 
            return {'coefficients': coeff, 'tangent': u, 'manifold': y}
        return y


from geometric_kernels.spaces import Space
from mdgp.utils import extrinsic_dimension


class ManifoldDeepGP(gpytorch.models.deep_gps.DeepGP): 

    def __init__(self, hidden_gps, output_gp, space, project_to_tangent='instrinsic', tangent_to_manifold='exp', parametrised_frame=False):
        get_normal_vector = 'nn' if parametrised_frame is True else None
        super().__init__()
        self.hidden_layers = torch.nn.ModuleList([
            ManifoldToManifoldDeepGPLayer(gp=gp, space=space, project_to_tangent=project_to_tangent, tangent_to_manifold=tangent_to_manifold, get_normal_vector=get_normal_vector)
            for gp in hidden_gps
        ])
        self.output_layer = output_gp
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()

    def forward_return_hidden(self, x: Tensor, are_samples: bool = False, sample_hidden: str = 'naive', sample_output=False, mean=False):
        hidden_factors = []
        for hidden_layer in self.hidden_layers: 
            hidden_dict = hidden_layer(x=x, are_samples=are_samples, sample=sample_hidden, mean=mean, return_hidden=True)
            hidden_factors.append(hidden_dict)
            x = hidden_dict['manifold']
            are_samples = False if mean else True 
        y = self.output_layer(x, are_samples=are_samples, sample=sample_output, mean=mean)
        return hidden_factors, y 

    def forward(self, x: Tensor, are_samples: bool = False, sample_hidden: str = 'naive', sample_output=False, mean=False):
        for hidden_layer in self.hidden_layers: 
            x = hidden_layer(x, are_samples=are_samples, sample=sample_hidden, mean=mean)
            are_samples = False if mean else True 
        return self.output_layer(x, are_samples=are_samples, sample=sample_output, mean=mean)
    

class GeometricManifoldDeepGP(ManifoldDeepGP):
    def __init__(
        self,
        space, 
        num_hidden: int,
        inducing_points, # [N, 3]
        output_dims=None, 
        num_eigenfunctions: int = 20, 
        nu: float = 2.5, 
        feature_map = 'deterministic',
        learn_inducing_locations: bool = False, 
        project_to_tangent: str = 'intrinsic',
        tangent_to_manifold: str = 'exp',
        optimize_nu: bool = False,
        whitened_variational_strategy=True, 
        sampler_inv_jitter=10e-8,
        outputscale_prior=None,
        parametrised_frame=False, 
        ) -> None:

        # Dimension of the manifold is the last dimension of the inducing points
        if project_to_tangent == 'intrinsic': 
            hidden_output_dims = space.dim 
        elif project_to_tangent == 'extrinsic': 
            hidden_output_dims = extrinsic_dimension(space)
        else: 
            raise NotImplementedError(f"Expected project_to_tangent either 'intrinsic' or 'extrinsic'. Got {project_to_tangent}.")


        hidden_gps = [
            GeometricDeepGPLayer(
                space=space,
                num_eigenfunctions=num_eigenfunctions,
                output_dims=hidden_output_dims,
                inducing_points=inducing_points,
                nu=nu, 
                feature_map=feature_map,
                learn_inducing_locations=learn_inducing_locations,
                optimize_nu=optimize_nu, 
                whitened_variational_strategy=whitened_variational_strategy,
                sampler_inv_jitter=sampler_inv_jitter, 
                outputscale_prior=outputscale_prior,
            )
            for _ in range(num_hidden)
        ]

        output_gp = GeometricDeepGPLayer(
            space=space,
            num_eigenfunctions=num_eigenfunctions,
            output_dims=output_dims,
            inducing_points=inducing_points,
            nu=nu, 
            feature_map=feature_map,
            learn_inducing_locations=learn_inducing_locations,
            optimize_nu=optimize_nu, 
            whitened_variational_strategy=whitened_variational_strategy,
            sampler_inv_jitter=sampler_inv_jitter,
        )

        super().__init__(hidden_gps=hidden_gps, output_gp=output_gp, project_to_tangent=project_to_tangent, tangent_to_manifold=tangent_to_manifold, space=space, parametrised_frame=parametrised_frame)


class EuclideanManifoldDeepGP(ManifoldDeepGP):

    def __init__(
        self,
        space: Space,
        num_hidden: int,
        inducing_points,
        output_dims=None, 
        nu: float = 2.5, 
        learn_inducing_locations: bool = False, 
        project_to_tangent='intrinsic', 
        tangent_to_manifold='exp',
        outputscale_prior=None,
        parametrised_frame=False,
        ) -> None:
        if project_to_tangent == 'intrinsic': 
            hidden_output_dims = space.dim 
        elif project_to_tangent == 'extrinsic': 
            hidden_output_dims = extrinsic_dimension(space)
        else: 
            raise NotImplementedError(f"Expected project_to_tangent either 'intrinsic' or 'extrinsic'. Got {project_to_tangent}.")

        hidden_gps = [
            EuclideanDeepGPLayer(
                output_dims=hidden_output_dims,
                inducing_points=inducing_points,
                nu=nu, 
                learn_inducing_locations=learn_inducing_locations,
                mean_type='constant',
                outputscale_prior=outputscale_prior,
            )
            for _ in range(num_hidden)
        ]

        output_gp = EuclideanDeepGPLayer(
            output_dims=output_dims,
            inducing_points=inducing_points,
            nu=nu, 
            learn_inducing_locations=learn_inducing_locations,
            mean_type='constant',
        )
        super().__init__(hidden_gps=hidden_gps, output_gp=output_gp, project_to_tangent=project_to_tangent, tangent_to_manifold=tangent_to_manifold, space=space, parametrised_frame=parametrised_frame)


class EuclideanDeepGP:

    def __init__(
        self,
        num_hidden: int,
        inducing_points,
        output_dims = None, 
        nu: float = 2.5, 
        learn_inducing_locations: bool = False, 
        outputscale_prior=None,
        ) -> None:
        super().__init__()
        # Dimension of the manifold is the last dimension of the inducing points
        hidden_output_dims = inducing_points.shape[-1]

        self.hidden_gp_layers = [
            EuclideanDeepGPLayer(
                output_dims=hidden_output_dims,
                inducing_points=inducing_points,
                nu=nu, 
                learn_inducing_locations=learn_inducing_locations,
                mean_type='linear',
                outputscale_prior=outputscale_prior,
            )
            for _ in range(num_hidden)
        ]
        self.output_gp_layer = EuclideanDeepGPLayer(
            output_dims=output_dims,
            inducing_points=inducing_points,
            nu=nu, 
            learn_inducing_locations=learn_inducing_locations,
            mean_type='constant',
        )
        
    def forward(self, inputs: Tensor, are_samples: bool = False, sample_hidden: str = 'naive', sample_output=False, mean=False, **kwargs):
        for gp_layer in self.hidden_gp_layers:
            inputs = gp_layer(inputs, are_samples=are_samples, sample=sample_hidden, mean=mean)
            are_samples = False if mean else True
        return self.output_gp_layer(inputs, are_samples=are_samples, sample=sample_output, mean=mean)
    

In [8]:
from mdgp.utils import sphere_uniform_grid


space = Hypersphere(2)
inducing_points = sphere_uniform_grid(60)

models = [
    GeometricManifoldDeepGP(space=space, num_hidden=h, inducing_points=inducing_points)
    for h in range(4)
]

In [9]:
x = sphere_uniform_grid(400)
torch.set_grad_enabled(False)

for h, model in enumerate(models):
    print(f"Hidden layers: {h}")
    %timeit -n 10 model(x)

Hidden layers: 0
13.4 ms ± 3.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Hidden layers: 1
64.4 ms ± 4.58 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Hidden layers: 2
130 ms ± 9.39 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Hidden layers: 3
173 ms ± 9.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
x = sphere_uniform_grid(400)
torch.set_grad_enabled(True)

for h, model in enumerate(models):
    print(f"Hidden layers: {h}")
    %timeit -n 10 model(x).mean.mean().backward()

Hidden layers: 0
13.2 ms ± 3.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Hidden layers: 1
The slowest run took 13.95 times longer than the fastest. This could mean that an intermediate result is being cached.
338 ms ± 531 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Hidden layers: 2
329 ms ± 20.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Hidden layers: 3


: 