# 05 - PBT on Quadratic Toy Problem
### [Last Update: February 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/05_quadratic_pbt.ipynb)

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

In [None]:
import jax
import jax.numpy as jnp
from evosax.strategies import PBT_ES
from functools import partial


@partial(jax.vmap, in_axes=(0, 0, None))
def step(theta, h, lrate):
    """Perform GradAscent step on quadratic surrogate objective (maximize!)."""
    surrogate_grad = -2.0 * h * theta
    return theta + lrate * surrogate_grad


@partial(jax.vmap, in_axes=(0,))
def evaluate(theta):
    """Ground truth objective (e.g. val loss) as in Jaderberg et al. 2016."""
    return 1.2 - jnp.sum(theta ** 2)


@partial(jax.vmap, in_axes=(0, 0))
def surrogate_objective(theta, h):
    """Surrogate objective (with hyperparams h) as in Jaderberg et al. 2016."""
    return 1.2 - jnp.sum(h * theta ** 2)


rng = jax.random.PRNGKey(1)
strategy = PBT_ES(2, 2)
params = strategy.default_params
params["noise_scale"] = 0.5
state = strategy.initialize(rng, params)

# set the state manually for init
theta = jnp.array([[0.9, 0.9], [0.9, 0.9]])
h = jnp.array([[0, 1], [1, 0]])

# Run 10 steps and evaluate final performance
fitness_log = []
theta_log = []
for gen in range(20):
    rng, rng_gen = jax.random.split(rng, 2)
    for i in range(10):
        theta = step(theta, h, 0.01)
        theta_log.append(theta)
    fitness = evaluate(theta)
    state = strategy.tell(h, fitness, state, params)
    copy_id, h, state = strategy.ask(rng_gen, state, params)
    theta = theta[copy_id]
    fitness_log.append(fitness)

theta_log = jnp.array(theta_log)

# import matplotlib.pyplot as plt
#
# fig, axs = plt.subplots(1, 2, figsize=(10, 5))
# axs[0].plot(fitness_log)
# axs[1].scatter(theta_log[:, 0], theta_log[:, 1], s=8)
