# Getting Started

In this notebook we use boax to demonstrate a single step of a typical bayesion optimization process.

We will begin by defining the latent objective function we want to maximize and its bounds.

In [1]:
from jax import config

# Double precision is highly recommended.
config.update("jax_enable_x64", True)

from jax import jit
from jax import lax
from jax import nn
from jax import numpy as jnp
from jax import random
from jax import scipy
from jax import value_and_grad

import optax

from boax.prediction import kernels, means, models
from boax.optimization import acquisitions, maximizers

As our latent objective function we chose a sinusoid that we aim to maximize in the interval of [-3, 3].

In [2]:
bounds = jnp.array([[-3, 3]])

In [3]:
def objective(x):
  return jnp.sin(4 * x[..., 0]) + jnp.cos(2 * x[..., 0])

To create the observation training data we sample random points from a uniform distribution, evaluate the objective functions at those points, and finish by adding gaussian noise.

In [4]:
sample_key, noise_key, maximizer_key = random.split(random.key(0), 3)
x_train = random.uniform(sample_key, minval=bounds[0, 0], maxval=bounds[0, 1], shape=(10, 1))
y_train = objective(x_train) + 0.3 * random.normal(noise_key, shape=(10,))

## Fitting a Gaussian Process model to the data

With the observations in place, we can now focus on constructing a Gaussian Process model and fit it to the data. For this example we choose a simple setup of a constant zero mean function and a scaled RBF kernel. Note that we use the softplus function to constrain some of the models' hyperparameters to be strictly positive.

In [5]:
def prior(amplitude, length_scale, noise):
    return models.gaussian_process(
        means.zero(),
        kernels.scaled(kernels.rbf(nn.softplus(length_scale)), nn.softplus(amplitude)),
        nn.softplus(noise),
    )

In [6]:
def target_log_prob(params):
    mean, cov = prior(**params)(x_train)
    return -scipy.stats.multivariate_normal.logpdf(y_train, mean, cov)

Next we initialise the models' hyperparameters, the optimizer, and fit the model to the observations.

In [8]:
params = {
  'amplitude': jnp.zeros(()),
  'length_scale': jnp.zeros(()),
  'noise': jnp.array(-5.),
}

In [7]:
optimizer = optax.adam(0.01)

In [9]:
def train_step(state, iteration):
  loss, grads = value_and_grad(target_log_prob)(state[0])
  updates, opt_state = optimizer.update(grads, state[1])
  params = optax.apply_updates(state[0], updates)

  return (params, opt_state), loss

In [10]:
(next_params, next_opt_state), history = lax.scan(
    jit(train_step),
    (params, optimizer.init(params)),
    jnp.arange(500)
)

## Constructing and optimizing an acquisition function

As a final step we use the posterior of our Gaussian Process model to construct the Upper Bound Confidence acquisition function which we maximize using the BFGS optimizer.

In [12]:
surrogate = models.gaussian_process_regression(
    x_train,
    y_train,
    means.zero(),
    kernels.scaled(kernels.rbf(nn.softplus(next_params['length_scale'])), nn.softplus(next_params['amplitude'])),
    nn.softplus(next_params['noise']),
)

In [14]:
acqf = acquisitions.upper_confidence_bound(surrogate, beta=2.0)
maximizer = maximizers.bfgs(bounds, q=1, num_restarts=25, num_raw_samples=500)

In [16]:
candidates = maximizer.init(maximizer_key, acqf)
next_candidates, values = maximizer.maximize(candidates, acqf)