In [1]:
import plotly.io as pio
from plotly import graph_objects as go
pio.templates.default = "plotly_dark"

In [2]:
import torch 
import gpytorch
import geometric_kernels.torch 
from torch import Tensor 
from torch import nn 
from geomstats.geometry.hyperboloid import Hyperboloid

  from .autonotebook import tqdm as notebook_tqdm
INFO: Using pytorch backend


In [3]:
import torch 
import geometric_kernels.torch
from geometric_kernels.frontends.pytorch.gpytorch import GPytorchGeometricKernel
from geometric_kernels.kernels.geometric_kernels import MaternFeatureMapKernel
from geometric_kernels.spaces import Hyperbolic, SymmetricPositiveDefiniteMatrices, DiscreteSpectrumSpace, Space
from geometric_kernels.kernels.feature_maps import deterministic_feature_map_compact, rejection_sampling_feature_map_spd, rejection_sampling_feature_map_hyperbolic


def get_feature_map(space, num_random_phases, num_eigenfunctions):
    if isinstance(space, Hyperbolic): 
        return rejection_sampling_feature_map_hyperbolic(space=space, num_random_phases=num_random_phases)
    if isinstance(space, SymmetricPositiveDefiniteMatrices):
        return rejection_sampling_feature_map_spd(space=space, num_random_phases=num_random_phases)
    if isinstance(space, DiscreteSpectrumSpace):
        return deterministic_feature_map_compact(space=space, num_eigenfunctions=num_eigenfunctions)
    raise NotImplementedError(f"Feature map for space {space} not implemented")


class GeometricMaternKernel(GPytorchGeometricKernel): 
    def __init__(
            self, 
            space: Space, 
            lengthscale=1.0, 
            nu=2.5, 
            trainable_nu=True, 
            num_eigenfunctions=35, 
            num_random_phases=3000,
            **kwargs
        ): 
        feature_map = get_feature_map(space=space, num_eigenfunctions=num_eigenfunctions, num_random_phases=num_random_phases)
        key = torch.Generator() # torch random state

        geometric_kernel = MaternFeatureMapKernel(
            space=space, 
            feature_map=feature_map,
            key=key,
            normalize=True, 
        )
        super().__init__(geometric_kernel, lengthscale=lengthscale, nu=nu, trainable_nu=trainable_nu, **kwargs)


def hyperbolic_grid(space: Hyperbolic, num_points=100): 
    s = torch.linspace(-5, 5, num_points)
    xx, yy = torch.meshgrid(s, s)
    points = torch.stack([xx, yy], dim=-1).reshape(-1, 2)
    points = space.from_coordinates(points, "intrinsic")
    return points

In [4]:
import gpytorch 
import math 


class RFFSampler(torch.nn.Module):

    def __init__(self, covar_module, mean_module, feature_map, num_features=None) -> None: 
        super().__init__()

        # Is covar_module as ScaleKernel? 
        assert isinstance(covar_module, gpytorch.kernels.ScaleKernel), "RFFSampler only implemented for ScaleKernel."
        self.covar_module = covar_module
        self.mean_module = mean_module
        self.feature_map = feature_map
        self._num_features = num_features # currently have to pass in num_features, but could just learn it from first use

        self.base_kernel = covar_module.base_kernel
        self.geometric_kernel = self.base_kernel.geometric_kernel
        self.space = self.geometric_kernel.space
        
        self._weights = None
        # self._num_features = sum(self.geometric_kernel.eigenfunctions.dim_of_eigenspaces)

    @property
    def num_features(self): 
        return self._num_features

    # TODO: Change the weights shape to go by the broadcasted batch shapes of the inputs and the kernel
    def weights(self, num_samples, inputs=None, resample=True) -> torch.Tensor:
        broadcasted_batch_shape = torch.broadcast_shapes(self.covar_module.batch_shape, inputs.shape[:-2]) if inputs is not None else self.covar_module.batch_shape
        if resample or self._weights is None: 
            self._weights = torch.randn(*broadcasted_batch_shape, self.num_features, num_samples) # [M, O]
        else: 
            assert self._weights.shape[-1] == num_samples, f"Sample shape mismatch. Resample or use sample_shape with product {self._weights.shape[-1]}."
        return self._weights

    def compute_features(self, inputs, normalize=True):
        params = self.base_kernel.geometric_kernel_params
        key = torch.Generator() 
        # key = None
        features, _ = self.feature_map(inputs, params, key=key, normalize=normalize)
        return features * self.base_kernel.batch_shape_scaling_factor.sqrt()

    def sample(self, inputs, weights, sample_shape: torch.Size = torch.Size([]), normalize=True): 
        """
        :param inputs: [..., D]
        """
        features = self.compute_features(inputs=inputs, normalize=normalize) # [..., batch_shape, N, num_eigenfunctions]
        res = torch.einsum('...ne, ...es -> s...n', features, weights)
        res = self.covar_module.outputscale.sqrt().unsqueeze(-1) * res 
        res = res + self.mean_module(inputs)
        return res.view(*sample_shape, *res.shape[1:])
    
    def forward(self, inputs: torch.Tensor, sample_shape: torch.Size = torch.Size([]), resample_weights=True, normalize=True) -> torch.Tensor: 
        """
        :return: A sample from the model. [S, O, N]
        """ 
        weights = self.weights(num_samples=math.prod(sample_shape), inputs=inputs, resample=resample_weights)
        return self.sample(inputs=inputs, weights=weights, sample_shape=sample_shape, normalize=normalize)

In [5]:
from mdgp.models.deep_gps.layers import DeepGPLayer
from mdgp.samplers import VISampler, PosteriorSampler


class GeometricDeepGPLayer(DeepGPLayer):
    def __init__(
        self, 
        output_dims: int,
        inducing_points: torch.Tensor,
        base_kernel, 
        feature_map: str = 'deterministic', 
        learn_inducing_locations: bool = False,
        whitened_variational_strategy=False, 
        sampler_inv_jitter=10e-8,
        outputscale_prior=None,
        zero_mean=True, 
        num_features=None, 
    ) -> None: 
        batch_shape = torch.Size([output_dims]) if output_dims is not None else torch.Size([])

        # Initialize mean and kernel modules 
        if zero_mean:
            mean_module = gpytorch.means.ZeroMean(
                batch_shape=batch_shape,
            )
        else: 
            mean_module = gpytorch.means.ConstantMean(
                batch_shape=batch_shape,
            )
        covar_module = gpytorch.kernels.ScaleKernel(
            base_kernel=base_kernel,
            batch_shape=batch_shape,
            outputscale_prior=outputscale_prior,
        )
        if outputscale_prior is not None: 
            covar_module.initialize(outputscale=outputscale_prior.mean)

        super().__init__(mean_module=mean_module, covar_module=covar_module, inducing_points=inducing_points, output_dims=output_dims, learn_inducing_locations=learn_inducing_locations, whitened_variational_strategy=whitened_variational_strategy)

        # Set up posterior sampler. VISampler needs the VariationalDistribution object for that the changing parameters are tracked properly
        rff_sampler = RFFSampler(covar_module=covar_module, mean_module=mean_module, feature_map=feature_map, num_features=num_features)
        vi_sampler = VISampler(variational_distribution=self.variational_strategy._variational_distribution)
        self.sampler = PosteriorSampler(rff_sampler=rff_sampler, vi_sampler=vi_sampler, inducing_points=inducing_points, whitened_variational_strategy=whitened_variational_strategy, inv_jitter=sampler_inv_jitter)

    def sample_pathwise(self, inputs, are_samples=False, resample_weights=True):
        # Clear cache if training, since otherwise we risk "trying to backward through the graph a second time" errors 
        if self.training: 
            self.variational_strategy._clear_cache()
        # Maybe initialize variational distribution (Taken from gpytorch.variational._VariationalStrategy.__call__)
        if not self.variational_strategy.variational_params_initialized.item():
            prior_dist = self.variational_strategy.prior_distribution
            self.variational_strategy._variational_distribution.initialize_variational_distribution(prior_dist)
            self.variational_strategy.variational_params_initialized.fill_(1)

        # Take sample 
        if are_samples: 
            sample_shape = torch.Size([])
            sample = self.sampler(inputs.unsqueeze(-3), sample_shape=sample_shape, resample=resample_weights) # [S, O, N]
        else: 
            sample_shape = torch.Size([gpytorch.settings.num_likelihood_samples.value()])
            sample = self.sampler(inputs, sample_shape=sample_shape, resample=resample_weights) # [S, O, N]
        return sample.mT

In [6]:
class ToTangent(nn.Module): 
    def __init__(self, space: Hyperboloid): 
        super().__init__()
        self.space = space 

    def forward(self, base_points, ambient_vectors): 
        return self.space.to_tangent(ambient_vectors, base_points)
    

class ExpMap(nn.Module):
    def __init__(self, space: Hyperboloid): 
        super().__init__()
        self.space = space 

    def forward(self, base_points, tangent_vectors): 
        return self.space.metric.exp(tangent_vectors, base_points)

In [7]:
class ManifoldToManifoldDeepGPLayer(torch.nn.Module): 
    def __init__(self, gp, space): 
        super().__init__()
        self.gp = gp 
        self.project_to_tangent = ToTangent(space=space)
        self.tangent_to_manifold = ExpMap(space=space)

    def forward(self, x, are_samples=False, return_hidden=False, mean=False, sample='naive', resample_weights=True): 
        ambient_vectors = self.gp(x, mean=mean, sample=sample, are_samples=are_samples, resample_weights=resample_weights)
        tangent_vectors = self.project_to_tangent(x, ambient_vectors)
        # Cannot project vector of norm zero onto the hyperboloid
        
        y = self.tangent_to_manifold(x, tangent_vectors)
        if return_hidden: 
            return {'coefficients': ambient_vectors, 'tangent': tangent_vectors, 'manifold': y}
        return y

In [8]:
class ManifoldDeepGP(gpytorch.models.deep_gps.DeepGP): 

    def __init__(self, hidden_gps, output_gp, space):
        super().__init__()
        self.hidden_layers = torch.nn.ModuleList([
            ManifoldToManifoldDeepGPLayer(gp=gp, space=space)
            for gp in hidden_gps
        ])
        self.output_layer = output_gp
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()

    def forward_return_hidden(self, x: Tensor, are_samples: bool = False, sample_hidden: str = 'naive', sample_output=False, mean=False, 
                              resample_weights=True):
        hidden_factors = []
        for hidden_layer in self.hidden_layers: 
            hidden_dict = hidden_layer(x=x, are_samples=are_samples, sample=sample_hidden, mean=mean, return_hidden=True, resample_weights=resample_weights)
            hidden_factors.append(hidden_dict)
            x = hidden_dict['manifold']
            are_samples = False if mean else True 
        y = self.output_layer(x, are_samples=are_samples, sample=sample_output, mean=mean, resample_weights=resample_weights)
        return hidden_factors, y 

    def forward(self, x: Tensor, are_samples: bool = False, sample_hidden: str = 'naive', sample_output=False, mean=False, resample_weights: bool = True):
        for hidden_layer in self.hidden_layers: 
            x = hidden_layer(x, are_samples=are_samples, sample=sample_hidden, mean=mean, resample_weights=resample_weights)
            are_samples = False if mean else True 
        return self.output_layer(x, are_samples=are_samples, sample=sample_output, mean=mean, resample_weights=resample_weights)

In [36]:
space = Hyperbolic(2)
hidden_output_dims = 3
inducing_points = hyperbolic_grid(space, 10)
num_eigenfunctions = 35
num_random_phases = 100
num_features = num_random_phases * 2
feature_map = get_feature_map(space=space, num_eigenfunctions=num_eigenfunctions, num_random_phases=num_random_phases)

base_kernel = GeometricMaternKernel(
    space=space, 
    num_eigenfunctions=num_eigenfunctions, 
    num_random_phases=num_random_phases,
    batch_shape=torch.Size([hidden_output_dims]),
)
hidden_layer = GeometricDeepGPLayer(
    output_dims=hidden_output_dims, 
    inducing_points=inducing_points, 
    num_features=num_features, 
    base_kernel=base_kernel,
    feature_map=feature_map,
)
hidden_layer.covar_module.outputscale *= 0.001
with torch.no_grad():
    hidden_layer.variational_strategy._variational_distribution.variational_mean *= 0.001

base_kernel = GeometricMaternKernel(
    space=space, 
    num_eigenfunctions=num_eigenfunctions, 
    num_random_phases=num_random_phases
)
output_layer = GeometricDeepGPLayer(
    output_dims=None, 
    inducing_points=inducing_points, 
    num_features=num_features, 
    base_kernel=base_kernel, 
    feature_map=feature_map,
    zero_mean=False, 
)
output_layer.covar_module.outputscale *= 0.1

In [37]:
model = ManifoldDeepGP(hidden_gps=[hidden_layer], output_gp=output_layer, space=space)
gpytorch.settings.num_likelihood_samples._set_value(2)
num_inputs = 50
inputs = hyperbolic_grid(space, 50)

In [38]:
with torch.no_grad():
    torch.manual_seed(0)
    outputs = model(inputs, mean=True)
    c = outputs
    x, y, z = inputs.view(num_inputs, num_inputs, -1).unbind(-1)

fig = go.Figure(data=[go.Surface(x=y, y=z, z=x, surfacecolor=c.view_as(x))])
fig.show()

In [41]:
with torch.no_grad():
    torch.manual_seed(2)
    outputs = model(inputs, sample_hidden='pathwise', sample_output='pathwise')
    c = outputs[0]
    x, y, z = inputs.view(num_inputs, num_inputs, -1).unbind(-1)

fig = go.Figure(data=[go.Surface(x=y, y=z, z=x, surfacecolor=c.view_as(x))])
fig.show()

In [44]:
def target_function(tensor, m=5):
    x, y, _ = tensor.unbind(dim=-1)
    # Convert (x, y, z) to cylindrical coordinates (r, theta, z)
    theta = torch.atan2(y, x)
    return torch.sin(m * theta)


train_inputs = hyperbolic_grid(space=space, num_points=50).view(-1, 3)
train_targets = target_function(train_inputs)

In [47]:
base_mll = gpytorch.mlls.VariationalELBO(model.likelihood, model, num_data=train_targets.size(0))
mll = gpytorch.mlls.DeepApproximateMLL(base_mll)

optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=0.01)

In [48]:
num_epochs = 200
model.train()
dim = 0
losses = []
for epoch in range(num_epochs):
    with gpytorch.settings.num_likelihood_samples(10):
        model.zero_grad()
        output = model(train_inputs, sample_hidden='naive')
        loss = -mll(output, train_targets)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        print(f"Epoch {epoch+1}/{num_epochs}, loss = {loss.item():.4f}", end='\r')

Epoch 153/200, loss = 0.9694

KeyboardInterrupt: 

In [49]:
with torch.no_grad():
    torch.manual_seed(2)
    outputs = model(inputs, sample_hidden='pathwise', sample_output='pathwise')
    c = outputs[0]
    x, y, z = inputs.view(num_inputs, num_inputs, -1).unbind(-1)

fig = go.Figure(data=[go.Surface(x=y, y=z, z=x, surfacecolor=c.view_as(x))])
fig.show()