# 04 - Persistent ES on Learning Rate Tuning Problem
### [Last Update: February 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/04_mlp_pes.ipynb)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

!pip install evosax

import jax
import jax.numpy as jnp
from functools import partial

## Problem as in [Vicol et al. (2021)](http://proceedings.mlr.press/v139/vicol21a/vicol21a-supp.pdf) - Toy 2D Regression

In [2]:
def loss(x):
    """Inner loss."""
    return (
        jnp.sqrt(x[0] ** 2 + 5)
        - jnp.sqrt(5)
        + jnp.sin(x[1]) ** 2 * jnp.exp(-5 * x[0] ** 2)
        + 0.25 * jnp.abs(x[1] - 100)
    )

def update(state, i):
    """Performs a single inner problem update, e.g., a single unroll step."""
    (L, x, theta, t_curr, T, K) = state
    lr = jnp.exp(theta[0]) * (T - t_curr) / T + jnp.exp(theta[1]) * t_curr / T
    x = x - lr * jax.grad(loss)(x)
    L += loss(x) * (t_curr < T)
    t_curr += 1
    return (L, x, theta, t_curr, T, K), x

@partial(jax.jit, static_argnums=(3, 4))
def unroll(x_init, theta, t0, T, K):
    """Unroll the inner problem for K steps."""
    L = 0.0
    initial_state = (L, x_init, theta, t0, T, K)
    state, outputs = jax.lax.scan(update, initial_state, None, length=K)
    (L, x_curr, theta, t_curr, T, K) = state
    return L, x_curr

### Initialize Persistent Evolution Strategy

In [3]:
from evosax import Persistent_ES

popsize = 100
T = 100
K = 10

strategy = Persistent_ES(popsize=popsize, num_dims=2)
es_params = strategy.default_params
es_params["T"] = 100
es_params["K"] = 10

rng = jax.random.PRNGKey(5)
state = strategy.initialize(rng, es_params)

# Initialize inner parameters
t = 0
xs = jnp.ones((popsize, 2)) * jnp.array([1.0, 1.0])

### Run Outer PES Loop of Inner GD Loops :)

In [4]:
for i in range(5000):
    rng, skey = jax.random.split(rng)
    if t >= es_params["T"]:
        # Reset the inner problem: iteration, parameters
        t = 0
        xs = jnp.ones((popsize, 2)) * jnp.array([1.0, 1.0])
    x, state = strategy.ask(rng, state, es_params)

    # Unroll inner problem for K steps using antithetic perturbations
    fitness, xs = jax.vmap(unroll, in_axes=(0, 0, None, None, None))(
        xs, x, t, es_params["T"], es_params["K"]
    )
    
    # Update ES - outer step!
    state = strategy.tell(x, fitness, state, es_params)
    t += es_params["K"]

    # Evaluation!
    if i % 500 == 0:
        L, _ = unroll(
            jnp.array([1.0, 1.0]), state["mean"], 0, es_params["T"], es_params["T"]
        )
        print(i, state["mean"], L)


0 [-1.8013444  1.6298852] 2106.2678
500 [-1.8695656  3.2172225] 1378.6495
1000 [-1.6957031  3.0443091] 1350.8297
1500 [-1.7292564  2.9525437] 1394.0214
2000 [-1.6703017  2.9694104] 1425.172
2500 [-1.5773143  2.965407 ] 1371.3549
3000 [-1.5619453  2.9539077] 1369.0419
3500 [-1.5344132  2.9571912] 1367.5834
4000 [-1.5256056  2.9602153] 1265.4513
4500 [-1.5129912  2.9543905] 1267.741
