In [16]:
import plotly.io as pio
from plotly import graph_objects as go
pio.templates.default = "plotly_dark"

import os 
os.environ['GEOMSTATS_BACKEND'] = 'pytorch'

import geomstats._backend as gs 
import geometric_kernels.torch 

In [17]:
import torch 
from geometric_kernels.frontends.pytorch.gpytorch import GPytorchGeometricKernel
from geometric_kernels.kernels.geometric_kernels import MaternFeatureMapKernel
from geometric_kernels.spaces import Hyperbolic, SymmetricPositiveDefiniteMatrices, DiscreteSpectrumSpace, Space
from geometric_kernels.kernels.feature_maps import deterministic_feature_map_compact, rejection_sampling_feature_map_spd, rejection_sampling_feature_map_hyperbolic


def get_feature_map(space, num_random_phases, num_eigenfunctions):
    if isinstance(space, Hyperbolic): 
        return rejection_sampling_feature_map_hyperbolic(space=space, num_random_phases=num_random_phases)
    if isinstance(space, SymmetricPositiveDefiniteMatrices):
        return rejection_sampling_feature_map_spd(space=space, num_random_phases=num_random_phases)
    if isinstance(space, DiscreteSpectrumSpace):
        return deterministic_feature_map_compact(space=space, num_eigenfunctions=num_eigenfunctions)
    raise NotImplementedError(f"Feature map for space {space} not implemented")


class GeometricMaternKernel(GPytorchGeometricKernel): 
    def __init__(
            self, 
            space: Space, 
            lengthscale=1.0, 
            nu=2.5, 
            trainable_nu=True, 
            num_eigenfunctions=35, 
            num_random_phases=3000,
            **kwargs
        ): 
        feature_map = get_feature_map(space=space, num_eigenfunctions=num_eigenfunctions, num_random_phases=num_random_phases)
        key = torch.Generator() # torch random state

        geometric_kernel = MaternFeatureMapKernel(
            space=space, 
            feature_map=feature_map,
            key=key,
            normalize=True, 
        )
        super().__init__(geometric_kernel, lengthscale=lengthscale, nu=nu, trainable_nu=trainable_nu, **kwargs)

In [18]:
import gpytorch 
import math 


class RFFSampler(torch.nn.Module):

    def __init__(self, covar_module, mean_module, feature_map, num_features=None) -> None: 
        super().__init__()

        # Is covar_module as ScaleKernel? 
        assert isinstance(covar_module, gpytorch.kernels.ScaleKernel), "RFFSampler only implemented for ScaleKernel."
        self.covar_module = covar_module
        self.mean_module = mean_module
        self.feature_map = feature_map
        self._num_features = num_features # currently have to pass in num_features, but could just learn it from first use

        self.base_kernel = covar_module.base_kernel
        self.geometric_kernel = self.base_kernel.geometric_kernel
        self.space = self.geometric_kernel.space
        
        self._weights = None
        # self._num_features = sum(self.geometric_kernel.eigenfunctions.dim_of_eigenspaces)

    @property
    def num_features(self): 
        return self._num_features

    # TODO: Change the weights shape to go by the broadcasted batch shapes of the inputs and the kernel
    def weights(self, num_samples, inputs=None, resample=True) -> torch.Tensor:
        broadcasted_batch_shape = torch.broadcast_shapes(self.covar_module.batch_shape, inputs.shape[:-2]) if inputs is not None else self.covar_module.batch_shape
        if resample or self._weights is None: 
            self._weights = torch.randn(*broadcasted_batch_shape, self.num_features, num_samples) # [M, O]
        else: 
            assert self._weights.shape[-1] == num_samples, f"Sample shape mismatch. Resample or use sample_shape with product {self._weights.shape[-1]}."
        return self._weights

    def compute_features(self, inputs, normalize=True):
        params = self.base_kernel.geometric_kernel_params
        key = torch.Generator() 
        # key = None
        _, features = self.feature_map(inputs, params, key=key, normalize=normalize)
        return features * self.base_kernel.batch_shape_scaling_factor.sqrt()

    def sample(self, inputs, weights, sample_shape: torch.Size = torch.Size([]), normalize=True): 
        """
        :param inputs: [..., D]
        """
        features = self.compute_features(inputs=inputs, normalize=normalize) # [..., batch_shape, N, num_eigenfunctions]
        res = torch.einsum('...ne, ...es -> s...n', features, weights)
        res = self.covar_module.outputscale.sqrt().unsqueeze(-1) * res 
        res = res + self.mean_module(inputs)
        return res.view(*sample_shape, *res.shape[1:])
    
    def forward(self, inputs: torch.Tensor, sample_shape: torch.Size = torch.Size([]), resample_weights=True, normalize=True) -> torch.Tensor: 
        """
        :return: A sample from the model. [S, O, N]
        """ 
        weights = self.weights(num_samples=math.prod(sample_shape), inputs=inputs, resample=resample_weights)
        return self.sample(inputs=inputs, weights=weights, sample_shape=sample_shape, normalize=normalize)

In [19]:
from mdgp.models.deep_gps.layers import DeepGPLayer
from mdgp.samplers import VISampler, PosteriorSampler


class GeometricDeepGPLayer(DeepGPLayer):
    def __init__(
        self, 
        output_dims: int,
        inducing_points: torch.Tensor,
        base_kernel, 
        feature_map: str = 'deterministic', 
        learn_inducing_locations: bool = False,
        whitened_variational_strategy=False, 
        sampler_inv_jitter=10e-8,
        outputscale_prior=None,
        zero_mean=True, 
        num_features=None, 
    ) -> None: 
        batch_shape = torch.Size([output_dims]) if output_dims is not None else torch.Size([])

        # Initialize mean and kernel modules 
        if zero_mean:
            mean_module = gpytorch.means.ZeroMean(
                batch_shape=batch_shape,
            )
        else: 
            mean_module = gpytorch.means.ConstantMean(
                batch_shape=batch_shape,
            )
        covar_module = gpytorch.kernels.ScaleKernel(
            base_kernel=base_kernel,
            batch_shape=batch_shape,
            outputscale_prior=outputscale_prior,
        )
        if outputscale_prior is not None: 
            covar_module.initialize(outputscale=outputscale_prior.mean)

        super().__init__(mean_module=mean_module, covar_module=covar_module, inducing_points=inducing_points, output_dims=output_dims, learn_inducing_locations=learn_inducing_locations, whitened_variational_strategy=whitened_variational_strategy)

        # Set up posterior sampler. VISampler needs the VariationalDistribution object for that the changing parameters are tracked properly
        rff_sampler = RFFSampler(covar_module=covar_module, mean_module=mean_module, feature_map=feature_map, num_features=num_features)
        vi_sampler = VISampler(variational_distribution=self.variational_strategy._variational_distribution)
        self.sampler = PosteriorSampler(rff_sampler=rff_sampler, vi_sampler=vi_sampler, inducing_points=inducing_points, whitened_variational_strategy=whitened_variational_strategy, inv_jitter=sampler_inv_jitter)

    def sample_pathwise(self, inputs, are_samples=False, resample_weights=True):
        # Clear cache if training, since otherwise we risk "trying to backward through the graph a second time" errors 
        if self.training: 
            self.variational_strategy._clear_cache()
        # Maybe initialize variational distribution (Taken from gpytorch.variational._VariationalStrategy.__call__)
        if not self.variational_strategy.variational_params_initialized.item():
            prior_dist = self.variational_strategy.prior_distribution
            self.variational_strategy._variational_distribution.initialize_variational_distribution(prior_dist)
            self.variational_strategy.variational_params_initialized.fill_(1)

        # Take sample 
        if are_samples: 
            sample_shape = torch.Size([])
            sample = self.sampler(inputs.unsqueeze(-3), sample_shape=sample_shape, resample=resample_weights) # [S, O, N]
        else: 
            sample_shape = torch.Size([gpytorch.settings.num_likelihood_samples.value()])
            sample = self.sampler(inputs, sample_shape=sample_shape, resample=resample_weights) # [S, O, N]
        return sample.mT

In [20]:
def hyperbolic_grid(space: Hyperbolic, num_points=100): 
    s = torch.linspace(-5, 5, num_points)
    xx, yy = torch.meshgrid(s, s)
    points = torch.stack([xx, yy], dim=-1).reshape(-1, 2)
    points = space.from_coordinates(points, "intrinsic")
    return points

In [21]:
space = Hyperbolic(dim=2)
num_random_phases = 500 
num_eigenfunctions = 35
num_features = num_random_phases * 2
output_dims = 3 
num_inducing = 10**2
inducing_points = hyperbolic_grid(space, num_points=math.isqrt(num_inducing))

base_kernel = GeometricMaternKernel(space=space, num_eigenfunctions=num_eigenfunctions, num_random_phases=num_random_phases)
feature_map = get_feature_map(space=space, num_eigenfunctions=num_eigenfunctions, num_random_phases=num_random_phases)

layer = GeometricDeepGPLayer(
    output_dims=output_dims, 
    inducing_points=inducing_points, 
    num_features=num_features, 
    base_kernel=base_kernel,
    feature_map=feature_map,
)
num_inputs = 27
inputs = hyperbolic_grid(space, num_inputs)


In [22]:
gpytorch.settings.num_likelihood_samples._set_value(2)
with torch.no_grad():
    torch.manual_seed(0)
    outputs = layer(inputs)
    c = outputs.mean[1, ..., 0]

    x, y, z = inputs.view(num_inputs, num_inputs, -1).unbind(-1)

fig = go.Figure(data=[go.Surface(x=y, y=z, z=x, surfacecolor=c.view_as(x))])
fig.show()

In [23]:
with torch.no_grad():
    torch.manual_seed(0)
    outputs = layer(inputs, mean=True)
    c = outputs[..., 2]
    x, y, z = inputs.view(num_inputs, num_inputs, -1).unbind(-1)

fig = go.Figure(data=[go.Surface(x=y, y=z, z=x, surfacecolor=c.view_as(x))])
fig.show()

In [24]:
with torch.no_grad():
    torch.manual_seed(0)
    outputs = layer(inputs, sample='pathwise')
    c = outputs[0, ..., 2]
    x, y, z = inputs.view(num_inputs, num_inputs, -1).unbind(-1)

fig = go.Figure(data=[go.Surface(x=y, y=z, z=x, surfacecolor=c.view_as(x))])
fig.show()