# Getting Started

In this notebook we use boax to demonstrate a single step of a typical bayesion optimization process.

We will begin by defining the latent objective function we want to maximize, as well as the bounded search space that guides the optimization process.

In [1]:
from jax import config

# Double precision is highly recommended.
config.update("jax_enable_x64", True)

In [2]:
from functools import partial

from jax import jit
from jax import lax
from jax import numpy as jnp
from jax import random
from jax import scipy
from jax import value_and_grad
from jax import vmap

import optax
import matplotlib.pyplot as plt

from boax.prediction import bijectors, kernels, means, processes
from boax.optimization import acquisitions, maximizers, spaces

As our latent objective function we chose a sinusoid that we aim to maximize in the interval of [-3, 3].

In [3]:
space = spaces.continuous(jnp.array([[-3, 3]]))

In [4]:
def objective(x):
  return jnp.sin(4 * x[..., 0]) + jnp.cos(2 * x[..., 0])

To create the observation training data we sample random points from a uniform distribution, evaluate the objective functions at those points, and finish by adding gaussian noise.

In [5]:
sample_key, noise_key = random.split(random.key(0))
x_train = random.uniform(sample_key, minval=-3, maxval=3, shape=(10, 1))
y_train = objective(x_train) + 0.3 * random.normal(noise_key, shape=(10,))

## Fitting a Gaussian Process model to the data

With the observations in place, we can now focus on constructing a Gaussian Process model and fit it to the data. For this example we choose a simple setup of a constant zero mean function and a scaled RBF kernel. Note that we use the softplus bijector to constrain some of the models' hyperparameters to be strictly positive.

In [6]:
bijector = bijectors.softplus()

In [7]:
def process(params):
    return processes.gaussian(
        vmap(means.zero()),
        vmap(vmap(kernels.scale(bijector.forward(params['amplitude']), kernels.rbf(bijector.forward(params['length_scale']))), in_axes=(None, 0)), in_axes=(0, None)),
        bijector.forward(params['noise']),
    )

Next we initialise the models' hyperparameters, the optimizer, and fit the model to the observations.

In [8]:
params = {
  'amplitude': jnp.zeros(()),
  'length_scale': jnp.zeros(()),
  'noise': jnp.array(-5.),
}

In [10]:
optimizer = optax.adam(0.01)
opt_state = optimizer.init(params)

In [11]:
def train_step(state, iteration):
  def loss_fn(params):            
    loc, scale = process(params).prior(x_train)
    return -scipy.stats.multivariate_normal.logpdf(y_train, loc, scale)

  loss, grads = value_and_grad(loss_fn)(state[0])
  updates, opt_state = optimizer.update(grads, state[1])
  params = optax.apply_updates(state[0], updates)

  return (params, opt_state), loss

In [12]:
(next_params, next_opt_state), history = lax.scan(
    jit(train_step),
    (params, opt_state),
    jnp.arange(500)
)

## Constructing and optimizing an acquisition function

As a final step we use the posterior of our Gaussian Process model to construct the Upper Bound Confidence acquisition function which we maximize using the BFGS optimizer.

In [13]:
acqusition = acquisitions.upper_confidence_bound(
    2,
    partial(process(next_params).posterior, observation_index_points=x_train, observations=y_train)
)

In [14]:
candidates, scores = maximizers.bfgs(50)(acqusition, space)