# Structured Bayesian Optimisation with Pyro (2D Eggholder function)

In [None]:
import numpy as np
import torch
import pyro
import math
import time

import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd

from torch.distributions import constraints, transform_to

import pyro.distributions as dist
import pyro.contrib.gp as gp

import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d

seed_number = 555

In [None]:
torch.manual_seed(seed_number)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed_number)

In [None]:
print(torch.__version__)
print(pyro.__version__)

# Plot function

In [None]:
def plot_function(f, x1_min, x1_max, x2_min, x2_max, n_points=100, ticks=9):
    XX, YY = np.meshgrid(np.linspace(x1_min, x1_max, n_points), np.linspace(x2_min, x2_max, n_points))
    ZZ = f(torch.FloatTensor(np.stack([XX.ravel(), YY.ravel()]).T))
    plt.imshow(ZZ.reshape(n_points, n_points))
    plt.xticks(np.linspace(0, n_points, ticks), np.linspace(x1_min, x1_max, ticks))
    plt.yticks(np.linspace(0, n_points, ticks), np.linspace(x2_min, x2_max, ticks)) 
    plt.xlabel('x1')
    plt.ylabel('x2')
    plt.set_cmap('jet')
    plt.colorbar()

# Objective function 

Using the eggholder function as a _challenging_ example (https://www.sfu.ca/~ssurjano/egg.html)

In [None]:
def eggholder(x):
    x1 = x[...,0]
    x2 = x[...,1]
    
    return -((x2 + 47) * torch.sin(torch.sqrt(torch.abs(x2 + x1/2 + 47))) - 
             x1 * torch.sin(torch.sqrt(torch.abs(x1 - x2 + 47))))

# Global minimum at 512, 404.2319 with value of -959.6407

# Initial (training) data

In [None]:
N_points = 200

X = torch.rand(N_points, 2)*1024 + torch.FloatTensor([-512, -512])

y = eggholder(X)

# Visualisation

In [None]:
const_x1_min = -512
const_x2_min = -512
const_x1_max = 512
const_x2_max = 512

steps = 1000
strides = 100

X1 = torch.linspace(const_x1_min, const_x1_max, steps)
X2 = torch.linspace(const_x2_min, const_x2_max, steps)

X1_mesh, X2_mesh = torch.meshgrid(X1, X2)

Z_mesh = eggholder(torch.stack((X1_mesh, X2_mesh), dim=2))

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.contour3D(X1_mesh, X2_mesh, Z_mesh, 50)
ax.set_xlabel('x1')
ax.set_ylabel('x2')
ax.set_zlabel('y');

In [None]:
plt.contour(X1_mesh, X2_mesh, Z_mesh, strides);

# Parametric model

In [None]:
def parametric_fn(X, alpha, beta, gamma, delta):
    x1 = X[...,0]
    x2 = X[...,1]
    
#     return (gamma*torch.sqrt(torch.abs(x1))*torch.cos(alpha*torch.sqrt(torch.abs(x1 + x2)))
#             + delta*torch.sqrt(torch.abs(x2))*torch.sin(beta*torch.sqrt(torch.abs(x1 - x2))))
    
    
    return -((x2*alpha )*torch.cos(torch.sqrt(torch.abs(x2 + x1 - gamma)))
            + (x1*beta )*torch.cos(torch.sqrt(torch.abs(x2 - x1 - delta))))
  
    
#    return (x1*torch.cos(2*alpha*math.pi*x1) + x2*torch.cos(2*beta*math.pi*x2))

def parametric_prior():
    
    alpha = pyro.sample('alpha', dist.Uniform(-10, 10))
    beta = pyro.sample('beta', dist.Uniform(-10, 10))
    gamma = pyro.sample('gamma', dist.Uniform(-10, 10))
    delta = pyro.sample('delta', dist.Uniform(-10, 20))
    
    return alpha, beta, gamma, delta

## Visualising parametric function

In [None]:
plt.figure(figsize=(12,3))
for i in range(3):
    plt.subplot(1,3,i+1)
    plot_function(lambda x: parametric_fn(x, *parametric_prior()), 
                  const_x1_min, const_x1_max, const_x2_min, const_x2_max)
plt.tight_layout();

## Inference of the parametric model

In [None]:
def guide(*args):
    mu_a = pyro.param('mu_a', torch.tensor(0.5), constraint=constraints.interval(-10, 10))
    mu_b = pyro.param('mu_b', torch.tensor(0.5), constraint=constraints.interval(-10, 10))
    mu_c = pyro.param('mu_c', torch.tensor(1.0), constraint=constraints.interval(-10, 10))
    mu_d = pyro.param('mu_d', torch.tensor(1.0), constraint=constraints.interval(-10, 20))
    
    sd_a = pyro.param('sd_a', torch.tensor(0.2), constraint=constraints.positive)
    sd_b = pyro.param('sd_b', torch.tensor(0.2), constraint=constraints.positive)
    sd_c = pyro.param('sd_c', torch.tensor(1.0), constraint=constraints.positive)
    sd_d = pyro.param('sd_d', torch.tensor(1.0), constraint=constraints.positive)

    alpha = pyro.sample('alpha', dist.Normal(mu_a, sd_a))
    beta = pyro.sample('beta', dist.Normal(mu_b, sd_b))
    gamma = pyro.sample('gamma', dist.Normal(mu_c, sd_c))
    delta = pyro.sample('delta', dist.Normal(mu_d, sd_d))
    
    return alpha, beta, gamma, delta

def model_parametric(X, y):
    g = parametric_fn(X, *parametric_prior())
    pyro.sample('f', dist.Normal(g, 0.1).independent(1), obs=y)

In [None]:
pyro.clear_param_store()
svi = pyro.infer.SVI(model_parametric, guide, pyro.optim.Adam({'lr': 0.005}), pyro.infer.Trace_ELBO())

In [None]:
%%time
losses = []
num_steps = 2000
for i in range(num_steps):
    losses.append(svi.step(X, y))

plt.semilogy(losses);

In [None]:
print("alpha ~ Normal(%0.2f, %0.2f)" % (pyro.param('mu_a'), pyro.param('sd_a')))
print("beta ~ Normal(%0.2f, %0.2f)" %  (pyro.param('mu_b'), pyro.param('sd_b')))
print("gamma ~ Normal(%0.4f, %0.4f)" %  (pyro.param('mu_c'), pyro.param('sd_c')))
print("delta ~ Normal(%0.4f, %0.4f)" %  (pyro.param('mu_d'), pyro.param('sd_d')))

In [None]:
with torch.no_grad():
    plt.figure(figsize=(12,3))
    for i in range(3):
        plt.subplot(1,3,i+1)
        plot_function(lambda x: parametric_fn(x, *guide()), 
                      const_x1_min, const_x1_max, const_x2_min, const_x2_max)
    plt.tight_layout();

# Semi-parametric model

In [None]:
pyro.clear_param_store()

class SemiParametricModel(nn.Module):

    def __init__(self, X, y):
        super().__init__()
        
        # Store data
        D = X.shape[-1]
        self.X = X
        self.y = y

        # Define parameters for parametric model
        # TODO: I couldn't figure out how to do this using `pyro.param`, so instead
        #       I am using `nn.Parameter`. This is annoying, because now constraints
        #       need to be handled manually, using the properties below
        self._mu_a = nn.Parameter(torch.zeros(1))
        self._mu_b = nn.Parameter(torch.zeros(1))
        self._mu_c = nn.Parameter(torch.zeros(1))
        self._mu_d = nn.Parameter(torch.zeros(1))
        
        self._sd_a = nn.Parameter(torch.zeros(1))
        self._sd_b = nn.Parameter(torch.zeros(1))
        self._sd_c = nn.Parameter(torch.zeros(1))
        self._sd_d = nn.Parameter(torch.zeros(1))
        
        self._mu_transform = transform_to(constraints.interval(-10, 20))
        self._sd_transform = transform_to(constraints.positive)

        # Define GP regressor (leave the data arguments empty for now)
        self.gp = gp.models.GPRegression(torch.empty((0, D)), torch.empty((0,)),
                                         kernel=gp.kernels.Matern52(input_dim=D, lengthscale=torch.ones(D)))

        #self.gp.kernel.set_prior("lengthscale", dist.Uniform(-100, 100.0).expand((2,)).to_event(1))
        #self.gp.kernel.set_prior("variance", dist.Uniform(0, 100.0))
        
        # Set priors for GP (these are the values used in the semiparametric BOAT model, which assumes noiseless GP)
        self.gp.kernel.set_prior("lengthscale", dist.LogNormal(1.0, 100.0).expand((2,)).to_event(1))
        self.gp.kernel.set_prior("variance", dist.Uniform(0.0, 10.0))
        self.gp.set_prior("noise", dist.Uniform(0.0, 1.0))

        # Set guides for GP
        self.gp.kernel.autoguide("lengthscale", dist.Normal)
        self.gp.kernel.autoguide("variance", dist.Normal)
        self.gp.autoguide("noise", dist.Normal)
    
    @property
    def mu_a(self): return self._mu_transform(self._mu_a)

    @property
    def mu_b(self): return self._mu_transform(self._mu_b)
    
    @property
    def mu_c(self): return self._mu_transform(self._mu_c)
    
    @property
    def mu_d(self): return self._mu_transform(self._mu_d)

    @property
    def sd_a(self): return self._sd_transform(self._sd_a)

    @property
    def sd_b(self): return self._sd_transform(self._sd_b)
    
    @property
    def sd_c(self): return self._sd_transform(self._sd_c)
    
    @property
    def sd_d(self): return self._sd_transform(self._sd_d)
    
    def guide(self):
        self.gp.guide()
        
        alpha = pyro.sample('alpha', dist.Normal(self.mu_a, self.sd_a))
        beta = pyro.sample('beta', dist.Normal(self.mu_b, self.sd_b))
        gamma = pyro.sample('gamma', dist.Normal(self.mu_c, self.sd_c))
        delta = pyro.sample('delta', dist.Normal(self.mu_d, self.sd_d))
        
        return alpha, beta, gamma, delta

    def model(self):
       
        alpha = pyro.sample('alpha', dist.Uniform(-10, 10))
        beta = pyro.sample('beta', dist.Uniform(-10, 10))
        gamma = pyro.sample('gamma', dist.Uniform(-10, 10))
        delta = pyro.sample('delta', dist.Uniform(-10, 20))
        
        g = parametric_fn(self.X, alpha, beta, gamma, delta)
        
        residual = self.y - g
        
        # update the GP to now model the residual from the parametric model
        self.gp.set_data(self.X, residual)
                
        # call GP model function to actually make the observation
        self.gp.model()
 
    def forward(self, X):
        g = parametric_fn(X, *self.guide())
        mu, sigma = self.gp(X)        
        return g + mu, sigma

semi_parametric = SemiParametricModel(X, y)

In [None]:
result = pyro.get_param_store()
result.get_all_param_names()

In [None]:
list(semi_parametric.parameters(recurse=True))

In [None]:
param_opt = torch.optim.Adam(semi_parametric.parameters(recurse=False), lr=0.1)
gp_opt = torch.optim.Adam(semi_parametric.gp.parameters(), lr=0.005)
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss

losses = []

num_steps = 5000
for i in range(num_steps):
    gp_opt.zero_grad()
    param_opt.zero_grad()
    loss = loss_fn(semi_parametric.model, semi_parametric.guide)
    loss.backward()
    gp_opt.step()
    param_opt.step()
    losses.append(loss.item())

plt.semilogy(losses);

In [None]:
print("alpha ~ Normal(%0.2f, %0.2f)" % (semi_parametric.mu_a.item(), semi_parametric.sd_a.item()))
print("beta ~ Normal(%0.2f, %0.2f)" % (semi_parametric.mu_b.item(), semi_parametric.sd_b.item()))
print("gamma ~ Normal(%0.4f, %0.4f)" % (semi_parametric.mu_c.item(), semi_parametric.sd_c.item()))
print("delta ~ Normal(%0.4f, %0.4f)" % (semi_parametric.mu_d.item(), semi_parametric.sd_d.item()))

In [None]:
sorted(list(semi_parametric.gp.named_buffers()))

In [None]:
for i in range(3):
    plt.figure(figsize=(8, 3))
    with torch.no_grad():
        
        plt.subplot(121)
        plt.title("GP mean")
        plot_function(lambda X: semi_parametric(X)[0], const_x1_min, const_x1_max, const_x2_min, const_x2_max)
        
        plt.subplot(122)
        plt.title("GP mean")
        plot_function(lambda X: semi_parametric(X)[0], const_x1_min, const_x1_max, const_x2_min, const_x2_max)
        plt.tight_layout();

In [None]:
plt.contour(X1_mesh, X2_mesh, Z_mesh, strides);

## Acquisition function

In [None]:
normal_phi = lambda x: torch.exp(-x.pow(2)/2)/np.sqrt(2*np.pi)
normal_Phi = lambda x: (1 + torch.erf(x / np.sqrt(2))) / 2

def expected_improvement(x):
    
    y_min = semi_parametric.gp.y.min()
    
    mu, variance = semi_parametric(x)
    #semi_parametric.gp(x, full_cov=False, noiseless=False)
    
    sigma = variance.sqrt()
    
    delta = y_min - mu
    
    EI = delta.clamp_min(0.0) + sigma*normal_phi(delta/sigma) - delta.abs()*normal_Phi(delta/sigma)
    
    return -EI

def acquisition_func(x):
 
    return expected_improvement(x)

In [None]:
def train(num_steps=1000):
    
    param_opt = torch.optim.Adam(semi_parametric.parameters(recurse=False), lr=0.1)
    gp_opt = torch.optim.Adam(semi_parametric.gp.parameters(), lr=0.005)
    
    loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
    
    losses = []
        
    for i in range(num_steps):
        gp_opt.zero_grad()
        param_opt.zero_grad()
        loss = loss_fn(semi_parametric.model, semi_parametric.guide)
        loss.backward()
        gp_opt.step()
        param_opt.step()
        losses.append(loss.item())
    
    return losses

def find_a_candidate(x_init):
    
    # Creating constrains
    constraint_x1 = constraints.interval(const_x1_min, const_x1_max)
    constraint_x2 = constraints.interval(const_x2_min, const_x2_max)
    
    # transform x_init to an unconstrained domain as we use an unconstrained optimizer
    unconstrained_x1_init = transform_to(constraint_x1).inv(x_init[:, 0])
    unconstrained_x2_init = transform_to(constraint_x2).inv(x_init[:, 1])
    x_uncon_init = torch.stack((unconstrained_x1_init, unconstrained_x2_init), dim=1)
    
    x_uncon = x_uncon_init.clone().detach().requires_grad_(True)
    
    # unconstrained minimiser
    minimizer = optim.LBFGS([x_uncon])

    def closure():
        minimizer.zero_grad()
                
        x1_tmp = transform_to(constraint_x1)(x_uncon[:, 0])
        x2_tmp = transform_to(constraint_x2)(x_uncon[:, 1])
        x = torch.stack((x1_tmp, x2_tmp), dim=1)
        
        y = acquisition_func(x)
        
        autograd.backward(x_uncon, autograd.grad(y, x_uncon))
                
        return y
    
    minimizer.step(closure)
   
    # after finding a candidate in the unconstrained domain,
    # convert it back to original domain.
    x1_tmp = transform_to(constraint_x1)(x_uncon[:, 0])
    x2_tmp = transform_to(constraint_x2)(x_uncon[:, 1])
    
    x = torch.stack((x1_tmp, x2_tmp), dim=1)
    
    return x.detach()
  
def next_x(num_candidates=5):
    
    candidates = []
    values = []
    
    # take the last point as the first attempt
    x_init = semi_parametric.X[-1:]
    
    for i in range(num_candidates):
        
        x = find_a_candidate(x_init)
        y = acquisition_func(x)
    
        candidates.append(x)
        values.append(y)
        
        # a new random attempt initial point
        x_init = torch.stack((
                x[:,0].new_empty(1).uniform_(const_x1_min, const_x1_max),
                x[:,1].new_empty(1).uniform_(const_x2_min, const_x2_max)), dim=1)
        
    argmin = torch.min(torch.cat(values), dim=0)[1].item()
        
    return candidates[argmin]

def update_posterior(x_new, viz_flag=False):
    
    # evaluate f at new point
    bh_y = eggholder(x_new) 
        
    # incorporate new evaluation
    semi_parametric.X = torch.cat([semi_parametric.X, x_new]) 
    semi_parametric.y = torch.cat([semi_parametric.y, bh_y])
    
    losses = train()
    
    if viz_flag:
        plot_model()
        
def plot_model():
    plt.figure(figsize=(12,3)) 
    
    # Acquisition function
    plt.subplot(1,4,1)
    plt.title("Acquisition")
    with torch.no_grad(): 
        plot_function(acquisition_func, const_x1_min, const_x1_max, const_x2_min, const_x2_max)

    # Losses
    plt.subplot(1,4,2)
    plt.title("Losses")
    plt.semilogy(losses);

    # Semi-param model mu
    plt.subplot(1,4,3)
    plt.title("Semi-param model $\mu$")
    with torch.no_grad(): 
        plot_function(lambda X: semi_parametric(X)[0], const_x1_min, const_x1_max, const_x2_min, const_x2_max)

    # Semi-param model sigma
    plt.subplot(1,4,4)
    plt.title("Semi-param model $\sigma$")
    with torch.no_grad(): 
        plot_function(lambda X: semi_parametric(X)[1], const_x1_min, const_x1_max, const_x2_min, const_x2_max)

    plt.tight_layout();

In [None]:
# flag to visualise steps
viz_flag = True

# Restarting the model
#pyro.clear_param_store()

#semi_parametric = SemiParametricModel(X, y)

time_st = time.time()

losses = train()

if viz_flag:
    plot_model()

sbo_steps = 2

for i in range(sbo_steps):
    
    xmin = next_x()
    
    print("Step SBO: ", i+1, "new point: ", xmin)
    
    update_posterior(xmin, viz_flag)
    
print("Time: ", time.time() - time_st)