# Fitting with Priors

In [None]:
from jax import config

# Double precision is highly recommended.
config.update("jax_enable_x64", True)

from jax import jit
from jax import lax
from jax import nn
from jax import numpy as jnp
from jax import random
from jax import value_and_grad

import optax

from boax import prediction, optimization
from boax.core import distributions, samplers
from boax.prediction import kernels, likelihoods, means, models, objectives
from boax.optimization import acquisitions, optimizers

As our latent objective function we chose a sinusoid that we aim to maximize in the interval of [-3, 3].

In [None]:
bounds = jnp.array([[0.0, 1.0]])

In [None]:
def objective(x):
    return 1 - jnp.linalg.norm(x - 0.5)

To create the observation training data we sample random points from a uniform distribution, evaluate the objective functions at those points, and finish by adding gaussian noise.

In [None]:
data_key, sampler_key, optimizer_key = random.split(random.key(0), 3)

x_train = random.uniform(
    random.fold_in(data_key, 0),
    minval=bounds[:, 0],
    maxval=bounds[:, 1],
    shape=(10, 1)
)

y_train = objective(x_train) + 0.1 * random.normal(
    random.fold_in(data_key, 1),
    shape=(10,)
)

## Fitting a Gaussian Process model to the data

With the observations in place, we can now focus on constructing a Gaussian Process model and fit it to the data. For this example we choose a simple setup of a constant zero mean function and a scaled RBF kernel. Note that we use the softplus function to constrain some of the models' hyperparameters to be strictly positive.

In [None]:
params = {
    'amplitude': jnp.zeros(()),
    'length_scale': jnp.zeros(()),
    'noise': jnp.zeros(()),
}

In [None]:
adam = optax.adam(0.01)

In [None]:
def fit(x_train, y_train):
    def model(params):
        return models.outcome_transformed(
            models.gaussian_process(
                means.zero(),
                kernels.scaled(
                    kernels.rbf(params['amplitude']),
                    params['length_scale'],
                ),
            ),
            likelihoods.gaussian(params['noise']),
        )

    def objective(params):
        return objectives.penalized(
            objectives.negative_log_likelihood(
                distributions.multivariate_normal.logpdf
            ),
            jnp.sum(
                distributions.gamma.logpdf(
                    distributions.gamma.gamma(2.0, 0.15),
                    params['amplitude'],
                )
            ),
            jnp.sum(
                distributions.gamma.logpdf(
                    distributions.gamma.gamma(3.0, 6.0),
                    params['length_scale'],
                )
            ),
            jnp.sum(
                distributions.gamma.logpdf(
                    distributions.gamma.gamma(1.1, 0.05),
                    params['noise'],
                )
            ),
        )

    def projection(params):
        return {
            'amplitude': nn.softplus(params['amplitude']),
            'length_scale': nn.softplus(params['length_scale']),
            'noise': nn.softplus(params['noise']) + 1e-4,
        }

    def step(state, iteration):
        loss_fn = prediction.construct(model, objective, projection)
        loss, grads = value_and_grad(loss_fn)(state[0], x_train, y_train)
        updates, opt_state = adam.update(grads, state[1])
        params = optax.apply_updates(state[0], updates)
        
        return (params, opt_state), loss
    
    (next_params, _), _ = lax.scan(
        jit(step),
        (params, adam.init(params)),
        jnp.arange(500)
    )

    return projection(next_params)

## Constructing and optimizing an acquisition functions

In [None]:
x0 = jnp.reshape(
    samplers.halton_uniform(
        distributions.uniform.uniform(bounds[:, 0], bounds[:, 1])
    )(
        sampler_key,
        100,
    ),
    (100, 1, -1)
)

In [None]:
def optimize(x_train, y_train):
    def model(params):
        return models.outcome_transformed(
            models.gaussian_process_regression(
                means.zero(),
                kernels.scaled(
                    kernels.rbf(params['amplitude']),
                    params['length_scale']
                )
            )(
                x_train,
                y_train,
            ),
            likelihoods.gaussian(params['noise']),
            distributions.multivariate_normal.as_normal,
        )

    for i in range(10):
        params = fit(x_train, nn.standardize(y_train))

        acqf = optimization.construct(
            model(params),
            acquisitions.upper_confidence_bound(2.0),
        )
        
        bfgs = optimizers.bfgs(acqf, bounds, x0, 10)
        candidates = bfgs.init(random.fold_in(optimizer_key, i))
        next_candidates, values = bfgs.update(candidates)

        next_x = next_candidates[jnp.argmax(values)]
        next_y = objective(next_x)
        
        x_train = jnp.vstack([x_train, next_x])
        y_train = jnp.hstack([y_train, next_y])

    return x_train, y_train

In [None]:
next_x_train, next_y_train = optimize(x_train, y_train)