In [1]:
import geometric_kernels.torch
import os 
import warnings 
import torch 
from mdgp.bo_experiment.data import get_initial_data
from mdgp.bo_experiment.model import create_model
from mdgp.experiment_utils.logging import CSVLogger, finalize, log 
from mdgp.bo_experiment.fit import fit
from mdgp.bo_experiment import (
    ExperimentConfig, set_experiment_seed, BOArguments, ModelArguments, FitArguments, optimize_acqf_manifold
)
from tqdm.autonotebook import tqdm 
from plotly import graph_objects as go
from mdgp.utils import sphere_meshgrid
from botorch.fit import fit_gpytorch_mll


torch.set_default_dtype(torch.float64)
set_experiment_seed(0)

  from .autonotebook import tqdm as notebook_tqdm
INFO: Using numpy backend
INFO: Created a temporary directory at /tmp/tmpcgampw2h
INFO: Writing /tmp/tmpcgampw2h/_remote_module_non_scriptable.py
INFO: Global seed set to 0
INFO: Global seed set to 0


In [2]:
def base_trace(num_meshgrid=50): 
    meshgrid = sphere_meshgrid(num_meshgrid, num_meshgrid)
    x_meshgrid, y_meshgrid, z_meshgrid = meshgrid.unbind(-1)
    grey_sphere_trace = go.Surface(x=x_meshgrid, y=y_meshgrid, z=z_meshgrid, colorscale='Greys', showscale=False)
    return grey_sphere_trace


def observations_trace(x, y=None):
    x_scatter, y_scatter, z_scatter = x.mul(1.02).unbind(-1)
    if y is None:
        color_scatter = torch.zeros(x.shape[0])
    else: 
        color_scatter = y.squeeze(-1)
    scatter_trace = go.Scatter3d(x=x_scatter, y=y_scatter, z=z_scatter, mode='markers', marker=dict(size=5, color=color_scatter, colorscale='Viridis'))
    return scatter_trace


def prediction_trace(prediction_function, num_meshgrid=50):
    meshgrid = sphere_meshgrid(num_meshgrid, num_meshgrid)
    x_meshgrid, y_meshgrid, z_meshgrid = meshgrid.unbind(-1)
    y = prediction_function(meshgrid.view(-1, meshgrid.shape[-1])).view_as(x_meshgrid)
    sphere_trace = go.Surface(x=x_meshgrid, y=y_meshgrid, z=z_meshgrid, surfacecolor=y, colorscale='Viridis', showscale=True)
    return sphere_trace


def target_trace(target_function, num_meshgrid=50):
    # 1. Compute y by evaluating target function on meshgrid
    meshgrid = sphere_meshgrid(num_meshgrid, num_meshgrid)
    x_meshgrid, y_meshgrid, z_meshgrid = meshgrid.unbind(-1)
    y = torch.zeros_like(x_meshgrid)
    for i in range(num_meshgrid): 
        for j in range(num_meshgrid): 
            y[i, j] = target_function(meshgrid[i, j])
    color_meshgrid = y.squeeze(-1)

    # 2. Plot a sphere colored by y
    sphere_trace = go.Surface(x=x_meshgrid, y=y_meshgrid, z=z_meshgrid, surfacecolor=color_meshgrid, colorscale='Viridis', showscale=True)
    return sphere_trace


def plot_traces(*traces): 
    fig = go.Figure()
    fig.add_traces([trace for trace in traces if trace is not None])
    return fig


def plot_observations(x, y=None, num_meshgrid=50): 
    # 1. Plot a grey sphere with meshgrid
    grey_sphere_trace = base_trace(num_meshgrid=num_meshgrid)

    # 2. Plot a scatter plot of observations slightly above the sphere
    scatter_trace = observations_trace(x, y=None)

    # 3. Add all traces to figure
    fig = plot_traces(grey_sphere_trace, scatter_trace)
    return fig


def plot_prediction(prediction_function, x=None, y=None, num_meshgrid=50):
    # 1. Plot a sphere colored by prediction function 
    sphere_trace = prediction_trace(prediction_function, num_meshgrid=num_meshgrid)

    # 2. (Optional) Plot observations slightly above the sphere
    if x is not None and y is not None: 
        scatter_trace = observations_trace(x, y)
    else: 
        scatter_trace = None

    # 3. Add all traces to figure
    fig = plot_traces(sphere_trace, scatter_trace)
    return fig

def plot_target(target_function, num_meshgrid=50): 
    # 1. Plot a sphere colored by target function
    sphere_trace = target_trace(target_function, num_meshgrid=num_meshgrid)

    # 2. Add all traces to figure
    fig = plot_traces(sphere_trace)
    return fig 

### GPU acceleration would only be potentially useful for deep models; however, we might want it for exact models too, since a BO run might begin with exact and transition to deep 
We will certainly need to move the inputs and models to the chosen device 
We need also set the default device for the LAB backend 

In [3]:
def run_bo(initial_data, target_function, bo_args: BOArguments, model_args: ModelArguments, fit_args: FitArguments, loggers=None, show_fit_progress=False):
    warnings.filterwarnings("ignore")

    x = initial_data
    y = target_function(initial_data)
    best_x = initial_data[y.argmin()]
    best_y = y.min()
         
    pbar = tqdm(range(bo_args.num_iter), desc="BO")
    for iter in pbar: 
        # 1. Create model, mll, and optimizer 
        inducing_points = x
        model = create_model(model_args=model_args, train_x=x, train_y=y, inducing_points=inducing_points)
        optimizer = model_args.optimizer_factory(model=model.base_model, lr=fit_args.lr)
        mll = model_args.mll_factory(model.base_model, y=y)

        # 2. Fit model to observations  
        if model_args.model_name == 'exact':
            fit_gpytorch_mll(mll=mll)
        elif model_args.model_name == 'deep':
            print("Fitting...", end='\r')
            fit(model=model.base_model, optimizer=optimizer, criterion=mll, train_inputs=x, train_targets=y, fit_args=fit_args, show_progress=show_fit_progress)
            print("Fitted      ", end='\r')
        else:
            raise ValueError(f"Unknown model name: {model_args.model_name}")

        def get_mean(x): 
            with torch.no_grad(): 
                if model_args.model_name == 'exact':
                    model.eval()
                    return model(x).mean
                elif model_args.model_name == 'deep':
                    return model.base_model(x, mean=True)
        if iter % 5 == 0:
            plot_prediction(get_mean, x=x, y=y, num_meshgrid=100).show()

        # 3. Get acquisition function for the fitted model 
        acq_function = model_args.acqf_factory(model=model, best_f=best_y)

        # 3. Get new observation 
        with torch.no_grad():
            model.resample_weights()
            new_x, _ = optimize_acqf_manifold(acq_function=acq_function, bo_args=bo_args)

        # 4. Observe target function at acquired point and add to previous observations 
        new_x = new_x.unsqueeze(-2)
        new_y = target_function(new_x)#.squeeze(0)

        print(new_x, new_y)

        x = torch.cat([x, new_x])
        y = torch.cat([y, new_y])

        # 5. Update best observation
        if new_y < best_y: 
            best_y = new_y 
            best_x = new_x.squeeze()

        # 6. Log best observation
        metrics = dict(
            best_x=best_x.tolist(), 
            best_y=best_y.item(),
        )
        log(loggers=loggers, metrics=metrics)
        pbar.set_postfix(metrics)

In [4]:

def run_experiment(experiment_config: ExperimentConfig, dir_path: str, show_fit_progress: bool = False):
    print(f"Running experiment with the config: {os.path.join(dir_path, experiment_config.file_name)}")
    # 0. Unpack arguments
    model_args, data_args, fit_args, bo_args = (
        experiment_config.model_arguments, experiment_config.data_arguments, experiment_config.fit_arguments, experiment_config.bo_arguments
    )

    # 1. Get initial data and target function. Target function is observed at input points acquired via BO 
    print("Creating initial observations..")
    target_function = data_args.target_function
    initial_data = get_initial_data(data_args=data_args)

    # 2. Set up logger for capturing points and observations acquired via BO
    bo_loggers = [CSVLogger(root_dir=os.path.join(dir_path, 'bo'))]

    # 3. Run BO loop
    print("Running Bayesian optimisation..")
    run_bo(initial_data=initial_data, target_function=target_function, bo_args=bo_args, model_args=model_args, fit_args=fit_args, loggers=bo_loggers, show_fit_progress=show_fit_progress)
    finalize(bo_loggers)
    
    print("Done!")

In [9]:
experiment_config = ExperimentConfig()
experiment_config.model_arguments.model_name = 'deep'
experiment_config.model_arguments.project_to_tangent = 'extrinsic'
experiment_config.data_arguments.target_function_name = 'dgp_sample'
experiment_config.fit_arguments.num_steps = 10
experiment_config.bo_arguments.optimizer_verbosity = 2

run_experiment(experiment_config, './', show_fit_progress=True)

Running experiment with the config: ./config.json
Creating initial observations..
Running Bayesian optimisation..


BO:   0%|          | 0/200 [00:00<?, ?it/s]

Fitting...



Fitted      

Optimizing...
Iteration    Cost                       Gradient norm     
---------    -----------------------    --------------    
  1          +1.8211986441326886e+00    1.04045534e-01    
  2          +1.8076031759657774e+00    1.49881009e-01    
  3          +1.7751813347227845e+00    2.43042255e-02    
  4          +1.7751514944255553e+00    3.14778882e-02    
  5          +1.7750377633772747e+00    2.85003746e-02    
  6          +1.7746856966813020e+00    1.55071790e-02    
  7          +1.7745509151067640e+00    3.78137785e-03    
  8          +1.7745508348165824e+00    3.80417620e-03    
  9          +1.7745505166861331e+00    3.73238763e-03    
 10          +1.7745492935928457e+00    3.44242497e-03    
 11          +1.7745452605491754e+00    2.23308800e-03    
 12          +1.7745427493696782e+00    8.50851694e-04    
 13          +1.7745424203571507e+00    4.09398447e-04    
 14          +1.7745423659290864e+00    2.76189964e-04    
 15          +1.7745423232826085e+00    6.

BO:   0%|          | 0/200 [01:04<?, ?it/s]


KeyboardInterrupt: 