# Getting Started

In this notebook we use boax to demonstrate a single step of a typical bayesion optimization process.

We will begin by defining the latent objective function we want to maximize and its bounds.

In [1]:
from jax import config

# Double precision is highly recommended.
config.update("jax_enable_x64", True)

from jax import jit
from jax import lax
from jax import nn
from jax import numpy as jnp
from jax import random
from jax import value_and_grad
from jax import vmap

import optax

from boax.core import distributions, samplers
from boax.prediction import kernels, likelihoods, means, models
from boax.optimization import acquisitions, maximizers

As our latent objective function we chose a sinusoid that we aim to maximize in the interval of [-3, 3].

In [2]:
bounds = jnp.array([[0.0, 1.0]])

In [3]:
def objective(x):
  return 1 - jnp.linalg.norm(x - 0.5)

To create the observation training data we sample random points from a uniform distribution, evaluate the objective functions at those points, and finish by adding gaussian noise.

In [4]:
data_key, maximizer_key = random.split(random.key(0))

x_train = random.uniform(random.fold_in(data_key, 0), minval=bounds[:, 0], maxval=bounds[:, 1], shape=(10, 1))
y_train = nn.standardize(objective(x_train) + 0.1 * random.normal(random.fold_in(data_key, 1), shape=(10,)))

## Fitting a Gaussian Process model to the data

With the observations in place, we can now focus on constructing a Gaussian Process model and fit it to the data. For this example we choose a simple setup of a constant zero mean function and a scaled RBF kernel. Note that we use the softplus function to constrain some of the models' hyperparameters to be strictly positive.

In [5]:
params = {
  'amplitude': jnp.zeros(()),
  'length_scale': jnp.zeros(()),
  'noise': jnp.array(-5.),
}

In [6]:
optimizer = optax.adam(0.01)

Next we initialise the models' hyperparameters, the optimizer, and fit the model to the observations.

In [7]:
def fit(x_train, y_train):
    def model(amplitude, length_scale, noise):
        return models.predictive(
            models.gaussian_process(
                means.zero(),
                kernels.scaled(
                    kernels.rbf(nn.softplus(length_scale)),
                    nn.softplus(amplitude)
                ),
            ),
            likelihoods.gaussian(nn.softplus(noise))
        )
    
    def target_log_prob(params):
        mvn = model(**params)(x_train)
        return -jnp.sum(distributions.multivariate_normal.logpdf(mvn, y_train))

    def train_step(state, iteration):
        loss, grads = value_and_grad(target_log_prob)(state[0])
        updates, opt_state = optimizer.update(grads, state[1])
        params = optax.apply_updates(state[0], updates)
        
        return (params, opt_state), loss
    
    return lax.scan(
        jit(train_step),
        (params, optimizer.init(params)),
        jnp.arange(500)
    )

In [8]:
(next_params, next_opt_state), history = fit(x_train, y_train)

In [None]:
surrogate = models.predictive(
    models.gaussian_process_regression(
        means.zero(),
        kernels.scaled(
            kernels.rbf(nn.softplus(next_params['length_scale'])),
            nn.softplus(next_params['amplitude'])
        ),
    )(
        x_train,
        y_train,
    ),
    likelihoods.gaussian(nn.softplus(next_params['noise']))
)

In [None]:
acqf = acquisitions.upper_confidence_bound(
    models.outcome_transformed(
        vmap(surrogate),
        vmap(distributions.multivariate_normal.as_normal),
    ),
    beta=2.0
)

In [None]:
maximizer = maximizers.bfgs(bounds, q=1, num_raw_samples=20, num_restarts=5)

In [None]:
candidates = maximizer.init(acqf, maximizer_key)

In [None]:
next_candidates, values = maximizer.update(acqf, candidates)

In [9]:
s, n, q, d = 32, 10, 3, 1

In [10]:
candidates = random.uniform(random.key(2), shape=(n, q, d))

In [11]:
candidates.shape

(10, 3, 1)

In [12]:
base_samples = random.normal(random.key(3), shape=(s, q))

In [13]:
base_samples.shape

(32, 3)

In [14]:
fantasy_index_points = random.uniform(random.key(4), shape=(n, s, 1, d))

In [15]:
fantasy_index_points.shape

(10, 32, 1, 1)

In [16]:
sampled = models.sampled(
    models.gaussian_process_regression(
        means.zero(),
        kernels.scaled(
            kernels.rbf(nn.softplus(next_params['length_scale'])),
            nn.softplus(next_params['amplitude'])
        ),
    )(
        x_train,
        y_train,
    ),
    distributions.multivariate_normal.sample,
    base_samples
)

In [17]:
observations = vmap(sampled)(candidates)

In [18]:
candidates.shape

(10, 3, 1)

In [19]:
observations.shape

(10, 32, 3)

In [20]:
fantasized = models.gaussian_process_fantasy(
    means.zero(),
    kernels.scaled(
        kernels.rbf(nn.softplus(next_params['length_scale'])),
        nn.softplus(next_params['amplitude'])
    ),
)(
    candidates,
    observations,
)

In [21]:
output = fantasized(fantasy_index_points)

In [22]:
output.mean.shape

(10, 32, 1)

In [23]:
result = models.outcome_transformed(
    fantasized,
    vmap(vmap(distributions.multivariate_normal.as_normal)),
)(fantasy_index_points)

In [25]:
result.loc.shape

(10, 32, 1)

In [None]:
acqf = acquisitions.posterior_mean()(result)

In [None]:
result = vmap(models.outcome_transformed(
    surrogate,
    distributions.multivariate_normal.as_normal,
))(candidates)

In [None]:
acquisitions.probability_of_improvement(0.0)(result)

In [None]:
result.loc.shape