In [1]:
import torch 
import geometric_kernels.torch
import gpytorch
from mdgp.bo_experiment import ModelArguments, create_model
from mdgp.utils import sphere_uniform_grid, sphere_meshgrid

  from .autonotebook import tqdm as notebook_tqdm
INFO: Using numpy backend
INFO: Created a temporary directory at /tmp/tmpjung500v
INFO: Writing /tmp/tmpjung500v/_remote_module_non_scriptable.py


In [2]:
import torch 
from plotly import graph_objects as go

from mdgp.utils import sphere_meshgrid


def base_trace(num_meshgrid=50): 
    meshgrid = sphere_meshgrid(num_meshgrid, num_meshgrid)
    x_meshgrid, y_meshgrid, z_meshgrid = meshgrid.unbind(-1)
    grey_sphere_trace = go.Surface(x=x_meshgrid, y=y_meshgrid, z=z_meshgrid, colorscale='Greys', showscale=False)
    return grey_sphere_trace


def observations_trace(x, y=None):
    x_scatter, y_scatter, z_scatter = x.mul(1.02).unbind(-1)
    if y is None:
        color_scatter = torch.zeros(x.shape[0])
    else: 
        color_scatter = y.squeeze(-1)
    scatter_trace = go.Scatter3d(x=x_scatter, y=y_scatter, z=z_scatter, mode='markers', marker=dict(size=5, color=color_scatter, colorscale='Viridis'))
    return scatter_trace


def prediction_trace(prediction_function, num_meshgrid=50):
    meshgrid = sphere_meshgrid(num_meshgrid, num_meshgrid)
    x_meshgrid, y_meshgrid, z_meshgrid = meshgrid.unbind(-1)
    y = prediction_function(meshgrid.view(-1, meshgrid.shape[-1])).view_as(x_meshgrid)
    sphere_trace = go.Surface(x=x_meshgrid, y=y_meshgrid, z=z_meshgrid, surfacecolor=y, colorscale='Viridis', showscale=True)
    return sphere_trace


def target_trace(target_function, num_meshgrid=50):
    # 1. Compute y by evaluating target function on meshgrid
    meshgrid = sphere_meshgrid(num_meshgrid, num_meshgrid)
    x_meshgrid, y_meshgrid, z_meshgrid = meshgrid.unbind(-1)
    y = torch.zeros_like(x_meshgrid)
    for i in range(num_meshgrid): 
        for j in range(num_meshgrid): 
            y[i, j] = target_function(meshgrid[i, j])
    color_meshgrid = y.squeeze(-1)

    # 2. Plot a sphere colored by y
    sphere_trace = go.Surface(x=x_meshgrid, y=y_meshgrid, z=z_meshgrid, surfacecolor=color_meshgrid, colorscale='Viridis', showscale=True)
    return sphere_trace


def plot_traces(*traces): 
    fig = go.Figure()
    fig.add_traces([trace for trace in traces if trace is not None])
    return fig


def plot_observations(x, y=None, num_meshgrid=50): 
    # 1. Plot a grey sphere with meshgrid
    grey_sphere_trace = base_trace(num_meshgrid=num_meshgrid)

    # 2. Plot a scatter plot of observations slightly above the sphere
    scatter_trace = observations_trace(x, y=None)

    # 3. Add all traces to figure
    fig = plot_traces(grey_sphere_trace, scatter_trace)
    return fig


def plot_prediction(prediction_function, x=None, y=None, num_meshgrid=50):
    # 1. Plot a sphere colored by prediction function 
    sphere_trace = prediction_trace(prediction_function, num_meshgrid=num_meshgrid)

    # 2. (Optional) Plot observations slightly above the sphere
    if x is not None and y is not None: 
        scatter_trace = observations_trace(x, y)
    else: 
        scatter_trace = None

    # 3. Add all traces to figure
    fig = plot_traces(sphere_trace, scatter_trace)
    return fig

def plot_target(target_function, num_meshgrid=50): 
    # 1. Plot a sphere colored by target function
    sphere_trace = target_trace(target_function, num_meshgrid=num_meshgrid)

    # 2. Add all traces to figure
    fig = plot_traces(sphere_trace)
    return fig 

In [3]:
model_args = ModelArguments()
model_args.project_to_tangent = 'extrinsic'
inducing_points = sphere_uniform_grid(60)
model = create_model(model_args, inducing_points).base_model
x = sphere_meshgrid(100, 100)

In [7]:
class DGPSample(torch.nn.Module):
    def __init__(self, seed=0, num_inducing=60, project_to_tangent='extrinsic'):
        super().__init__()
        self.seed = seed 
        model_args = ModelArguments(project_to_tangent=project_to_tangent)
        inducing_points = sphere_uniform_grid(num_inducing)
        self.model = create_model(model_args, inducing_points).base_model


    def forward(self, x): 
        with torch.no_grad(), gpytorch.settings.num_likelihood_samples(1):
            initial_seed = torch.initial_seed()
            torch.manual_seed(self.seed)
            out = self.model(x, sample_hidden='pathwise', sample_output='pathwise', resample_weights=False)[0]
            torch.manual_seed(initial_seed)
        return out

In [15]:
dgp_sample = DGPSample(10)

In [16]:
with gpytorch.settings.num_likelihood_samples(1):
    fig = plot_prediction(dgp_sample, num_meshgrid=100)
    fig.show()