# Hyperparameter Tuning

In this notebook we simulate using boax for a hyperparemter tuning problem of a SVM.

We will begin by defining the latent objective function we want to maximize and its bounds.

In [None]:
from jax import config

# Double precision is highly recommended.
config.update("jax_enable_x64", True)

from jax import jit
from jax import lax
from jax import nn
from jax import numpy as jnp
from jax import random
from jax import value_and_grad

import optax

from boax import prediction, optimization
from boax.core import distributions, samplers
from boax.prediction import kernels, likelihoods, means, models, objectives
from boax.optimization import acquisitions, optimizers

We use a two-dimensional synthetic objective funtion simulating the accuracy of a SVM.

In [None]:
bounds = jnp.array([[0.0, 2.0]] * 2)

In [None]:
def objective(x):
    return (
        jnp.sin(5 * x[..., 0] / 2 - 2.5) * jnp.cos(2.5 - 5 * x[..., 1])
        + (5 * x[..., 1] / 2 + 0.5) ** 2 / 10
    ) / 5 + 0.2

In [None]:
data_key, sampler_key, optimizer_key = random.split(random.key(0), 3)
x_train = random.uniform(data_key, minval=bounds[:, 0], maxval=bounds[:, 1], shape=(10, 2))
y_train = objective(x_train)

In [None]:
params = {
    'mean': jnp.zeros(()),
    'length_scale': jnp.zeros((2,)),
    'amplitude': jnp.zeros(()),
}

In [None]:
adam = optax.adam(0.01)

In [None]:
def fit(x_train, y_train):
    def model(params):
        return models.outcome_transformed(
            models.gaussian_process(
                means.constant(params['mean']),
                kernels.scaled(
                    kernels.matern_five_halves(params['length_scale']),
                    params['amplitude'],
                ),
            ),
            likelihoods.gaussian(1e-4)
        )

    def objective(params):
        return objectives.negative_log_likelihood(
            distributions.multivariate_normal.logpdf
        )

    def projection(params):
        return {
            'mean': params['mean'],
            'amplitude': nn.softplus(params['amplitude']),
            'length_scale': nn.softplus(params['length_scale']),
        }
        
    def step(state, iteration):
        loss_fn = prediction.construct(model, objective, projection)
        loss, grads = value_and_grad(loss_fn)(state[0], x_train, y_train)
        updates, opt_state = adam.update(grads, state[1])
        params = optax.apply_updates(state[0], updates)
        
        return (params, opt_state), loss
    
    (next_params, _), _ = lax.scan(
        jit(step),
        (params, adam.init(params)),
        jnp.arange(500)
    )

    return projection(next_params)

In [None]:
base_samples = jnp.reshape(
    samplers.halton_normal()(
        random.fold_in(sampler_key, 0),
        32 * 4,
    ),
    (32, 4)
)

In [None]:
x0 = jnp.reshape(
    samplers.halton_uniform(
        distributions.uniform.uniform(bounds[:, 0], bounds[:, 1])
    )(
        sampler_key,
        100 * 4,
    ),
    (100, 4, -1)
)

In [None]:
def optimize(x_train, y_train):
    def model(params):
        return models.sampled(
            models.outcome_transformed(
                models.gaussian_process_regression(
                    means.constant(params['mean']),
                    kernels.scaled(
                        kernels.matern_five_halves(params['length_scale']),
                        params['amplitude'],
                    ),
                )(
                    x_train,
                    y_train,
                ),
                likelihoods.gaussian(1e-4),
            ),
            distributions.multivariate_normal.sample,
            base_samples,
        )

    for i in range(5):
        params = fit(x_train, y_train)

        acqf = optimization.construct(
            model(params),
            acquisitions.q_probability_of_improvement(
                tau=1.0,
                best=jnp.argmax(y_train)
            ),
        )
        
        bfgs = optimizers.bfgs(acqf, bounds, x0, 40)
        candidates = bfgs.init(random.fold_in(optimizer_key, i))
        next_candidates, values = bfgs.update(candidates)

        next_x = next_candidates[jnp.argmax(values)]
        next_y = objective(next_x)
        
        x_train = jnp.vstack([x_train, next_x])
        y_train = jnp.hstack([y_train, next_y])

    return x_train, y_train

In [None]:
next_x_train, next_y_train = optimize(x_train, y_train)