# Hyperbolic kernel extension to batch dimensions for implementation in deep GPs

In [2]:
import torch 
import geometric_kernels.torch
from geometric_kernels.frontends.pytorch.gpytorch import GPytorchGeometricKernel
from geometric_kernels.kernels.geometric_kernels import MaternFeatureMapKernel
from geometric_kernels.spaces import Hyperbolic, SymmetricPositiveDefiniteMatrices, DiscreteSpectrumSpace, Space
from geometric_kernels.kernels.feature_maps import deterministic_feature_map_compact, rejection_sampling_feature_map_spd, rejection_sampling_feature_map_hyperbolic


def get_feature_map(space, num_random_phases, num_eigenfunctions):
    if isinstance(space, Hyperbolic): 
        return rejection_sampling_feature_map_hyperbolic(space=space, num_random_phases=num_random_phases)
    if isinstance(space, SymmetricPositiveDefiniteMatrices):
        return rejection_sampling_feature_map_spd(space=space, num_random_phases=num_random_phases)
    if isinstance(space, DiscreteSpectrumSpace):
        return deterministic_feature_map_compact(space=space, num_eigenfunctions=num_eigenfunctions)
    raise NotImplementedError(f"Feature map for space {space} not implemented")


class GeometricMaternKernel(GPytorchGeometricKernel): 
    def __init__(
            self, 
            space: Space, 
            lengthscale=1.0, 
            nu=2.5, 
            trainable_nu=True, 
            num_eigenfunctions=35, 
            num_random_phases=3000,
            **kwargs
        ): 
        feature_map = get_feature_map(space=space, num_eigenfunctions=num_eigenfunctions, num_random_phases=num_random_phases)
        key = torch.Generator() # torch random state

        geometric_kernel = MaternFeatureMapKernel(
            space=space, 
            feature_map=feature_map,
            key=key,
            normalize=True, 
        )
        super().__init__(geometric_kernel, lengthscale=lengthscale, nu=nu, trainable_nu=trainable_nu, **kwargs)

  from .autonotebook import tqdm as notebook_tqdm
INFO: Using pytorch backend


In [3]:
space = Hyperbolic(dim=2)
kernel = GeometricMaternKernel(space=space, num_eigenfunctions=35, num_random_phases=500, ard_num_dims=None)
x = torch.randn(100, 3)
x = space.projection(x)
assert kernel(x, x).evaluate().shape == (100, 100)

In [4]:
batch_shape = torch.Size([13, 7])
kernel = GeometricMaternKernel(space=space, num_eigenfunctions=35, num_random_phases=500, batch_shape=batch_shape)
x = torch.randn(*batch_shape, 100, 3)
x = space.projection(x)

In [5]:
print(kernel(x, x).evaluate().shape)

torch.Size([13, 7, 100, 100])
