# 04 - Persistent ES on Learning Rate Tuning Problem
### [Last Update: March 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/04_mlp_pes.ipynb)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

!pip install git+https://github.com/RobertTLange/evosax.git@main

## Problem as in [Vicol et al. (2021)](http://proceedings.mlr.press/v139/vicol21a/vicol21a-supp.pdf) - Toy 2D Regression

In [2]:
import jax
import jax.numpy as jnp
from functools import partial


def loss(x):
    """Inner loss."""
    return (
        jnp.sqrt(x[0] ** 2 + 5)
        - jnp.sqrt(5)
        + jnp.sin(x[1]) ** 2 * jnp.exp(-5 * x[0] ** 2)
        + 0.25 * jnp.abs(x[1] - 100)
    )

def update(state, i):
    """Performs a single inner problem update, e.g., a single unroll step."""
    (L, x, theta, t_curr, T, K) = state
    lr = jnp.exp(theta[0]) * (T - t_curr) / T + jnp.exp(theta[1]) * t_curr / T
    x = x - lr * jax.grad(loss)(x)
    L += loss(x) * (t_curr < T)
    t_curr += 1
    return (L, x, theta, t_curr, T, K), x

@partial(jax.jit, static_argnums=(3, 4))
def unroll(x_init, theta, t0, T, K):
    """Unroll the inner problem for K steps."""
    L = 0.0
    initial_state = (L, x_init, theta, t0, T, K)
    state, outputs = jax.lax.scan(update, initial_state, None, length=K)
    (L, x_curr, theta, t_curr, T, K) = state
    return L, x_curr

### Initialize Persistent Evolution Strategy

In [3]:
from evosax import PersistentES

popsize = 100
T = 100
K = 10

strategy = PersistentES(popsize=popsize, num_dims=2)
es_params = strategy.default_params
es_params["T"] = 100
es_params["K"] = 10

rng = jax.random.PRNGKey(5)
state = strategy.initialize(rng, es_params)

# Initialize inner parameters
t = 0
xs = jnp.ones((popsize, 2)) * jnp.array([1.0, 1.0])

### Run Outer PES Loop of Inner GD Loops :)

In [4]:
for i in range(5000):
    rng, skey = jax.random.split(rng)
    if t >= es_params["T"]:
        # Reset the inner problem: iteration, parameters
        t = 0
        xs = jnp.ones((popsize, 2)) * jnp.array([1.0, 1.0])
    x, state = strategy.ask(rng, state, es_params)

    # Unroll inner problem for K steps using antithetic perturbations
    fitness, xs = jax.vmap(unroll, in_axes=(0, 0, None, None, None))(
        xs, x, t, es_params["T"], es_params["K"]
    )
    
    # Update ES - outer step!
    state = strategy.tell(x, fitness, state, es_params)
    t += es_params["K"]

    # Evaluation!
    if i % 500 == 0:
        L, _ = unroll(
            jnp.array([1.0, 1.0]), state["mean"], 0, es_params["T"], es_params["T"]
        )
        print(i, state["mean"], L)


0 [ 0.01000024 -0.01000024] 2423.4482
500 [ 0.08029073 -0.78276646] 2423.3083
1000 [ 0.17780854 -0.68988854] 2423.4324
1500 [ 1.8179075  -0.55037224] 1234.2625
2000 [ 2.6217515 -0.4867051] 583.6262
2500 [ 2.7118611  -0.47680327] 580.5625
3000 [ 2.7385468  -0.51297283] 563.72296
3500 [ 2.7507095  -0.54018897] 571.6591
4000 [ 2.7493122  -0.54298735] 570.99243
4500 [ 2.7636638  -0.59655845] 575.1831
