# 04 - Persistent ES on Learning Rate Tuning Problem
### [Last Update: March 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/04_mlp_pes.ipynb)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

# !pip install -q git+https://github.com/RobertTLange/evosax.git@main

## Problem as in [Vicol et al. (2021)](http://proceedings.mlr.press/v139/vicol21a/vicol21a-supp.pdf) - Toy 2D Regression

In [2]:
import jax
import jax.numpy as jnp
from functools import partial


def loss(x):
    """Inner loss."""
    return (
        jnp.sqrt(x[0] ** 2 + 5)
        - jnp.sqrt(5)
        + jnp.sin(x[1]) ** 2 * jnp.exp(-5 * x[0] ** 2)
        + 0.25 * jnp.abs(x[1] - 100)
    )

def update(state, i):
    """Performs a single inner problem update, e.g., a single unroll step."""
    (L, x, theta, t_curr, T, K) = state
    lr = jnp.exp(theta[0]) * (T - t_curr) / T + jnp.exp(theta[1]) * t_curr / T
    x = x - lr * jax.grad(loss)(x)
    L += loss(x) * (t_curr < T)
    t_curr += 1
    return (L, x, theta, t_curr, T, K), x

@partial(jax.jit, static_argnums=(3, 4))
def unroll(x_init, theta, t0, T, K):
    """Unroll the inner problem for K steps."""
    L = 0.0
    initial_state = (L, x_init, theta, t0, T, K)
    state, outputs = jax.lax.scan(update, initial_state, None, length=K)
    (L, x_curr, theta, t_curr, T, K) = state
    return L, x_curr

### Initialize Persistent Evolution Strategy

In [3]:
from evosax import PersistentES

popsize = 100

strategy = PersistentES(popsize=popsize, num_dims=2)
es_params = strategy.default_params.replace(
    T=100, K=10, sigma_init=0.1
)

rng = jax.random.PRNGKey(5)
state = strategy.initialize(rng, es_params)

# Initialize inner parameters
t = 0
xs = jnp.ones((popsize, 2)) * jnp.array([1.0, 1.0])

es_params

EvoParams(opt_params=OptParams(lrate_init=0.05, lrate_decay=1.0, lrate_limit=0.001, momentum=None, beta_1=0.99, beta_2=0.999, beta_3=None, eps=1e-08, max_speed=None), T=100, K=10, sigma_init=0.1, sigma_decay=1.0, sigma_limit=0.01, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38)

### Run Outer PES Loop of Inner GD Loops :)

In [4]:
for i in range(5000):
    rng, skey = jax.random.split(rng)
    if t >= es_params.T:
        # Reset the inner problem: iteration, parameters
        t = 0
        xs = jnp.ones((popsize, 2)) * jnp.array([1.0, 1.0])
    x, state = strategy.ask(rng, state, es_params)

    # Unroll inner problem for K steps using antithetic perturbations
    fitness, xs = jax.vmap(unroll, in_axes=(0, 0, None, None, None))(
        xs, x, t, es_params.T, es_params.K
    )
    
    # Update ES - outer step!
    state = strategy.tell(x, fitness, state, es_params)
    t += es_params.K

    # Evaluation!
    if i % 500 == 0:
        L, _ = unroll(
            jnp.array([1.0, 1.0]), state.mean, 0, es_params.T, es_params.T
        )
        print(i, state.mean, L)


0 [ 0.05 -0.05] 2423.374
500 [ 0.13214235 -2.474788  ] 2423.2078
1000 [ 3.9050057 -4.4652762] 1183.7357
1500 [ 2.5583386 -4.036586 ] 582.6147
2000 [ 2.7078283 -3.8439238] 564.5876
2500 [ 2.744315  -2.5619094] 559.23505
3000 [ 2.7431633 -3.8979192] 566.58826
3500 [ 2.7665381 -4.55985  ] 558.7182
4000 [ 2.7644894 -3.5615964] 556.5793
4500 [ 2.7446108 -4.953268 ] 559.6667


# Do the same exercise for NR-ES

In [5]:
from evosax import NoiseReuseES

popsize = 100

strategy = NoiseReuseES(popsize=popsize, num_dims=2)
es_params = strategy.default_params.replace(
    T=100, K=10, sigma_init=0.1
)

rng = jax.random.PRNGKey(5)
state = strategy.initialize(rng, es_params)

# Initialize inner parameters
t = 0
xs = jnp.ones((popsize, 2)) * jnp.array([1.0, 1.0])

es_params

EvoParams(opt_params=OptParams(lrate_init=0.05, lrate_decay=1.0, lrate_limit=0.001, momentum=None, beta_1=0.99, beta_2=0.999, beta_3=None, eps=1e-08, max_speed=None), T=100, K=10, sigma_init=0.1, sigma_decay=1.0, sigma_limit=0.01, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38)

In [6]:
for i in range(5000):
    rng, skey = jax.random.split(rng)
    if t >= es_params.T:
        # Reset the inner problem: iteration, parameters
        t = 0
        xs = jnp.ones((popsize, 2)) * jnp.array([1.0, 1.0])
    x, state = strategy.ask(rng, state, es_params)

    # Unroll inner problem for K steps using antithetic perturbations
    fitness, xs = jax.vmap(unroll, in_axes=(0, 0, None, None, None))(
        xs, x, t, es_params.T, es_params.K
    )
    
    # Update ES - outer step!
    state = strategy.tell(x, fitness, state, es_params)
    t += es_params.K

    # Evaluation!
    if i % 500 == 0:
        L, _ = unroll(
            jnp.array([1.0, 1.0]), state.mean, 0, es_params.T, es_params.T
        )
        print(i, state.mean, L)


0 [ 0.05 -0.05] 2423.374
500 [ 0.09266945 -2.0956526 ] 2423.294
1000 [ 2.6628385 -7.8434124] 579.9972
1500 [ 2.8134286 -8.447359 ] 581.87134
2000 [ 2.767414 -8.371097] 561.2798
2500 [ 2.7782755 -8.828757 ] 558.65204
3000 [ 2.7734683 -8.900436 ] 559.83704
3500 [ 2.7766051 -8.655186 ] 562.86945
4000 [ 2.7784874 -8.7830105] 571.02905
4500 [ 2.7700536 -8.335732 ] 566.3759
