# Illustration of the projected implementation of Gaussian Vector Fields on the sphere

In [33]:
import torch 
import gpytorch 
import geometric_kernels.torch
from mdgp.models.deep_gps.layers import GeometricDeepGPLayer
from mdgp.utils import sphere_uniform_grid, sphere_meshgrid
from geometric_kernels.spaces import Hypersphere
from mdgp.models.projectors import ProjectToTangentExtrinsic


import plotly.io as pio
from plotly import graph_objects as go
from plotly.subplots import make_subplots

pio.templates.default = "plotly"
torch.set_default_dtype(torch.float64)

In [34]:
space = Hypersphere(dim=2)
inducing_points = sphere_uniform_grid(40)
model_extrinsic = GeometricDeepGPLayer(space=space, num_eigenfunctions=20, output_dims=3, inducing_points=inducing_points)
projector_extrinsic = ProjectToTangentExtrinsic(space=space)

In [171]:
inputs_mesh = sphere_meshgrid(100, 100).view(-1, 3)
inputs_cone = sphere_uniform_grid(100)
inputs = torch.cat([inputs_mesh, inputs_cone])
with torch.no_grad(), gpytorch.settings.num_likelihood_samples(1):
    ambient = model_extrinsic(inputs, sample='pathwise').squeeze(0)
    ambient_mesh, ambient_cone = ambient[:10000], ambient[10000:]
    tangent = projector_extrinsic(inputs, ambient) 
    tangent_mesh, tangent_cone = tangent[:10000], tangent[10000:]

In [192]:
scene_kwargs = dict(
    camera=dict(eye=dict(x=1.5, y=1.5, z=0.2)),
    xaxis=dict(showbackground=False, gridcolor='lightgrey', showticklabels=False, title_text="", range=[-1, 1]),
    yaxis=dict(showbackground=False, gridcolor='lightgrey', showticklabels=False, title_text="", range=[-1, 1]),
    zaxis=dict(showbackground=True, gridcolor='lightgrey', showticklabels=False, title_text="", range=[-1, 1]),
)

def plot_scalar_valued_gp(inputs, outputs, height=500, width=500):

    cols = outputs.size(-1)
    fig = make_subplots(rows=1, cols=cols, specs=[[{'type': 'scene'}] * 3])

    x, y, z = inputs.view(100, 100, 3).unbind(-1)
    for i, output in enumerate(outputs.unbind(-1), 1): 
        fig.add_trace(
            go.Surface(
                x=x, 
                y=y, 
                z=z, 
                surfacecolor=output.view_as(x),
                coloraxis='coloraxis1'
            ), 
            col=i, row=1,
        )
    fig.update_scenes(**scene_kwargs)
    fig.update_layout(
        height=height, 
        width=width * cols,
        coloraxis_colorscale='plasma'
    )
    return fig 


def plot_vector_valued_gp(inputs, outputs, surface=None, height=500, width=500):
    fig = go.Figure()

    if surface is not None: 
        x, y, z = surface.unbind(-1)
        fig.add_trace(
            go.Surface(
                x=x, 
                y=y, 
                z=z, 
                colorscale=[(0, 'lightgrey'), (1, 'lightgrey')],
            )
        )

    x, y, z = inputs.unbind(-1)
    u, v, w = outputs.unbind(-1)
    fig.add_trace(
        go.Cone(
            x=x, 
            y=y, 
            z=z, 
            u=u, 
            v=v, 
            w=w, 
            # sizemode="scaled",
            sizeref= 1. / outputs.abs().max().item(),
            anchor="tail",
            showscale=False,
            colorscale=['black','black'],
        ), 
    )
    fig.update_scenes(**scene_kwargs)
    fig.update_layout(
        height=height, 
        width=width,
        xaxis_range=[-1, 1],
        yaxis_range=[-1, 1],

    )
    return fig 
    

In [196]:
fig = plot_scalar_valued_gp(inputs_mesh, ambient_mesh, width=450)
fig.write_image('./report_plots/gvf_projected_scalar.svg', scale=4, height=500, width=450 * 3)
fig.write_image('./report_plots/gvf_projected_scalar.png', scale=4, height=500, width=450 * 3)
fig

In [197]:
fig = plot_vector_valued_gp(inputs_cone, ambient_cone, surface=inputs_mesh.view(100, 100, 3))
fig.write_image('./report_plots/gvf_projected_ambient.svg', scale=4, height=500, width=500)
fig.write_image('./report_plots/gvf_projected_ambient.png', scale=4, height=500, width=500)
fig

In [198]:
fig = plot_vector_valued_gp(inputs_cone, tangent_cone, surface=inputs_mesh.view(100, 100, 3))
fig.write_image('./report_plots/gvf_projected_tangent.svg', scale=4, height=500, width=500)
fig.write_image('./report_plots/gvf_projected_tangent.png', scale=4, height=500, width=500)
fig