# Constrained Airplane Design

In this notebook we simulate the design of an airplane's component using constrained bayesian optimsation.

We will begin by defining the latent objective function as well as the latent cost function.

In [1]:
from jax import config

# Double precision is highly recommended.
config.update("jax_enable_x64", True)

from jax import jit
from jax import lax
from jax import nn
from jax import numpy as jnp
from jax import random
from jax import value_and_grad
from jax import vmap

import optax

from boax.core import distributions, samplers
from boax.prediction import kernels, likelihoods, means, models
from boax.optimization import acquisitions, constraints, maximizers

In [2]:
bounds = jnp.array([[0.0, 1.0]] * 4)

In [3]:
def objective(x):
    next_x = x.at[:, [2, 3]].set(1 - x[:, [2, 3]]) * 10 - 5
    return -0.005 * jnp.sum(next_x**4 - 16 * next_x**2 + 5 * next_x, axis=-1) + 3

In [4]:
def cost(x):
    i = jnp.arange(2, 5)
    next_x = x * 20 - 10
    part1 = (next_x[..., 0] - 1)**2
    part2 = jnp.sum(i * (2 * next_x[..., 1:]**2 - next_x[..., :-1])**2, axis=-1)

    return -(part1 + part2) / 100_000 + 2

In [5]:
params = {
    'mean': jnp.zeros(()),
    'length_scale': jnp.zeros((4,)),
    'amplitude': jnp.zeros(()),
}

In [6]:
optimizer = optax.adam(0.01)

In [7]:
def fit(x_train, y_train):
    def prior(mean, length_scale, amplitude):
        return models.predictive(
            models.gaussian_process(
                means.constant(mean),
                kernels.scaled(kernels.matern_five_halves(nn.softplus(length_scale)), nn.softplus(amplitude)),
            ),
            likelihoods.gaussian(1e-4)
        )

    def posterior(mean, length_scale, amplitude):
        return models.predictive(
            models.gaussian_process_regression(
                means.constant(mean),
                kernels.scaled(kernels.matern_five_halves(nn.softplus(length_scale)), nn.softplus(amplitude)),
            )(
                x_train,
                y_train,
            ),
            likelihoods.gaussian(1e-4)
        )
    
    def target_log_prob(params):
        mvn = prior(**params)(x_train)
        return -jnp.sum(distributions.multivariate_normal.logpdf(mvn, y_train))

    def train_step(state, iteration):
        loss, grads = value_and_grad(target_log_prob)(state[0])
        updates, opt_state = optimizer.update(grads, state[1])
        params = optax.apply_updates(state[0], updates)
        
        return (params, opt_state), loss
    
    (next_params, next_opt_state), history = lax.scan(
        jit(train_step),
        (params, optimizer.init(params)),
        jnp.arange(500)
    )

    return posterior(**next_params)

In [8]:
num_queries = 5
batch_size = 1
num_raw_samples = 500
num_restarts = 100

In [9]:
data_key, optimization_key = random.split(random.key(0))
x_train = random.uniform(random.key(0), minval=bounds[:, 0], maxval=bounds[:, 1], shape=(batch_size, 4))
y_train = objective(x_train)
c_train = cost(x_train)

In [10]:
for i in range(num_queries):
    feasible = y_train[c_train <= 0]
    best = jnp.array(-2.) if not jnp.any(feasible) else jnp.max(feasible)

    maximizer = maximizers.bfgs(
        bounds,
        q=batch_size,
        num_raw_samples=num_raw_samples,
        num_restarts=num_restarts
    )

    surrogate = fit(x_train, y_train)
    lei = acquisitions.log_expected_improvement(
        models.outcome_transformed(
            vmap(surrogate),
            vmap(distributions.multivariate_normal.as_normal)
        ),
        best=best
    )

    feasibility = fit(x_train, c_train)
    lle = constraints.log_less_or_equal(
        models.outcome_transformed(
            vmap(feasibility),
            vmap(distributions.multivariate_normal.as_normal)
        ),
        bound=0.0
    )
    
    acqf = acquisitions.log_constrained(lei, lle)

    candidates = maximizer.init(acqf, random.fold_in(optimization_key, i))
    next_candidates, values = maximizer.update(acqf, candidates)
    
    next_x = next_candidates[jnp.argmax(values)]
    next_y = objective(next_x)
    next_c = cost(next_x)

    x_train = jnp.vstack([x_train, next_x])
    y_train = jnp.hstack([y_train, next_y])
    c_train = jnp.hstack([c_train, next_c])