# Getting Started

In this notebook we use boax to demonstrate a single step of a typical bayesion optimization process.

We will begin by defining the latent objective function we want to maximize and its bounds.

In [1]:
from jax import config

# Double precision is highly recommended.
config.update("jax_enable_x64", True)

from jax import jit
from jax import lax
from jax import nn
from jax import numpy as jnp
from jax import random
from jax import value_and_grad

import optax

from boax.core import distributions, samplers
from boax.prediction import kernels, likelihoods, means, models
from boax.optimization import acquisitions, optimizers

As our latent objective function we chose a sinusoid that we aim to maximize in the interval of [-3, 3].

In [2]:
bounds = jnp.array([[0.0, 1.0]])

In [3]:
def objective(x):
  return 1 - jnp.linalg.norm(x - 0.5)

To create the observation training data we sample random points from a uniform distribution, evaluate the objective functions at those points, and finish by adding gaussian noise.

In [24]:
data_key, sampler_key, maximizer_key = random.split(random.key(0), 3)
x_train = random.uniform(random.fold_in(data_key, 0), minval=bounds[:, 0], maxval=bounds[:, 1], shape=(10, 1))
y_train = nn.standardize(objective(x_train) + 0.1 * random.normal(random.fold_in(data_key, 1), shape=(10,)))

## Fitting a Gaussian Process model to the data

With the observations in place, we can now focus on constructing a Gaussian Process model and fit it to the data. For this example we choose a simple setup of a constant zero mean function and a scaled RBF kernel. Note that we use the softplus function to constrain some of the models' hyperparameters to be strictly positive.

In [25]:
params = {
  'amplitude': jnp.zeros(()),
  'length_scale': jnp.zeros(()),
  'noise': jnp.array(-5.),
}

In [26]:
adam = optax.adam(0.01)

In [27]:
def fit(x_train, y_train):
    def model(amplitude, length_scale, noise):
        return models.predictive(
            models.gaussian_process(
                means.zero(),
                kernels.scaled(
                    kernels.rbf(nn.softplus(length_scale)),
                    nn.softplus(amplitude)
                ),
            ),
            likelihoods.gaussian(nn.softplus(noise))
        )
    
    def target_log_prob(params):
        mvn = model(**params)(x_train)
        return -jnp.sum(distributions.multivariate_normal.logpdf(mvn, y_train))

    def train_step(state, iteration):
        loss, grads = value_and_grad(target_log_prob)(state[0])
        updates, opt_state = adam.update(grads, state[1])
        params = optax.apply_updates(state[0], updates)
        
        return (params, opt_state), loss
    
    return lax.scan(
        jit(train_step),
        (params, adam.init(params)),
        jnp.arange(500)
    )

In [28]:
(next_params, next_opt_state), history = fit(x_train, y_train)

In [29]:
surrogate = models.predictive(
    models.gaussian_process_regression(
        means.zero(),
        kernels.scaled(
            kernels.rbf(nn.softplus(next_params['length_scale'])),
            nn.softplus(next_params['amplitude'])
        ),
    )(
        x_train,
        y_train,
    ),
    likelihoods.gaussian(nn.softplus(next_params['noise']))
)

## Constructing and optimizing an acquisition functions

In [30]:
x0 = jnp.reshape(
    samplers.halton_uniform(distributions.uniform.uniform(bounds[:, 0], bounds[:, 1]))(
        sampler_key,
        100,
    ),
    (100, 1, -1)
)

In [31]:
acqf = optimizers.construct(
    models.outcome_transformed(
        surrogate,
        distributions.multivariate_normal.as_normal
    ),
    acquisitions.upper_confidence_bound(2.0),
)

In [32]:
bfgs = optimizers.bfgs(acqf, bounds, x0, 10)

In [35]:
candidates = bfgs.init(maximizer_key)

In [36]:
next_candidates, values = bfgs.update(candidates)