In [11]:
import math 
import torch 
import gpytorch 
import geometric_kernels.torch
from mdgp.models.deep_gps.layers import GeometricDeepGPLayer
from mdgp.utils import sphere_uniform_grid, sphere_meshgrid, spherical_harmonic
from geometric_kernels.spaces import Hypersphere


import plotly.io as pio
from plotly import graph_objects as go
from plotly.subplots import make_subplots

pio.templates.default = "plotly"
torch.set_default_dtype(torch.float64)

In [76]:
space = Hypersphere(dim=2)
inducing_points = sphere_uniform_grid(60)
model = GeometricDeepGPLayer(space=space, num_eigenfunctions=20, output_dims=None, inducing_points=inducing_points, nu=0.5)
model.likelihood = gpytorch.likelihoods.GaussianLikelihood()

In [77]:
def target_function(x):
    return spherical_harmonic(x, m=2, n=3)

def get_data(n, target_fnc, noise_std=0.01, arrangement='uniform', meshgrid_eps=10e-6):
    if arrangement == 'uniform': 
        inputs = sphere_uniform_grid(n=n)
    elif arrangement == 'meshgrid': 
        s = math.isqrt(n)
        inputs = sphere_meshgrid(s, s, meshgrid_eps).view(-1, 3)
    outputs = target_fnc(inputs)
    return inputs, outputs + torch.randn_like(outputs) * noise_std


inputs, targets = get_data(10000, target_function, arrangement='meshgrid')
train_inputs, train_targets = get_data(400, target_function)
criterion = gpytorch.mlls.DeepApproximateMLL(
    gpytorch.mlls.VariationalELBO(model.likelihood, model, train_targets.size(0))
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, maximize=True)

In [78]:
model.train()
for epoch in range(1000): 
    optimizer.zero_grad()
    outputs = model(train_inputs)
    loss = criterion(outputs, train_targets)
    loss.backward()
    optimizer.step()
    print(f"{epoch=}, elbo: {loss.item():.3f}", end='\r')

epoch=999, elbo: 1.2354

In [87]:
with torch.no_grad():
    sampler = model.sampler
    sample_shape = torch.Size([])

    x = inputs 
    z = sampler.inducing_points

    # Step 1. Get prior samples from RFF
    Phi_w_x, Phi_w_z = sampler.sample_prior(x=x, z=z, sample_shape=sample_shape, normalize_kernel=True) 

    # Step 2. Get prior sample from VI
    u = sampler.sample_variational(sample_shape=sample_shape)

    # Step 3. Update prior 
    update = sampler.compute_posterior_update(x=x, z=z, u=u, Phi_w_z=Phi_w_z)

    posterior_sample = Phi_w_x + update 

In [88]:
scene_kwargs = dict(
    camera=dict(eye=dict(x=1.5, y=1.5, z=0.2)),
    xaxis=dict(showbackground=False, gridcolor='lightgrey', showticklabels=False, title_text=""),
    yaxis=dict(showbackground=False, gridcolor='lightgrey', showticklabels=False, title_text=""),
    zaxis=dict(showbackground=True, gridcolor='lightgrey', showticklabels=False, title_text=""),
)

def plot_scalar_valued_gp(surface_inputs, prior, scatter_inputs, variational, posterior, update, height=500, width=500):
    fig = make_subplots(rows=1, cols=3, specs=[[{'type': 'scene'}] * 3])

    x, y, z = surface_inputs.unbind(-1)
    # prior 
    fig.add_trace(
        go.Surface(
            x=x, 
            y=y, 
            z=z, 
            surfacecolor=prior.view_as(x),
            coloraxis='coloraxis1'
        ), 
        col=1, row=1,
    )
    # update 
    fig.add_trace(
        go.Surface(
            x=x, 
            y=y, 
            z=z, 
            surfacecolor=update.view_as(x),
            coloraxis='coloraxis1'
        ), 
        col=2, row=1,
    )
    # posterior
    fig.add_trace(
        go.Surface(
            x=x, 
            y=y, 
            z=z, 
            surfacecolor=posterior.view_as(x),
            coloraxis='coloraxis1'
        ), 
        col=3, row=1,
    )

    scatter_radius = 1.02
    x, y, z = scatter_inputs.mul(scatter_radius).unbind(-1)
    fig.add_trace(
        go.Scatter3d(
            x=x, 
            y=y, 
            z=z, 
            mode="markers", 
            marker=dict(
                size=5,
                color=variational.view_as(x),
                colorscale='plasma',
                opacity=1.0,
                line=dict(width=2, colorscale=['black', 'black']),
                coloraxis='coloraxis1',
            ), 
        ), 
        col=3, row=1,
    )
    r = scatter_radius * 1.01
    fig.update_layout(
        scene=dict(
            xaxis_range=[-r, r], 
            yaxis_range=[-r, r], 
            zaxis_range=[-r, r], 
        )
    )

    fig.update_scenes(**scene_kwargs)
    fig.update_layout(
        height=height, 
        width=width * 3,
        coloraxis_colorscale='plasma'
    )
    return fig 

In [91]:
fig = plot_scalar_valued_gp(x.view(100, 100, 3), Phi_w_x, z, u, posterior_sample, update=update, width=400)
fig.write_image('./report_plots/posterior_sample.svg', scale=4, width=400 * 3, height=500)
fig