# Constrained Airplane Design

In this notebook we simulate the design of an airplane's component using constrained bayesian optimsation.

We will begin by defining the latent objective function as well as the latent cost function.

In [None]:
from jax import config

# Double precision is highly recommended.
config.update("jax_enable_x64", True)

from jax import jit
from jax import lax
from jax import nn
from jax import numpy as jnp
from jax import random
from jax import value_and_grad

import optax

from boax import prediction, optimization
from boax.core import distributions, samplers
from boax.prediction import kernels, likelihoods, means, models, objectives
from boax.optimization import acquisitions, constraints, optimizers

In [None]:
bounds = jnp.array([[0.0, 1.0]] * 4)

In [None]:
def objective(x):
    next_x = x.at[:, [2, 3]].set(1 - x[:, [2, 3]]) * 10 - 5
    return -0.005 * jnp.sum(next_x**4 - 16 * next_x**2 + 5 * next_x, axis=-1) + 3

In [None]:
def cost(x):
    i = jnp.arange(2, 5)
    next_x = x * 20 - 10
    part1 = (next_x[..., 0] - 1)**2
    part2 = jnp.sum(i * (2 * next_x[..., 1:]**2 - next_x[..., :-1])**2, axis=-1)

    return -(part1 + part2) / 100_000 + 2

In [None]:
data_key, sampler_key, optimizer_key = random.split(random.key(0), 3)
x_train = random.uniform(random.key(0), minval=bounds[:, 0], maxval=bounds[:, 1], shape=(10, 4))
y_train = objective(x_train)
c_train = cost(x_train)

In [None]:
params = {
    'mean': jnp.zeros(()),
    'length_scale': jnp.zeros((4,)),
    'amplitude': jnp.zeros(()),
}

In [None]:
adam = optax.adam(0.01)

In [None]:
def fit(x_train, y_train):
    def model(params):
        return models.outcome_transformed(
            models.gaussian_process(
                means.constant(params['mean']),
                kernels.scaled(
                    kernels.matern_five_halves(params['length_scale']),
                    params['amplitude'],
                ),
            ),
            likelihoods.gaussian(1e-4)
        )

    def objective(params):
        return objectives.negative_log_likelihood(
            distributions.multivariate_normal.logpdf
        )

    def projection(params):
        return {
            'mean': params['mean'],
            'amplitude': nn.softplus(params['amplitude']),
            'length_scale': nn.softplus(params['length_scale']),
        }
        
    def step(state, iteration):
        loss_fn = prediction.construct(model, objective, projection)
        loss, grads = value_and_grad(loss_fn)(state[0], x_train, y_train)
        updates, opt_state = adam.update(grads, state[1])
        params = optax.apply_updates(state[0], updates)
        
        return (params, opt_state), loss
    
    (next_params, _), _ = lax.scan(
        jit(step),
        (params, adam.init(params)),
        jnp.arange(500)
    )

    return projection(next_params)

In [None]:
x0 = jnp.reshape(
    samplers.halton_uniform(
        distributions.uniform.uniform(bounds[:, 0], bounds[:, 1])
    )(
        sampler_key,
        100 * 1,
    ),
    (100, 1, -1)
)

In [None]:
def optimize(x_train, y_train, c_train):
    def model(params):
        return models.outcome_transformed(
            models.gaussian_process_regression(
                means.constant(params['mean']),
                kernels.scaled(
                    kernels.matern_five_halves(params['length_scale']),
                    params['amplitude'],
                ),
            )(
                x_train,
                y_train,
            ),
            likelihoods.gaussian(1e-4),
            distributions.multivariate_normal.as_normal,
        )

    for i in range(5):
        feasible = y_train[c_train <= 0]
        best = jnp.array(-2.) if not jnp.any(feasible) else jnp.max(feasible)
        
        objective_params = fit(x_train, y_train)
        cost_params = fit(x_train, c_train)

        acqf = optimization.construct_log_constrained(
            models.joined(
                model(objective_params),
                model(cost_params),
            ),
            acquisitions.log_expected_improvement(best),
            constraints.log_less_or_equal(0.0),
        )
        
        bfgs = optimizers.bfgs(acqf, bounds, x0, 40)
        candidates = bfgs.init(random.fold_in(optimizer_key, i))
        next_candidates, values = bfgs.update(candidates)

        next_x = next_candidates[jnp.argmax(values)]
        next_y = objective(next_x)
        next_c = cost(next_x)
        
        x_train = jnp.vstack([x_train, next_x])
        y_train = jnp.hstack([y_train, next_y])
        c_train = jnp.hstack([c_train, next_c])

    return x_train, y_train, c_train

In [None]:
next_x_train, next_y_train, next_c_train = optimize(x_train, y_train, c_train)