# 02 - Evolving CartPole Controllers
### [Last Update: February 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/02_mlp_control.ipynb)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import jax
import jax.numpy as jnp

## Open-ES with MLP Controller

In [None]:
from evosax import Open_ES
from ParameterReshaper, FitnessShaper, NetworkMapper
from evosax.problems import GymFitness

rng = jax.random.PRNGKey(0)
network = NetworkMapper["MLP"](
    num_hidden_units=64,
    num_hidden_layers=2,
    num_output_units=2,
    hidden_activation="relu",
    output_activation="categorical",
)
pholder = jnp.zeros((4,))
params = network.init(
    rng,
    x=pholder,
    rng=rng,
)

param_reshaper = ParameterReshaper(params['params'])

In [None]:
evaluator = GymFitness()
evaluator.set_apply_fn(network.apply)
rollout = jax.jit(jax.vmap(evaluator.rollout, in_axes=(None, param_reshaper.vmap_dict)))

fit_shaper = FitnessShaper(rank_fitness=True,
                           z_score_fitness=True,
                           weight_decay=0.1,
                           maximize_objective=True)

In [None]:
strategy = Open_ES(popsize=100,
                   num_dims=param_reshaper.total_params,
                   opt_name="sgd")

es_params = {
        "lrate_init": 0.01,  # Adam learning rate outer step
        "lrate_decay": 0.999,
        "lrate_limit": 0.001,
        "beta_1": 0.99,  # beta_1 outer step
        "beta_2": 0.999,  # beta_2 outer step
        "eps": 1e-8,  # eps constant outer step,
        "sigma_init": 0.1,
        "sigma_decay": 0.999,
        "sigma_limit": 0.01,
        "init_min": -0.1,
        "init_max": 0.1
}

In [None]:
num_generations = 350
num_rollouts = 20
state = strategy.initialize(rng, es_params)

for gen in range(num_generations):
    rng, rng_init, rng_ask, rng_eval = jax.random.split(rng, 4)
    x, state = strategy.ask(rng_ask, state, es_params)
    reshaped_params = param_reshaper.reshape(x)
    batch_rng = jax.random.split(rng_eval, num_rollouts)
    fitness = rollout(batch_rng, reshaped_params).mean(axis=1)
    fit_re = fit_shaper.apply(x, fitness)
    state = strategy.tell(x, fit_re, state, es_params)
    

# PGPE with LSTM Controller

In [None]:
rng = jax.random.PRNGKey(0)
network = NetworkMapper["LSTM"](
    num_hidden_units=64,
    num_output_units=2,
    output_activation="categorical",
)
pholder = jnp.zeros((4,))
carry_init = network.initialize_carry()
params = network.init(
    rng,
    x=pholder,
    carry=carry_init,
    rng=rng,
)

from evosax.utils import ParameterReshaper
param_reshaper = ParameterReshaper(params['params'])

In [None]:
evaluator = GymFitness()
evaluator.set_apply_fn(network.apply, network.initialize_carry)
rollout = jax.jit(jax.vmap(evaluator.rollout, in_axes=(None, param_reshaper.vmap_dict)))

In [None]:
popsize = 100
strategy = PGPE_ES(shaper.total_params, popsize, elite_ratio=0.1)
params = strategy.default_params

params = {
        "sigma_init": 0.10,  # initial standard deviation
        "sigma_lrate": 0.2,
        "sigma_decay": 0.999,            # anneal standard deviation
        "sigma_limit": 0.01,             # stop annealing if less than this
        "sigma_max_change": 0.2,         # clips adaptive sigma to 20%
        "lrate_init": 0.01,
        "lrate_decay": 0.999,
        "lrate_limit": 0.001,
        "beta_1": 0.99,  # beta_1 outer step
        "beta_2": 0.999,  # beta_2 outer step
        "eps": 1e-8,  # eps constant outer step,
        "init_min": -0.1,
        "init_max": 0.1
}

In [None]:
num_generations = 350
num_rollouts = 20
state = strategy.initialize(rng, es_params)

for gen in range(num_generations):
    rng, rng_init, rng_ask, rng_eval = jax.random.split(rng, 4)
    x, state = strategy.ask(rng_ask, state, es_params)
    reshaped_params = param_reshaper.reshape(x)
    batch_rng = jax.random.split(rng_eval, num_rollouts)
    fitness = rollout(batch_rng, reshaped_params).mean(axis=1)
    fit_re = fit_shaper.apply(x, fitness)
    state = strategy.tell(x, fit_re, state, es_params)