# 04 - Persistent ES on Learning Rate Tuning Problem [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/04_mlp_pes.ipynb)

## Installation

You will need Python 3.10 or later, and a working JAX installation. For example, you can install JAX on NVIDIA GPU with:

In [None]:
%pip install -U "jax[cuda]"

Then, install `evosax` from PyPi:

In [None]:
%pip install -U "evosax[examples]"

## Import

In [1]:
from functools import partial

import jax
import jax.numpy as jnp
import optax
from evosax.algorithms import NoiseReuseES, PersistentES

In [2]:
seed = 0
key = jax.random.key(seed)

## Problem as in [Vicol et al. (2021)](http://proceedings.mlr.press/v139/vicol21a/vicol21a-supp.pdf) - Toy 2D Regression

In [3]:
def loss(x):
    """Inner loss."""
    return (
        jnp.sqrt(x[0] ** 2 + 5)
        - jnp.sqrt(5)
        + jnp.sin(x[1]) ** 2 * jnp.exp(-5 * x[0] ** 2)
        + 0.25 * jnp.abs(x[1] - 100)
    )


def update(state, i):
    """Perform a single inner problem update, e.g., a single unroll step."""
    (L, x, theta, t_curr, T, K) = state
    lr = jnp.exp(theta[0]) * (T - t_curr) / T + jnp.exp(theta[1]) * t_curr / T
    x = x - lr * jax.grad(loss)(x)
    L += loss(x) * (t_curr < T)
    t_curr += 1
    return (L, x, theta, t_curr, T, K), x


@partial(jax.jit, static_argnames=("T", "K"))
def unroll(x_init, theta, t0, T, K):
    """Unroll the inner problem for K steps."""
    L = 0.0
    initial_state = (L, x_init, theta, t0, T, K)
    state, outputs = jax.lax.scan(update, initial_state, None, length=K)
    (L, x_curr, theta, t_curr, T, K) = state
    return L, x_curr

## Initialize Persistent Evolution Strategy

In [4]:
population_size = 128

strategy = PersistentES(
    population_size=population_size,
    solution=jnp.zeros(2),
    std_schedule=optax.constant_schedule(0.2),
)
params = strategy.default_params.replace(T=100, K=10)

key, subkey = jax.random.split(key)
state = strategy.init(subkey, jnp.zeros(2), params)

# Initialize inner parameters
xs = jnp.ones((population_size, 2)) * jnp.array([1.0, 1.0])

params

Params(T=100, K=10)

## Run Persistent ES

In [5]:
for i in range(5000):
    key, key_ask, key_tell = jax.random.split(key, 3)

    if state.inner_step_counter == 0:
        # Reset the inner problem: iteration, parameters
        xs = jnp.ones((population_size, 2)) * jnp.array([1.0, 1.0])

    population, state = strategy.ask(key_ask, state, params)

    # Unroll inner problem for K steps using antithetic perturbations
    fitness, xs = jax.vmap(unroll, in_axes=(0, 0, None, None, None))(
        xs, population, state.inner_step_counter, params.T, params.K
    )

    state, metrics = strategy.tell(key_tell, population, fitness, state, params)

    # Evaluation!
    if i % 500 == 0:
        L, _ = unroll(jnp.array([1.0, 1.0]), state.mean, 0, params.T, params.T)
        print(f"Generation {i:4d} | Mean fitness: {L:.2f}")

Generation    0 | Mean fitness: 2423.47
Generation  500 | Mean fitness: 2321.65
Generation 1000 | Mean fitness: 1545.29
Generation 1500 | Mean fitness: 1238.88
Generation 2000 | Mean fitness: 730.35
Generation 2500 | Mean fitness: 665.97
Generation 3000 | Mean fitness: 618.68
Generation 3500 | Mean fitness: 611.86
Generation 4000 | Mean fitness: 604.82
Generation 4500 | Mean fitness: 602.42


## Initialize Noise Reuse Evolution Strategy

In [6]:
population_size = 128

strategy = NoiseReuseES(
    population_size=population_size,
    solution=jnp.zeros(2),
    std_schedule=optax.constant_schedule(0.2),
)
params = strategy.default_params.replace(T=100, K=10)

key, subkey = jax.random.split(key)
state = strategy.init(subkey, jnp.zeros(2), params)

# Initialize inner parameters
xs = jnp.ones((population_size, 2)) * jnp.array([1.0, 1.0])

params

Params(T=100, K=10)

In [7]:
for i in range(5000):
    key, key_ask, key_tell = jax.random.split(key, 3)

    if state.inner_step_counter == 0:
        # Reset the inner problem: iteration, parameters
        xs = jnp.ones((population_size, 2)) * jnp.array([1.0, 1.0])

    population, state = strategy.ask(key_ask, state, params)

    # Unroll inner problem for K steps using antithetic perturbations
    fitness, xs = jax.vmap(unroll, in_axes=(0, 0, None, None, None))(
        xs, population, state.inner_step_counter, params.T, params.K
    )

    state, metrics = strategy.tell(key_tell, population, fitness, state, params)

    # Evaluation!
    if i % 500 == 0:
        L, _ = unroll(jnp.array([1.0, 1.0]), state.mean, 0, params.T, params.T)
        print(f"Generation {i:4d} | Mean fitness: {L:.2f}")

Generation    0 | Mean fitness: 2423.47
Generation  500 | Mean fitness: 2095.82
Generation 1000 | Mean fitness: 1656.34
Generation 1500 | Mean fitness: 997.18
Generation 2000 | Mean fitness: 743.01
Generation 2500 | Mean fitness: 656.02
Generation 3000 | Mean fitness: 628.86
Generation 3500 | Mean fitness: 610.77
Generation 4000 | Mean fitness: 600.68
Generation 4500 | Mean fitness: 597.65
