In [1]:
import math 
import torch 
import gpytorch 
import geometric_kernels.torch
from mdgp.utils import sphere_uniform_grid, sphere_meshgrid, spherical_antiharmonic, rotate
from mdgp.models.deep_gps import GeometricManifoldDeepGP
from geometric_kernels.spaces import Hypersphere


import plotly.io as pio
from plotly import graph_objects as go
from plotly.subplots import make_subplots

pio.templates.default = "plotly"
torch.set_default_dtype(torch.float64)

INFO: Using numpy backend


In [2]:
space = Hypersphere(dim=2)
inducing_points = sphere_uniform_grid(40)
outputscale_prior = gpytorch.priors.GammaPrior(1.0, 1 / 0.01) 
model = GeometricManifoldDeepGP(
    space=Hypersphere(dim=2), 
    num_hidden=1, 
    num_eigenfunctions=20, 
    learn_inducing_locations=False, 
    optimize_nu=True, 
    inducing_points=inducing_points, 
    project_to_tangent='intrinsic', 
    outputscale_prior=outputscale_prior, 
    tangent_to_manifold='exp',
    parametrised_frame=True,
)

In [114]:
z = 1.18
scene_kwargs = dict(
    xaxis=dict(showbackground=False, gridcolor='lightgrey', showticklabels=False, title_text="", range=[-1, 1]),
    yaxis=dict(showbackground=False, gridcolor='lightgrey', showticklabels=False, title_text="", range=[-1, 1]),
    zaxis=dict(showbackground=True, gridcolor='lightgrey', showticklabels=False, title_text="", range=[-1, 1]),
)


def plot_vector_valued_gp(inputs, outputs, surface=None, surfacecolor=None, height=500, width=500, color='black', scale=1.0):
    fig = go.Figure()

    if surface is not None: 
        x, y, z = surface.unbind(-1)
        if surfacecolor is None:
            fig.add_trace(
                go.Surface(
                    x=x, 
                    y=y, 
                    z=z, 
                    colorscale=[(0, 'lightgrey'), (1, 'lightgrey')],
                )
            )
        else:
            fig.add_trace(
                go.Surface(
                    x=x, 
                    y=y, 
                    z=z, 
                    surfacecolor=surfacecolor, 
                )
            )

    x, y, z = inputs.unbind(-1)
    u, v, w = outputs.unbind(-1)
    fig.add_trace(
        go.Cone(
            x=x, 
            y=y, 
            z=z, 
            u=u, 
            v=v, 
            w=w, 
            # sizemode="scaled",
            sizeref= scale / outputs.abs().max().item(),
            anchor="tail",
            showscale=False,
            colorscale=[color, color],
        ), 
    )
    fig.update_scenes(**scene_kwargs, 
                      )
    fig.update_layout(
        height=height, 
        width=width,
        xaxis_range=[-1, 1],
        yaxis_range=[-1, 1],

    )
    return fig 
    

def plot_coordinate_frame(inputs, outputs, surface=None, surfacecolor=None, height=500, width=500, pole=True): 
    e1, e2 = outputs.unbind(-1)
    fig = plot_vector_valued_gp(inputs, e1, surface=surface, surfacecolor=surfacecolor, height=height, width=width, color='black', scale=0.5)
    fig.add_trace(
        plot_vector_valued_gp(inputs, e2, height=height, width=width, color='lightgreen', scale=0.5).data[0]
    )

    if pole: 
        fig.update_scenes(
            camera=dict(eye=dict(x=0.0 * z, y=0.0 * z, z=1.7 * z)),
        )
    else: 
        fig.update_scenes(
            camera=dict(eye=dict(x=1.5 * z, y=0.0 * z, z=0.2 * z)),
        )


    return fig 


def target_function(x):
    return spherical_antiharmonic(x, m=1, n=2) + spherical_antiharmonic(rotate(x, roll=math.pi / 2), m=1, n=1)

In [115]:
class ExpandAs(torch.nn.Module): 
    def forward(self, x): 
        return torch.tensor([[1.0, 0.0, 0.0]]).expand_as(x)

In [116]:
model.hidden_layers[0].project_to_tangent.frame.get_normal_vector = ExpandAs()

In [117]:
cone_inputs = sphere_uniform_grid(100)
model.eval()
with torch.no_grad():
    frame = model.hidden_layers[0].project_to_tangent.frame.frame(cone_inputs)

In [119]:
surface = sphere_meshgrid(100, 100)
surfacecolor = target_function(surface)
fig = plot_coordinate_frame(cone_inputs, frame, surface=surface, surfacecolor=surfacecolor, pole=False)
fig.show()
fig.write_image('./report_plots/coordinate_frame_fixed_equator.svg')


In [145]:
from gpytorch.metrics import negative_log_predictive_density


def test_step(model, inputs, targets, sample_hidden='naive'):
    with torch.no_grad():
        model.eval() 
        outputs_f = model(inputs, sample_hidden=sample_hidden)
        outputs_y = model.likelihood(outputs_f)
        metrics = {
            'negative_log_predictive_density': negative_log_predictive_density(outputs_y, targets).mean(0)
        }
    return metrics 

In [205]:
model = GeometricManifoldDeepGP(
    space=Hypersphere(dim=2),
    num_hidden=1, 
    learn_inducing_locations=False, 
    inducing_points=inducing_points,
    project_to_tangent='intrinsic', 
    outputscale_prior=outputscale_prior,
    tangent_to_manifold='exp',
    parametrised_frame=True,
)
# model.hidden_layers[0].project_to_tangent.frame.get_normal_vector = ExpandAs()

In [206]:
def get_data(n, target_fnc, noise_std=0.01):
    inputs = sphere_uniform_grid(n=n)
    outputs = target_fnc(inputs)
    return inputs, outputs + torch.randn_like(outputs) * noise_std


train_inputs, train_targets = get_data(400, target_function)
criterion = gpytorch.mlls.DeepApproximateMLL(
    gpytorch.mlls.VariationalELBO(model.likelihood, model, train_targets.size(0))
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, maximize=True)

In [207]:
model.train()
for epoch in range(1000): 
    optimizer.zero_grad()
    outputs = model(train_inputs)
    loss = criterion(outputs, train_targets)
    loss.backward()
    optimizer.step()
    print(f"{epoch=}, elbo: {loss.item():.3f}", end='\r')
    if epoch % 100 == 0: 
        print(test_step(model, test_inputs, test_targets))
        model.train()

{'negative_log_predictive_density': tensor(0.7806)}
{'negative_log_predictive_density': tensor(0.3052)}
{'negative_log_predictive_density': tensor(-0.2180)}
{'negative_log_predictive_density': tensor(-0.7245)}
{'negative_log_predictive_density': tensor(-1.1748)}
{'negative_log_predictive_density': tensor(-1.5303)}
{'negative_log_predictive_density': tensor(-1.6413)}
{'negative_log_predictive_density': tensor(-1.7507)}
{'negative_log_predictive_density': tensor(-1.8025)}
{'negative_log_predictive_density': tensor(-1.8226)}
epoch=999, elbo: 1.725

In [176]:
test_inputs, test_targets = get_data(2000, target_function)




{'negative_log_predictive_density': tensor(-1.6199)}

In [162]:
surface = sphere_meshgrid(100, 100)
surfacecolor = target_function(surface)
model.eval()
with torch.no_grad():
    frame = model.hidden_layers[0].project_to_tangent.frame.frame(cone_inputs)
fig = plot_coordinate_frame(cone_inputs, frame, surface=surface, surfacecolor=surfacecolor, pole=False)
fig.show()
fig.write_image('./report_plots/coordinate_frame_parametrised_equator.svg')