# Grid search example with hyperoptax

In [1]:
from functools import partial
import time

import jax
import flax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
import optax
import matplotlib.pyplot as plt
from tqdm import tqdm


from hyperoptax.grid_search import GridSearch, RandomSearch
from hyperoptax.spaces import LinearSpace, LogSpace

key = jax.random.PRNGKey(0)


In [2]:
# make a basic model
class Model(flax.linen.Module):
    @flax.linen.compact
    def __call__(self, x):
        x = nn.Dense(10)(x)
        x = nn.relu(x)
        return nn.Dense(1)(x)


# make a basic dataset
key, key_data = jax.random.split(key, 2)
x = jax.random.uniform(key_data, (10000, 10))
y = 5 * x**2 +2
xtrain = x[:8000]
ytrain = y[:8000]
xtest = x[8000:]
ytest = y[8000:]


In [3]:
# end to end loop
def make_and_train_model(learning_rate, final_lr_pct):
    key = jax.random.PRNGKey(0)
    n_epochs = 1000
    key, key_init = jax.random.split(key, 2)
    model = Model()
    # make a train state
    ts = train_state.TrainState.create(
        apply_fn=model.apply,
        params=model.init(key_init, jnp.zeros((10,))),
        tx=optax.adamw(learning_rate=optax.linear_schedule(
            init_value=learning_rate,
            end_value=final_lr_pct,
            transition_steps=n_epochs,
        )),
    )
    # learning loop
    def train_model(ts, n_epochs):
        def _train_step(ts, carry):
            def _loss(params):
                y_pred = ts.apply_fn(params, xtrain)
                return jnp.mean((y_pred - ytrain)**2)
            loss, grads = jax.value_and_grad(_loss)(ts.params)
            ts = ts.apply_gradients(grads=grads)
            return ts, loss

        def _eval_step(ts, carry):
            y_pred = ts.apply_fn(ts.params, xtest)
            return ts, jnp.mean((y_pred - ytest)**2)
    
        def _epoch_step(ts, carry):
            ts, loss_train = _train_step(ts, carry)
            ts, loss_eval = _eval_step(ts, carry)
            return ts, loss_eval

        ts, losses = jax.lax.scan(_epoch_step, ts, None, n_epochs)
        return ts, losses 
    
    # train the model
    ts, losses = train_model(ts, n_epochs)
    # return the negative loss to maximise
    return -losses[-1]

# returns the final loss
make_and_train_model(1e-3, 0.01)


Array(-1.963465, dtype=float32)

In [4]:
search_space = {
    "learning_rate": LogSpace(1e-4, 1e-2, 10),
    "final_lr_pct": LinearSpace(0.01, 0.99, 100),
}
search = GridSearch(search_space, make_and_train_model, n_parallel=10)
start = time.time()
n_iterations = 100
result = search.optimise(n_iterations).flatten()
print("Optimal result:", result)
end = time.time()
print(f"Time taken: {end - start:.2f} seconds to sweep {n_iterations} configs")
print("Optimal loss:", -make_and_train_model(result[0], result[1]))

Optimal result: [0.00599484 0.0990909 ]
Time taken: 11.65 seconds to sweep 100 configs
Optimal loss: 1.959629


# Random search

In [5]:
search = RandomSearch(search_space, make_and_train_model, n_parallel=10)
result = search.optimise(100).flatten()
print("Optimal result:", result)
print("Optimal loss:", -make_and_train_model(result[0], result[1]))

Optimal result: [0.00215444 0.22777778]
Optimal loss: 1.9586651
