In [10]:
import math 
import pytest
import torch 
import gpytorch
import geometric_kernels.torch
from geometric_kernels.spaces import Hypersphere
from mdgp.kernels import GeometricMaternKernel

import plotly.io as pio 
from plotly import graph_objects as go 
from plotly.subplots import make_subplots


pio.templates.default = "plotly_dark"
torch.set_default_dtype(torch.float64)

# Tests 

## Output shape

In [3]:
DIM = 2
SPACE = Hypersphere(DIM)
NU = 2.5
NUM_EIGENFUNCTIONS = 30
BATCH_SHAPE = torch.Size([2])

base_kernel = GeometricMaternKernel(space=SPACE, nu=NU, num_eigenfunctions=NUM_EIGENFUNCTIONS, batch_shape=BATCH_SHAPE)
kernel = gpytorch.kernels.ScaleKernel(base_kernel, batch_shape=BATCH_SHAPE)

In [4]:
def test_kernel_output_shape(kernel):
    x1 = torch.randn(13, 3)
    x2 = torch.randn(17, 3)
    x3 = torch.randn(10, 2, 19, 3)
    x4 = torch.randn(10, 2, 11, 3)

    # Evaluate is necessary here, since sometimes lazy shape will appear correct in spite of an incorrect evaluated shape 
    with torch.no_grad():
        assert kernel(x1, x2).evaluate().shape == torch.Size([2, 13, 17])
        assert kernel(x2, x3).evaluate().shape == torch.Size([10, 2, 17, 19])
        assert kernel(x3, x4).evaluate().shape == torch.Size([10, 2, 19, 11])

        assert kernel(x1).evaluate().shape == torch.Size([2, 13, 13])
        assert kernel(x3).evaluate().shape == torch.Size([10, 2, 19, 19])

        assert kernel(x1, diag=True).shape == torch.Size([2, 13])
        assert kernel(x3, diag=True).shape == torch.Size([10, 2, 19])


test_kernel_output_shape(kernel)

## Normalization

In [7]:
def sphere_randn(*size, **kwargs):
    x = torch.randn(*size, 3, **kwargs)
    return x / x.norm(dim=-1, keepdim=True)


def test_kernel_normalization(base_kernel):
    x1 = sphere_randn(10, *base_kernel.batch_shape, 19)
    x2 = sphere_randn(10, *base_kernel.batch_shape, 17)
    with torch.no_grad():
        diag = base_kernel(x1, diag=True, normalize=True)
        k = base_kernel(x1, x2, normalize=True).evaluate()
    
    assert torch.allclose(diag, torch.ones_like(diag))
    assert torch.all(k <= 1.)


test_kernel_normalization(base_kernel)

In [8]:
def compute_kernel_vs_spherical_distance(kernel, normalize=True, n=500):
    with torch.no_grad():
        north_pole = torch.tensor([[0., 0., 1.]])
        theta = torch.linspace(math.pi / 2, -math.pi / 2, n)
        north_to_south_pole = torch.column_stack([torch.cos(theta), torch.zeros(n), torch.sin(theta)])

        spherical_distances = torch.inner(north_pole, north_to_south_pole).acos().ravel()
        k = kernel(north_pole, north_to_south_pole, normalize=normalize).evaluate().squeeze(dim=(-1, -2))
        return k, spherical_distances


k_normalized, spherical_distance = compute_kernel_vs_spherical_distance(base_kernel, normalize=True)
k_unnormalized, _ = compute_kernel_vs_spherical_distance(base_kernel, normalize=False)

In [11]:
colorway = pio.templates['plotly_dark'].layout.colorway

rows = BATCH_SHAPE[0]
fig = make_subplots(rows=rows, cols=1, row_titles=[f"Output dim {i}" for i in range(1, rows + 1)], shared_xaxes=True,
                    x_title="Spherical Distance", y_title="Kernel Value")

for row in range(rows):
    fig.add_trace(
        go.Scatter(
            x=spherical_distance, 
            y=k_unnormalized[row],
            name="Unnormalized",
            legendgroup="Unnormalized", 
            marker=dict(
                color=colorway[0],
            ),
            showlegend=row == 0, 
        ), 
        row=row + 1, col=1, 
    )
    fig.add_trace(
        go.Scatter(
            x=spherical_distance, 
            y=k_normalized[row],
            name="Normalized",
            legendgroup="Normalized",
            marker=dict(
                color=colorway[1],
            ),
            showlegend=row == 0,
        ), 
        row=row + 1, col=1, 
    )

fig.update_layout(
    title="Matern Kernel Value vs Spherical Distance Between Points", 
    height=400 * rows, 
)

fig.show()

### Behaviour across number of eigenfunctions

In [12]:
nums_eigenfunctions = torch.arange(10, 41)

fig = go.Figure() 

for num_eigenfunctions in nums_eigenfunctions:
    geometric_kernel = GeometricMaternKernel(space=SPACE, nu=NU, num_eigenfunctions=num_eigenfunctions.item())
    geometric_kernel.lengthscale = 0.2
    k, d = compute_kernel_vs_spherical_distance(geometric_kernel, normalize=False)
    fig.add_trace(
        go.Scatter(
            visible=False,
            x=d, 
            y=k,
        )
    )

steps = []
for i, num_eigenfunctions in enumerate(nums_eigenfunctions):
    step = dict(
        method="update",
        args=[{"visible": [False] * len(nums_eigenfunctions)},
                {"title": f"Matern Kernel Value vs Spherical Distance. lengthscale=0.2, num_eigenfunctions={num_eigenfunctions:d}"}],
        label=f"{num_eigenfunctions}"
    )
    step["args"][0]["visible"][i] = True
    steps.append(step)

fig.data[20].visible = True

sliders = [dict(
    active=20,
    currentvalue={"prefix": "num_eigenfunctions: "},
    pad={"t": 50},
    steps=steps
)]

fig.update_layout(
    sliders=sliders,
    title="GeometricKernels Matern Kernel vs Spherical Distance. lengthscale=0.2, num_eigenfunctions=30",
    xaxis_title="Spherical Distance",
    yaxis_title="Kernel Value",
)

fig.show()


In [13]:
nums_eigenfunctions = torch.arange(10, 51)

fig = go.Figure() 
LENGTHSCALE = 0.2

for num_eigenfunctions in nums_eigenfunctions:
    geometric_kernel = GeometricMaternKernel(space=SPACE, nu=NU, num_eigenfunctions=num_eigenfunctions.item())
    geometric_kernel.lengthscale = LENGTHSCALE
    k, d = compute_kernel_vs_spherical_distance(geometric_kernel, normalize=True)
    fig.add_trace(
        go.Scatter(
            visible=False,
            x=d, 
            y=k,
        )
    )

steps = []
for i, num_eigenfunctions in enumerate(nums_eigenfunctions):
    step = dict(
        method="update",
        args=[{"visible": [False] * len(nums_eigenfunctions)},
                {"title": f"Matern Kernel Value vs Spherical Distance. lengthscale={LENGTHSCALE:.3f}, num_eigenfunctions={num_eigenfunctions:d}"}],
        label=f"{num_eigenfunctions}"
    )
    step["args"][0]["visible"][i] = True
    steps.append(step)

fig.data[35].visible = True

sliders = [dict(
    active=35,
    currentvalue={"prefix": "num_eigenfunctions: "},
    pad={"t": 50},
    steps=steps
)]

fig.update_layout(
    sliders=sliders,
    title=f"GeometricKernels Matern Kernel vs Spherical Distance. lengthscale={LENGTHSCALE:.3f}, num_eigenfunctions=45",
    xaxis_title="Spherical Distance",
    yaxis_title="Kernel Value",
)

fig.show()

In [14]:
nums_eigenfunctions = torch.arange(10, 51)

fig = go.Figure() 
LENGTHSCALE = 0.02

for num_eigenfunctions in nums_eigenfunctions:
    geometric_kernel = GeometricMaternKernel(space=SPACE, nu=NU, num_eigenfunctions=num_eigenfunctions.item())
    geometric_kernel.lengthscale = LENGTHSCALE
    k, d = compute_kernel_vs_spherical_distance(geometric_kernel, normalize=True)
    fig.add_trace(
        go.Scatter(
            visible=False,
            x=d, 
            y=k,
        )
    )

steps = []
for i, num_eigenfunctions in enumerate(nums_eigenfunctions):
    step = dict(
        method="update",
        args=[{"visible": [False] * len(nums_eigenfunctions)},
                {"title": f"Matern Kernel Value vs Spherical Distance. lengthscale={LENGTHSCALE:.3f}, num_eigenfunctions={num_eigenfunctions:d}"}],
        label=f"{num_eigenfunctions}"
    )
    step["args"][0]["visible"][i] = True
    steps.append(step)

fig.data[25].visible = True

sliders = [dict(
    active=25,
    currentvalue={"prefix": "num_eigenfunctions: "},
    pad={"t": 50},
    steps=steps
)]

fig.update_layout(
    sliders=sliders,
    title=f"GeometricKernels Matern Kernel vs Spherical Distance. lengthscale={LENGTHSCALE:.3f}, num_eigenfunctions=30",
    xaxis_title="Spherical Distance",
    yaxis_title="Kernel Value",
)

fig.show()