In [1]:
import torch 
import gpytorch
import geometric_kernels.torch
from mdgp.kernels import GeometricMaternKernel
from geometric_kernels.spaces import Hypersphere

  from .autonotebook import tqdm as notebook_tqdm
INFO: Using numpy backend


In [2]:
DIM = 2
SPACE = Hypersphere(DIM)
NU = 2.5
NUM_EIGENFUNCTIONS = 30
BATCH_SHAPE = torch.Size([2])

base_kernel = GeometricMaternKernel(space=SPACE, nu=NU, num_eigenfunctions=NUM_EIGENFUNCTIONS, batch_shape=BATCH_SHAPE)
kernel = gpytorch.kernels.ScaleKernel(base_kernel, batch_shape=BATCH_SHAPE)
rbf = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=BATCH_SHAPE))

In [3]:
import pytest


def test_kernel_output_shape(kernel):
    x1 = torch.randn(13, 3)
    x2 = torch.randn(17, 3)
    x3 = torch.randn(10, 2, 19, 3)
    x4 = torch.randn(10, 2, 11, 3)

    # Evaluate is necessary here, since sometimes lazy shape will appear correct in spite of an incorrect evaluated shape 
    with torch.no_grad():
        assert kernel(x1, x2).evaluate().shape == torch.Size([2, 13, 17])
        assert kernel(x2, x3).evaluate().shape == torch.Size([10, 2, 17, 19])
        assert kernel(x3, x4).evaluate().shape == torch.Size([10, 2, 19, 11])

        assert kernel(x1).evaluate().shape == torch.Size([2, 13, 13])
        assert kernel(x3).evaluate().shape == torch.Size([10, 2, 19, 19])

        assert kernel(x1, diag=True).shape == torch.Size([2, 13])
        assert kernel(x3, diag=True).shape == torch.Size([10, 2, 19])

test_kernel_output_shape(kernel)

In [4]:
def sphere_randn(*size, **kwargs):
    x = torch.randn(*size, 3, **kwargs)
    return x / x.norm(dim=-1, keepdim=True)


def test_kernel_normalization(base_kernel):
    x1 = sphere_randn(10, *base_kernel.batch_shape, 19)
    x2 = sphere_randn(10, *base_kernel.batch_shape, 17)
    with torch.no_grad():
        diag = base_kernel(x1, diag=True)
        k = base_kernel(x1, x2).evaluate()
    
    assert torch.allclose(diag, torch.ones_like(diag))
    assert torch.all(k <= 1.)


test_kernel_normalization(base_kernel)