In [None]:
import sys
import matplotlib.pyplot as plt
import numpy as np

import torch

import pyro
import pyro.contrib.gp as gp
import pyro.distributions as dist
from pyro.nn import PyroSample, PyroModule
from pyro.infer import Trace_ELBO

sys.path.append("../../sbo")

import sbo
import utilities

## Target function

In [None]:
class BraninHooTarget(sbo.TargetFunction):
    
    def eval(self, x):
        
        """ Compute Branin-Hoo function for fixed constants """
        a = 1.0
        b = 5.1 / (4 * np.pi**2)
        c = 5.0 / np.pi
        r = 6.0
        s = 10.0
        t = 1.0 / (8 * np.pi)
        x1 = x[...,0]
        x2 = x[...,1]
        
        return a * (x2 - b*x1**2 + c*x1 - r)**2 + s*(1 - t)*torch.cos(x1) + s

target = BraninHooTarget([[-5, 10], [0, 15]])

In [None]:
utilities.plot2d_func(target.eval, target.ranges)
plt.savefig('bh_landscape.png', dpi=300)

## Training data

Starting point is in the middle of the search space.

In [None]:
utilities.set_random_seed(556)

In [None]:
N_train_points = 1

# Random points
# X_train = torch.rand(N_train_points, 2)*15 + torch.FloatTensor([target.ranges[0][0], target.ranges[1][0]])

# Middle point
X_train = torch.FloatTensor([[(target.ranges[0][1] + target.ranges[0][0])/2.0, 
                   (target.ranges[1][1] + target.ranges[1][0])/2.0]])

y_train = target.eval(X_train)

In [None]:
utilities.plot2d_func(target.eval, target.ranges)
plt.scatter(X_train[:,0].detach().numpy(), X_train[:,1].detach().numpy(), 
            marker="x", s=200, c='orange', zorder=2, linewidth=4);

## Parametric model

In [None]:
class ParametricMeanFn(PyroModule):
    
    def __init__(self):
        super().__init__()
        
        self.alpha = PyroSample(dist.Uniform(0, 20))
        self.beta = PyroSample(dist.Uniform(0, 20))
        self.gamma = PyroSample(dist.Uniform(0, 20))
        
    def forward(self, X):

        x1 = X[...,0]
        x2 = X[...,1]
        return self.alpha*torch.cos(x1) + self.beta*torch.pow(x1, 4) + torch.pow(x2, 2) + self.gamma

## SBO

In [None]:
pyro.clear_param_store()

# SBO parameters
sbo_steps = 20

opti_num_steps = 1000
opti_params = {"lr": 0.1}
optimizer = pyro.optim.Adam(opti_params)
loss = Trace_ELBO()
jitter = 1e-04
noise = 0.1*torch.ones(len(target.ranges))

return_site = "EI"

# Acquisition function optimizer
acqf_optimizer = torch.optim.Adam
acqf_opti_num_steps = 1000
acqf_opti_lr = 0.1

num_candidates = 10

# Sampling
num_samples = 5

In [None]:
# GP kernel
kernel = gp.kernels.Matern52(
    input_dim = X_train.shape[1], lengthscale = 10*torch.ones(X_train.shape[1]))

# Semi paramteric GP model
model = sbo.SemiParametricModel(X_train, y_train, ParametricMeanFn(), kernel, 
                                noise=noise, jitter=jitter)

# Defining GP's parameters
model.gp.kernel.lengthscale = PyroSample(dist.Uniform(0, 5).expand([X_train.shape[1]]).to_event())
model.gp.kernel.variance = PyroSample(dist.Uniform(0, 10))
model.gp.noise = PyroSample(dist.Uniform(0, 1))

guide = None

In [None]:
for i in range(sbo_steps):
    
    if i == 0:
        opti_num_steps_i = 10000
    else:
        opti_num_steps_i = opti_num_steps
    
    guide, predict, losses = sbo.step(model, guide, optimizer, loss, target, acqf_optimizer, 
                                      opti_num_steps=opti_num_steps_i, acqf_opti_num_steps=acqf_opti_num_steps,
                                      acqf_opti_lr=acqf_opti_lr, num_samples=num_samples, 
                                      num_candidates=num_candidates,
                                      return_site=return_site)
            
    ######################
    # Visualising the step
    ######################
    if i == 0:
        plt.figure(figsize=(5, 3))
        
        plt.title("Training")
        plt.semilogy(losses)
        
    else:
        plt.figure(figsize=(15, 3))
        
        plt.subplot(1,3,1)
        plt.title("mean (samples=%d)" % (num_samples))
        
        with torch.no_grad():
            utilities.plot2d_func(lambda x: predict(x)["y"].mean(0), target.ranges)
        
        plt.subplot(1,3,2)
        plt.title("Find x (samples=%d)" % (num_samples))
        
        with torch.no_grad():
            utilities.plot2d_func(lambda x: predict(x)[return_site].mean(0), target.ranges)
        
        plt.scatter(model.X[-1, 0].detach().numpy(), model.X[-1, 1].detach().numpy(), 
            marker="x", s=200, c='orange', zorder=2, linewidth=4);
        
        plt.subplot(1,3,3)
        plt.title("Training")
        plt.semilogy(losses)
    
    plt.show()
    
    print("SBO step: %d" % (i), model.X[-1])