# Hyperparameter Tuning

In this notebook we simulate using boax for a hyperparemter tuning problem of a SVM.

We will begin by defining the latent objective function we want to maximize and its bounds.

In [1]:
from jax import config

# Double precision is highly recommended.
config.update("jax_enable_x64", True)

from jax import jit
from jax import lax
from jax import nn
from jax import numpy as jnp
from jax import random
from jax import value_and_grad
from jax import vmap

import optax

from boax.core import distributions, samplers
from boax.prediction import kernels, likelihoods, means, models
from boax.optimization import acquisitions, maximizers

We use a two-dimensional synthetic objective funtion simulating the accuracy of a SVM.

In [2]:
bounds = jnp.array([[0.0, 2.0]] * 2)

In [3]:
def objective(x):
    return (
        jnp.sin(5 * x[..., 0] / 2 - 2.5) * jnp.cos(2.5 - 5 * x[..., 1])
        + (5 * x[..., 1] / 2 + 0.5) ** 2 / 10
    ) / 5 + 0.2

In [4]:
params = {
    'mean': jnp.zeros(()),
    'length_scale': jnp.zeros((2,)),
    'amplitude': jnp.zeros(()),
}

In [5]:
optimizer = optax.adam(0.01)

In [6]:
def fit(x_train, y_train):
    def model(mean, length_scale, amplitude):
        return models.predictive(
            models.gaussian_process(
                means.constant(mean),
                kernels.scaled(kernels.matern_five_halves(nn.softplus(length_scale)), nn.softplus(amplitude)),
            ),
            likelihoods.gaussian(1e-4)
        )
        
    def target_log_prob(params):
        mvn = model(**params)(x_train)
        return -jnp.sum(distributions.multivariate_normal.logpdf(mvn, y_train))

    def train_step(state, iteration):
        loss, grads = value_and_grad(target_log_prob)(state[0])
        updates, opt_state = optimizer.update(grads, state[1])
        params = optax.apply_updates(state[0], updates)
        
        return (params, opt_state), loss
    
    return lax.scan(
        jit(train_step),
        (params, optimizer.init(params)),
        jnp.arange(500)
    )

In [7]:
batch_size = 4
sample_size = 32
num_restarts = 40
num_raw_samples = 100

In [8]:
num_queries = 20
num_iterations = num_queries // batch_size

In [9]:
data_key, sampler_key, optimization_key = random.split(random.key(0), 3)
x_train = random.uniform(data_key, minval=bounds[:, 0], maxval=bounds[:, 1], shape=(batch_size, 2))
y_train = objective(x_train)

In [10]:
raw_base_samples = jnp.reshape(
    samplers.halton_normal()(random.fold_in(sampler_key, 0), sample_size * num_raw_samples * batch_size),
    (sample_size, num_raw_samples, batch_size)
)

In [11]:
restart_base_samples = jnp.reshape(
    samplers.halton_normal()(random.fold_in(sampler_key, 1), sample_size * num_restarts * batch_size),
    (sample_size, num_restarts, batch_size)
)

In [12]:
for i in range(num_iterations):
    best = jnp.max(y_train)
    
    (next_params, _), _ = fit(x_train, y_train)

    surrogate = models.predictive(
        models.gaussian_process_regression(
            means.constant(next_params['mean']),
            kernels.scaled(
                kernels.matern_five_halves(nn.softplus(next_params['length_scale'])),
                nn.softplus(next_params['amplitude'])
            ),
        )(
            x_train,
            y_train,
        ),
        likelihoods.gaussian(1e-4)
    )

    maximizer = maximizers.bfgs(
        bounds,
        q=batch_size,
        num_raw_samples=num_raw_samples,
        num_restarts=num_restarts
    )

    candidates = maximizer.init(
        acquisitions.q_probability_of_improvement(
            models.sampled(
                vmap(surrogate),
                vmap(distributions.multivariate_normal.sample),
                raw_base_samples
            ),
            tau=1.0,
            best=best,
        ),
        random.fold_in(optimization_key, i)
    )

    next_candidates, values = maximizer.update(
        acquisitions.q_probability_of_improvement(
            models.sampled(
                vmap(surrogate),
                vmap(distributions.multivariate_normal.sample),
                restart_base_samples
            ),
            tau=1.0,
            best=best,
        ),
        candidates,
    )
    
    next_x = next_candidates[jnp.argmax(values)]
    next_y = objective(next_x)

    x_train = jnp.vstack([x_train, next_x])
    y_train = jnp.hstack([y_train, next_y])