In [1]:
import math 
import torch 
import geometric_kernels.torch
from mdgp.utils import sphere_uniform_grid, sphere_meshgrid, spherical_harmonic, spherical_antiharmonic, rotate 

import plotly.io as pio
from plotly import graph_objects as go

pio.templates.default = "plotly"
torch.set_default_dtype(torch.float64)

INFO: Using numpy backend


In [2]:
def get_target_function(name='smooth', *args, **kwargs):
    if name == 'smooth': 
        def target_function(x):
            return spherical_harmonic(x, m=2, n=3)
    elif name == 'singular': 
        def target_function(x):
            return spherical_antiharmonic(x, m=1, n=2) + spherical_antiharmonic(rotate(x, roll=math.pi / 2), m=1, n=1)
    else:
        raise NotImplementedError
    
    return target_function

In [75]:
scene_kwargs = dict(
    camera=dict(eye=dict(x=0.0, y=0.0, z=2.0)),
    xaxis=dict(showbackground=False, gridcolor='lightgrey', showticklabels=False, title_text="", range=[-1, 1]),
    yaxis=dict(showbackground=False, gridcolor='lightgrey', showticklabels=False, title_text="", range=[-1, 1]),
    zaxis=dict(showbackground=True, gridcolor='lightgrey', showticklabels=False, title_text="", range=[-1, 1]),
)


def plot(inputs=None, outputs=None, surface=None, surfacecolor=None, height=500, width=500, marker_size=5):
    fig = go.Figure()

    if surface is not None: 
        x, y, z = surface.unbind(-1)
        if surfacecolor is None:
            fig.add_trace(
                go.Surface(
                    x=x, 
                    y=y, 
                    z=z, 
                    colorscale=[(0, 'lightgrey'), (1, 'lightgrey')],
                )
            )
        else:
            fig.add_trace(
                go.Surface(
                    x=x, 
                    y=y, 
                    z=z, 
                    surfacecolor=surfacecolor, 
                )
            )

    fig.update_scenes(**scene_kwargs)

    if inputs is not None: 
        scatter_radius = 1.04
        x, y, z = inputs.squeeze().mul(scatter_radius).unbind(-1)
        fig.add_trace(
            go.Scatter3d(
                x=x, 
                y=y, 
                z=z, 
                mode="markers", 
                marker=dict(
                    size=marker_size,
                    color=outputs if outputs is not None else 'black',
                    colorscale='plasma',
                    opacity=1.0,
                    line=dict(width=2, colorscale=['black', 'black'])),
            ), 
        )
        r = scatter_radius * 1.01
        fig.update_scenes(
            xaxis_range=[-r, r], 
            yaxis_range=[-r, r], 
            zaxis_range=[-r, r], 
        )

    fig.update_layout(
        height=height, 
        width=width,

    )
    return fig

In [29]:
def get_data(n, target_fnc, noise_std=0.01, arrangement='uniform', meshgrid_eps=10e-6):
    if arrangement == 'uniform': 
        inputs = sphere_uniform_grid(n=n)
    elif arrangement == 'meshgrid': 
        s = math.isqrt(n)
        inputs = sphere_meshgrid(s, s, meshgrid_eps).view(-1, 3)
    outputs = target_fnc(inputs)
    return inputs, outputs + torch.randn_like(outputs) * noise_std

In [36]:
target_function = get_target_function('singular')
surface, surfacecolor = get_data(10000, target_function, arrangement='meshgrid')
surface = surface.view(100, 100, 3)
surfacecolor = surfacecolor.view(100, 100)

In [81]:
inputs, targets = get_data(100, target_function)

fig = plot(inputs, targets, surface)
fig

In [80]:
inputs, targets = get_data(200, target_function)

fig = plot(inputs, targets, surface)
fig

In [82]:
inputs, targets = get_data(400, target_function)

fig = plot(inputs, targets, surface)
fig

In [83]:
fig = plot(None, None, surface, surfacecolor)
fig