# Hyperparameter Tuning

In this notebook we simulate using boax for a hyperparemter tuning problem of a SVM.

We will begin by defining the latent objective function we want to maximize and its bounds.

In [1]:
from jax import config

# Double precision is highly recommended.
config.update("jax_enable_x64", True)

from jax import jit
from jax import lax
from jax import nn
from jax import numpy as jnp
from jax import random
from jax import value_and_grad

import optax

from boax.core import distributions, samplers
from boax.prediction import kernels, likelihoods, means, models
from boax.optimization import acquisitions, optimizers

We use a two-dimensional synthetic objective funtion simulating the accuracy of a SVM.

In [2]:
bounds = jnp.array([[0.0, 2.0]] * 2)

In [3]:
def objective(x):
    return (
        jnp.sin(5 * x[..., 0] / 2 - 2.5) * jnp.cos(2.5 - 5 * x[..., 1])
        + (5 * x[..., 1] / 2 + 0.5) ** 2 / 10
    ) / 5 + 0.2

In [4]:
data_key, sampler_key, optimization_key = random.split(random.key(0), 3)
x_train = random.uniform(data_key, minval=bounds[:, 0], maxval=bounds[:, 1], shape=(10, 2))
y_train = objective(x_train)

In [5]:
params = {
    'mean': jnp.zeros(()),
    'length_scale': jnp.zeros((2,)),
    'amplitude': jnp.zeros(()),
}

In [6]:
adam = optax.adam(0.01)

In [16]:
def fit(x_train, y_train):
    def model(mean, length_scale, amplitude):
        return models.predictive(
            models.gaussian_process(
                means.constant(mean),
                kernels.scaled(
                    kernels.matern_five_halves(nn.softplus(length_scale)),
                    nn.softplus(amplitude),
                ),
            ),
            likelihoods.gaussian(1e-4)
        )
        
    def target_log_prob(params):
        mvn = model(**params)(x_train)
        return -jnp.sum(distributions.multivariate_normal.logpdf(mvn, y_train))

    def train_step(state, iteration):
        loss, grads = value_and_grad(target_log_prob)(state[0])
        updates, opt_state = adam.update(grads, state[1])
        params = optax.apply_updates(state[0], updates)
        
        return (params, opt_state), loss
    
    return lax.scan(
        jit(train_step),
        (params, adam.init(params)),
        jnp.arange(500)
    )

In [17]:
s, q, num_samples, raw_samples = 32, 4, 40, 100

In [18]:
num_queries = 20
num_iterations = num_queries // q

In [19]:
base_samples = jnp.reshape(
    samplers.halton_normal()(
        random.fold_in(sampler_key, 0),
        s * q,
    ),
    (s, q)
)

In [20]:
x0 = jnp.reshape(
    samplers.halton_uniform(distributions.uniform.uniform(bounds[:, 0], bounds[:, 1]))(
        random.fold_in(sampler_key, 1),
        raw_samples * q,
    ),
    (raw_samples, q, -1)
)

In [21]:
for i in range(num_iterations):    
    (next_params, _), _ = fit(x_train, y_train)

    acqf = optimizers.construct(
        models.sampled(
            models.predictive(
                models.gaussian_process_regression(
                    means.constant(next_params['mean']),
                    kernels.scaled(
                        kernels.matern_five_halves(nn.softplus(next_params['length_scale'])),
                        nn.softplus(next_params['amplitude'])
                    ),
                )(
                    x_train,
                    y_train,
                ),
                likelihoods.gaussian(1e-4)
            ),
            distributions.multivariate_normal.sample,
            base_samples,
        ),
        acquisitions.q_probability_of_improvement(
            tau=1.0,
            best=jnp.argmax(y_train)
        )
    )

    bfgs = optimizers.bfgs(
        acqf,
        bounds,
        x0,
        num_samples,
    )

    candidates = bfgs.init(random.fold_in(optimization_key, i))
    next_candidates, values = bfgs.update(candidates)
    
    next_x = next_candidates[jnp.argmax(values)]
    next_y = objective(next_x)

    x_train = jnp.vstack([x_train, next_x])
    y_train = jnp.hstack([y_train, next_y])