### Instead of jointly training every DGP parameter, we want to traing the coordinate frame separately. Perhaps, instead, we will train the parameters for 50 epochs, than the coordinate frame for 50 epochs, then the parameters, then the coordinate frame, etc. 

In [1]:
import torch 
from plotly import graph_objects as go

from mdgp.utils import sphere_meshgrid


def base_trace(num_meshgrid=50): 
    meshgrid = sphere_meshgrid(num_meshgrid, num_meshgrid)
    x_meshgrid, y_meshgrid, z_meshgrid = meshgrid.unbind(-1)
    grey_sphere_trace = go.Surface(x=x_meshgrid, y=y_meshgrid, z=z_meshgrid, colorscale='Greys', showscale=False)
    return grey_sphere_trace


def observations_trace(x, y=None):
    x_scatter, y_scatter, z_scatter = x.mul(1.02).unbind(-1)
    if y is None:
        color_scatter = torch.zeros(x.shape[0])
    else: 
        color_scatter = y.squeeze(-1)
    scatter_trace = go.Scatter3d(x=x_scatter, y=y_scatter, z=z_scatter, mode='markers', marker=dict(size=5, color=color_scatter, colorscale='Viridis'))
    return scatter_trace


def prediction_trace(prediction_function, num_meshgrid=50):
    meshgrid = sphere_meshgrid(num_meshgrid, num_meshgrid)
    x_meshgrid, y_meshgrid, z_meshgrid = meshgrid.unbind(-1)
    y = prediction_function(meshgrid.view(-1, meshgrid.shape[-1])).view_as(x_meshgrid)
    sphere_trace = go.Surface(x=x_meshgrid, y=y_meshgrid, z=z_meshgrid, surfacecolor=y, colorscale='Viridis', showscale=True)
    return sphere_trace


def target_trace(target_function, num_meshgrid=50):
    # 1. Compute y by evaluating target function on meshgrid
    meshgrid = sphere_meshgrid(num_meshgrid, num_meshgrid)
    x_meshgrid, y_meshgrid, z_meshgrid = meshgrid.unbind(-1)
    y = torch.zeros_like(x_meshgrid)
    for i in range(num_meshgrid): 
        for j in range(num_meshgrid): 
            y[i, j] = target_function(meshgrid[i, j])
    color_meshgrid = y.squeeze(-1)

    # 2. Plot a sphere colored by y
    sphere_trace = go.Surface(x=x_meshgrid, y=y_meshgrid, z=z_meshgrid, surfacecolor=color_meshgrid, colorscale='Viridis', showscale=True)
    return sphere_trace


def plot_traces(*traces): 
    fig = go.Figure()
    fig.add_traces([trace for trace in traces if trace is not None])
    return fig


def plot_observations(x, y=None, num_meshgrid=50): 
    # 1. Plot a grey sphere with meshgrid
    grey_sphere_trace = base_trace(num_meshgrid=num_meshgrid)

    # 2. Plot a scatter plot of observations slightly above the sphere
    scatter_trace = observations_trace(x, y=None)

    # 3. Add all traces to figure
    fig = plot_traces(grey_sphere_trace, scatter_trace)
    return fig


def plot_prediction(prediction_function, x=None, y=None, num_meshgrid=50):
    # 1. Plot a sphere colored by prediction function 
    sphere_trace = prediction_trace(prediction_function, num_meshgrid=num_meshgrid)

    # 2. (Optional) Plot observations slightly above the sphere
    if x is not None and y is not None: 
        scatter_trace = observations_trace(x, y)
    else: 
        scatter_trace = None

    # 3. Add all traces to figure
    fig = plot_traces(sphere_trace, scatter_trace)
    return fig

def plot_target(target_function, num_meshgrid=50): 
    # 1. Plot a sphere colored by target function
    sphere_trace = target_trace(target_function, num_meshgrid=num_meshgrid)

    # 2. Add all traces to figure
    fig = plot_traces(sphere_trace)
    return fig 

INFO: Using numpy backend


In [2]:
import torch 
import geometric_kernels.torch 
from tqdm.autonotebook import tqdm 
from mdgp.bo_experiment.fit import FitArguments, train_step 
torch.set_default_dtype(torch.float64)

def fit(model, optimizer, criterion, train_inputs, train_targets, train_loggers=None, fit_args: FitArguments = None, show_progress=True): 
    metrics = {'elbo': None}
    pbar = tqdm(range(1, fit_args.num_steps + 1), desc="Fitting", leave=False, disable=not show_progress)
    for step in pbar:
        # Training step and display training metrics 
        optimizer.zero_grad(set_to_none=True)
        loss = train_step(model=model, inputs=train_inputs, targets=train_targets, criterion=criterion, sample_hidden=fit_args.sample_hidden, 
                          loggers=train_loggers, step=step)
        loss.backward()
        optimizer.step() 
        metrics.update({'elbo': loss.item()})
        # Display metrics 
        pbar.set_postfix(metrics)
    return model 

  from tqdm.autonotebook import tqdm
INFO: Created a temporary directory at /tmp/tmpyu8eso87
INFO: Writing /tmp/tmpyu8eso87/_remote_module_non_scriptable.py


In [3]:
from mdgp.bo_experiment.model import ModelArguments, create_model

In [13]:
from mdgp.bayesian_optimisation.target_functions import Levy, Rosenbrock, StyblinskiTang, ProductOfSines
from pymanopt.manifolds import Sphere 
import numpy as np
from mdgp.utils import rotate, sphere_uniform_grid, sphere_kmeans_centers, spherical_antiharmonic
from mdgp.experiment_utils import get_target_function

target_function_ = ProductOfSines(Sphere(3)).compute_function_torch
def experimental_target_function(x):
    x = rotate(x, 0, np.pi * 3/2, 0)
    y = torch.zeros_like(x[..., 0])
    for i in np.ndindex(x.shape[:-1]): 
        y[i] = target_function_(x[i])
    return -(y + 100) / 200 * (1 - x[..., 0]) ** 2

experimental_target_function = lambda x: get_target_function('singular')(rotate(x, pitch=np.pi / 2))
experimental_target_function = lambda x: spherical_antiharmonic(x, 1, 2)

In [14]:
x = sphere_uniform_grid(400) + torch.tensor([[0., 0., 1.]])
x = x / x.norm(dim=-1, keepdim=True)
inducing_points = sphere_kmeans_centers(x, 60)
inducing_points = sphere_uniform_grid(60)
y = experimental_target_function(x)

In [15]:
def get_frame_parameters(model): 
    return [p for name, p in model.named_parameters() if 'frame' in name]

def get_non_frame_parameters(model): 
    return [p for name, p in model.named_parameters() if 'frame' not in name]

In [21]:
model_args = ModelArguments() 

from mdgp.utils import cart_to_sph

class CartToSph(torch.nn.Module): 
    def forward(self, x):
        return torch.stack(cart_to_sph(x), dim=-1)

# model_args.parametrised_frame = torch.nn.Sequential(
#     CartToSph(),
#     torch.nn.Linear(2, 10),
#     torch.nn.ReLU(),
#     torch.nn.Linear(10, 10),
#     torch.nn.ReLU(), 
#     torch.nn.Linear(10, 3),
# )
model_args.parametrised_frame = [10, 10]
model_args.project_to_tangent = 'intrinsic'
model_args.outputscale_prior_class = 'normal'
model_args.outputscale_std = 0.01
model_args.num_hidden = 1
# model_args.outputscale_mean = 0.01
# model_args.parametrised_frame = False 
model = create_model(model_args, inducing_points=inducing_points).base_model

In [17]:
mll = model_args.mll_factory(model, y)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, maximize=True)
# optimizer = torch.optim.Adam([
#     {'params': get_non_frame_parameters(model), 
#      'lr': 0.01}, {'params': get_frame_parameters(model), 'lr': 0.005}], maximize=True)

fit_args = FitArguments(num_steps=1000)
for i in range(1):
    print(f"Iter {i}")
    fit(model, optimizer, mll, x, y, fit_args=fit_args)
    
    # fit(model, optimizer_frame, mll, x, y, fit_args=fit_args_frame)
    # fit(model, optimizer_model, mll, x, y, fit_args=fit_args_model)

Iter 0


                                                                         

In [18]:
from gpytorch.metrics import negative_log_predictive_density

with torch.no_grad():
    test_x = sphere_uniform_grid(600)
    test_y = experimental_target_function(test_x)
    loss = negative_log_predictive_density(model.likelihood(model(x)), y)
    print(loss.mean())

tensor(-1.3618)


In [20]:
from gpytorch.metrics import negative_log_predictive_density
import gpytorch 

with torch.no_grad(), gpytorch.settings.num_likelihood_samples(1):
    fig = plot_prediction(lambda x: model(x, sample_hidden='pathwise', sample_output='pathwise'), num_meshgrid=100)
    fig.show()