# Constrained Airplane Design

In this notebook we simulate the design of an airplane's component using constrained bayesian optimsation.

We will begin by defining the latent objective function as well as the latent cost function.

In [1]:
from jax import config

# Double precision is highly recommended.
config.update("jax_enable_x64", True)

from jax import jit
from jax import lax
from jax import nn
from jax import numpy as jnp
from jax import random
from jax import value_and_grad

import optax

from boax.core import distributions, samplers
from boax.prediction import kernels, likelihoods, means, models
from boax.optimization import acquisitions, constraints, optimizers

In [2]:
bounds = jnp.array([[0.0, 1.0]] * 4)

In [3]:
def objective(x):
    next_x = x.at[:, [2, 3]].set(1 - x[:, [2, 3]]) * 10 - 5
    return -0.005 * jnp.sum(next_x**4 - 16 * next_x**2 + 5 * next_x, axis=-1) + 3

In [4]:
def cost(x):
    i = jnp.arange(2, 5)
    next_x = x * 20 - 10
    part1 = (next_x[..., 0] - 1)**2
    part2 = jnp.sum(i * (2 * next_x[..., 1:]**2 - next_x[..., :-1])**2, axis=-1)

    return -(part1 + part2) / 100_000 + 2

In [7]:
data_key, sample_key, optimization_key = random.split(random.key(0), 3)
x_train = random.uniform(random.key(0), minval=bounds[:, 0], maxval=bounds[:, 1], shape=(10, 4))
y_train = objective(x_train)
c_train = cost(x_train)

In [8]:
params = {
    'mean': jnp.zeros(()),
    'length_scale': jnp.zeros((4,)),
    'amplitude': jnp.zeros(()),
}

In [9]:
adam = optax.adam(0.01)

In [10]:
def fit(x_train, y_train):
    def model(mean, length_scale, amplitude):
        return models.predictive(
            models.gaussian_process(
                means.constant(mean),
                kernels.scaled(
                    kernels.matern_five_halves(nn.softplus(length_scale)),
                    nn.softplus(amplitude),
                ),
            ),
            likelihoods.gaussian(1e-4)
        )
    
    def target_log_prob(params):
        mvn = model(**params)(x_train)
        return -jnp.sum(distributions.multivariate_normal.logpdf(mvn, y_train))

    def train_step(state, iteration):
        loss, grads = value_and_grad(target_log_prob)(state[0])
        updates, opt_state = adam.update(grads, state[1])
        params = optax.apply_updates(state[0], updates)
        
        return (params, opt_state), loss
    
    return lax.scan(
        jit(train_step),
        (params, adam.init(params)),
        jnp.arange(500)
    )

In [11]:
q, num_samples, raw_samples = 1, 40, 100

In [12]:
num_queries = 5
num_iterations = num_queries // 1

In [13]:
x0 = jnp.reshape(
    samplers.halton_uniform(distributions.uniform.uniform(bounds[:, 0], bounds[:, 1]))(
        sample_key,
        raw_samples * q,
    ),
    (raw_samples, q, -1)
)

In [14]:
for i in range(num_queries):
    feasible = y_train[c_train <= 0]
    best = jnp.array(-2.) if not jnp.any(feasible) else jnp.max(feasible)

    (objective_params, _), _ = fit(x_train, y_train)
    (cost_params, _), _ = fit(x_train, c_train)

    acqf = optimizers.construct_log_constrained(
        models.joined(
            models.outcome_transformed(
                models.predictive(
                    models.gaussian_process_regression(
                        means.constant(objective_params['mean']),
                        kernels.scaled(
                            kernels.matern_five_halves(nn.softplus(objective_params['length_scale'])),
                            nn.softplus(objective_params['amplitude']),
                        ),
                    )(
                        x_train,
                        y_train,
                    ),
                    likelihoods.gaussian(1e-4)
                ),
                distributions.multivariate_normal.as_normal
            ),
            models.outcome_transformed(
                models.predictive(
                    models.gaussian_process_regression(
                        means.constant(cost_params['mean']),
                        kernels.scaled(
                            kernels.matern_five_halves(nn.softplus(cost_params['length_scale'])),
                            nn.softplus(cost_params['amplitude']),
                        ),
                    )(
                        x_train,
                        c_train
                    ),
                    likelihoods.gaussian(1e-4)
                ),
                distributions.multivariate_normal.as_normal,
            )
        ),
        acquisitions.log_expected_improvement(best),
        constraints.log_less_or_equal(0.0)
    )

    bfgs = optimizers.bfgs(
        acqf,
        bounds,
        x0,
        num_samples,
    )

    candidates = bfgs.init(random.fold_in(optimization_key, i))
    next_candidates, values = bfgs.update(candidates)
    
    next_x = next_candidates[jnp.argmax(values)]
    next_y = objective(next_x)
    next_c = cost(next_x)

    x_train = jnp.vstack([x_train, next_x])
    y_train = jnp.hstack([y_train, next_y])
    c_train = jnp.hstack([c_train, next_c])