In [1]:
from jax import config

# Double precision is highly recommended.
config.update("jax_enable_x64", True)

from jax import jit
from jax import lax
from jax import nn
from jax import numpy as jnp
from jax import random
from jax import scipy
from jax import value_and_grad

import optax
import matplotlib.pyplot as plt

from boax.prediction import kernels, means, models
from boax.optimization import acquisitions, constraints, maximizers, samplers

In [2]:
bounds = jnp.array([[0.0, 1.0]] * 4)

In [3]:
def objective(x):
    next_x = x.at[:, [2, 3]].set(1 - x[:, [2, 3]]) * 10 - 5
    return -0.005 * jnp.sum(next_x**4 - 16 * next_x**2 + 5 * next_x, axis=-1) + 3

In [52]:
def cost(x):
    i = jnp.arange(2, 5)
    next_x = x * 20 - 10
    part1 = (next_x[..., 0] - 1)**2
    part2 = jnp.sum(i * (2 * next_x[..., 1:]**2 - next_x[..., :-1])**2, axis=-1)

    return -(part1 + part2) / 100_000 + 2

In [53]:
params = {
    'mean': jnp.zeros(()),
    'length_scale': jnp.zeros((4,)),
    'amplitude': jnp.zeros(()),
}

In [54]:
optimizer = optax.adam(0.01)

In [55]:
def fit(x_train, y_train):
    def prior(mean, length_scale, amplitude):
        return models.gaussian_process(
            means.constant(mean),
            kernels.scaled(kernels.matern_five_halves(nn.softplus(length_scale)), nn.softplus(amplitude)),
            1e-4,
        )

    def posterior(mean, length_scale, amplitude):
        return models.gaussian_process_regression(
            x_train,
            y_train,
            means.constant(mean),
            kernels.scaled(kernels.matern_five_halves(nn.softplus(length_scale)), nn.softplus(amplitude)),
            1e-4,
        )
    
    def target_log_prob(params):
        mean, cov = prior(**params)(x_train)
        return -scipy.stats.multivariate_normal.logpdf(y_train, mean, cov)

    def train_step(state, iteration):
        loss, grads = value_and_grad(target_log_prob)(state[0])
        updates, opt_state = optimizer.update(grads, state[1])
        params = optax.apply_updates(state[0], updates)
        
        return (params, opt_state), loss
    
    (next_params, next_opt_state), history = lax.scan(
        jit(train_step),
        (params, optimizer.init(params)),
        jnp.arange(500)
    )

    return posterior(**next_params)

In [63]:
num_queries = 50

In [64]:
def score(key, surrogate, feasibility, best):    
    lei = acquisitions.log_expected_improvement(surrogate, best)
    lle = constraints.log_less_or_equal(feasibility, 0.0)
    acqf = acquisitions.log_constrained(lei, lle)
    
    maximizer = maximizers.bfgs(bounds, q=1, num_restarts=100, num_raw_samples=500)
    candidates = maximizer.init(key, acqf)
    next_candidates, values = maximizer.maximize(candidates, acqf)

    return next_candidates[jnp.argmax(values)]

In [65]:
data_key, optimization_key = random.split(random.key(0))
x_train = random.uniform(random.key(0), minval=bounds[:, 0], maxval=bounds[:, 1], shape=(4, 4))
y_train = objective(x_train)
c_train = cost(x_train)

In [66]:
for i in range(num_queries):
    print(i)
    feasible = y_train[c_train <= 0]
    
    surrogate = fit(x_train, y_train)
    feasibility = fit(x_train, c_train)

    best = jnp.array(-2.) if not jnp.any(feasible) else jnp.max(feasible)
    
    next_x = score(random.fold_in(optimization_key, i), surrogate, feasibility, best)
    next_y = objective(next_x)
    next_c = cost(next_x)

    x_train = jnp.vstack([x_train, next_x])
    y_train = jnp.hstack([y_train, next_y])
    c_train = jnp.hstack([c_train, next_c])

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
