In [1]:
from jax import config

# Double precision is highly recommended.
config.update("jax_enable_x64", True)

from functools import partial

from jax import jit
from jax import lax
from jax import nn
from jax import numpy as jnp
from jax import random
from jax import value_and_grad
from jax import vmap
from jax import scipy

import optax

from boax.core import distributions, samplers
from boax.prediction import kernels, likelihoods, means, models
from boax.optimization import acquisitions, maximizers

In [2]:
bounds = jnp.array([[0.0, 1.0]] * 2)

In [3]:
def objective(x):
    return jnp.sin(2 * jnp.pi * x[..., 0]) * jnp.cos(2 * jnp.pi * x[..., 1])

In [4]:
data_key, sample_key, maximizer_key = random.split(random.key(0), 3)
x_train = random.uniform(random.fold_in(data_key, 0), minval=bounds[:, 0], maxval=bounds[:, 1], shape=(20, 2))
y_train = nn.standardize(objective(x_train) + random.normal(random.fold_in(data_key, 1), shape=(20,)))

In [5]:
def prior(amplitude, length_scale, noise):
    return models.predictive(
        models.gaussian_process(
            means.zero(),
            kernels.scaled(kernels.rbf(nn.softplus(length_scale)), nn.softplus(amplitude)),
        ),
        likelihoods.gaussian(
            nn.softplus(noise),
        ),
    )

In [6]:
def posterior(amplitude, length_scale, noise):
    return models.predictive(
        models.gaussian_process_regression(
            means.zero(),
            kernels.scaled(kernels.rbf(nn.softplus(length_scale)), nn.softplus(amplitude)),
        )(
            x_train,
            y_train,
        ),
        likelihoods.gaussian(
            nn.softplus(noise),
        ),
    )

In [7]:
def target_log_prob(params):
    mvn = prior(**params)(x_train)
    return -distributions.multivariate_normal.logpdf(mvn, y_train)

In [8]:
params = {
  'amplitude': jnp.zeros(()),
  'length_scale': jnp.zeros(()),
  'noise': jnp.array(-5.),
}

In [9]:
optimizer = optax.adam(0.01)

In [10]:
def train_step(state, iteration):
    loss, grads = value_and_grad(target_log_prob)(state[0])
    updates, opt_state = optimizer.update(grads, state[1])
    params = optax.apply_updates(state[0], updates)

    return (params, opt_state), loss

In [11]:
(next_params, next_opt_state), history = lax.scan(
    jit(train_step),
    (params, optimizer.init(params)),
    jnp.arange(500)
)

In [12]:
s, n, q, d = 128, 10, 3, 2

In [13]:
base_samples = jnp.reshape(
    samplers.halton_normal()(random.fold_in(sample_key, 0), s * n * q),
    (s, n, q)
)

In [14]:
candidates = jnp.reshape(
    samplers.halton_uniform(distributions.uniform.uniform(jnp.zeros((2,)), jnp.ones((2,))))(random.fold_in(sample_key, 1), n * q),
    (n, q, -1)
)

In [15]:
fantasies = jnp.reshape(
    samplers.halton_uniform(distributions.uniform.uniform(jnp.zeros((2,)), jnp.ones((2,))))(random.fold_in(sample_key, 2), s * n),
    (s, n, 1, -1)
)

In [16]:
surrogate = models.predictive(
    models.fantasized(
        models.sampled(
            vmap(posterior(**next_params)),
            vmap(distributions.multivariate_normal.sample),
            base_samples
        ),
        models.gaussian_process_fantasy(
            means.zero(),
            kernels.scaled(
                kernels.rbf(nn.softplus(next_params['length_scale'])),
                nn.softplus(next_params['amplitude'])
            ),
        ),
        fantasies
    ),
    likelihoods.gaussian(
        next_params['noise']
    )
)

In [19]:
acqf = acquisitions.q_knowledge_gradient(
    models.outcome_transformed(
        surrogate,
        vmap(vmap(distributions.multivariate_normal.as_normal))
    ),
    best=0.0
)

In [22]:
values = acqf(candidates)