In [28]:
import torch 
import geometric_kernels.torch
from mdgp.utils import sphere_uniform_grid, sphere_meshgrid, spherical_harmonic, spherical_antiharmonic, rotate 
import math 

import plotly.io as pio
from plotly import graph_objects as go
from plotly.subplots import make_subplots

pio.templates.default = "plotly"
torch.set_default_dtype(torch.float64)

In [9]:
def get_target_function(name='smooth', *args, **kwargs):
    if name == 'smooth': 
        def target_function(x):
            return spherical_harmonic(x, m=2, n=3)
    elif name == 'singular': 
        def target_function(x):
            return spherical_antiharmonic(x, m=1, n=2) + spherical_antiharmonic(rotate(x, roll=math.pi / 2), m=1, n=1)
    else:
        raise NotImplementedError
    
    return target_function

In [115]:
scene_kwargs = dict(
    camera=dict(eye=dict(x=1.5, y=1.5, z=0.2)),
    xaxis=dict(showbackground=False, gridcolor='lightgrey' title_text="X", range=[-1, 1]),
    yaxis=dict(showbackground=False, gridcolor='lightgrey', showticklabels=False, title_text="Y", range=[-1, 1]),
    zaxis=dict(showbackground=True, gridcolor='lightgrey', showticklabels=False, title_text="Z", range=[-1, 1]),
)


def plot_sphere(inputs, outputs, height=500, width=500):
    fig =  make_subplots(rows=1, cols=1, specs=[[{'type': 'scene'}] * 2])
    x, y, z = inputs.view(100, 100, 3).unbind(-1)
    fig.add_trace(
        go.Surface(
            x=x, 
            y=y, 
            z=z, 
            surfacecolor=outputs.view_as(x),
            coloraxis='coloraxis1'
        ), 
        col=1, row=1, 
    )

    inputs = rotate(inputs, pitch=math.pi)
    x, y, z = inputs.view(100, 100, 3).unbind(-1)
    fig.add_trace(
        go.Surface(
            x=x, 
            y=y, 
            z=z, 
            surfacecolor=outputs.view_as(x),
            coloraxis='coloraxis1'
        ), 
        col=2, row=1, 
    )

    fig.update_scenes(
        **scene_kwargs
    )

    fig.update_layout(
        scene1=dict(camera=dict(eye=dict(x=0.1, y=2.0, z=2.0))),
        scene2=dict(camera=dict(eye=dict(x=0.1, y=-2.0, z=2.0)))
    )

    fig.update_layout(
        font_family="Serif",
        height=height, 
        width=width * 2,
        coloraxis_colorscale='plasma'
    )
    return fig 

In [37]:
def get_data(n, target_fnc, noise_std=0.01, arrangement='uniform', meshgrid_eps=10e-6):
    if arrangement == 'uniform': 
        inputs = sphere_uniform_grid(n=n)
    elif arrangement == 'meshgrid': 
        s = math.isqrt(n)
        inputs = sphere_meshgrid(s, s, meshgrid_eps).view(-1, 3)
    outputs = target_fnc(inputs)
    return inputs, outputs + torch.randn_like(outputs) * noise_std

# Smooth target function 

In [116]:
target_function = get_target_function('smooth')
inputs, targets = get_data(10000, target_function, arrangement='meshgrid')

fig = plot_sphere(inputs, targets)
fig.write_image('./report_plots/target_function_smooth.png', scale=4, width=500 * 2, height=500)
fig.write_image('./report_plots/target_function_smooth.svg', scale=4, width=500 * 2, height=500)
fig

# Target Function with Singularities

In [117]:
target_function = get_target_function('singular')
inputs, targets = get_data(10000, target_function, arrangement='meshgrid')

fig = plot_sphere(inputs, targets)
fig.write_image('./report_plots/target_function_singular.png', scale=4, width=500 * 2, height=500)
fig.write_image('./report_plots/target_function_singular.svg', scale=4, width=500 * 2, height=500)
fig