# 09 - Batch Strategy Rollouts
### [Last Update: June 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/09_exp_batch_es.ipynb)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

!pip install -q git+https://github.com/RobertTLange/evosax.git@main
!pip install -q gymnax

## Experimental (!!!) - Subpopulation Batch ES Rollouts

In [2]:
import jax
import jax.numpy as jnp
from evosax import NetworkMapper
from evosax.problems import GymnaxFitness
from evosax.utils import ParameterReshaper, FitnessShaper

rng = jax.random.PRNGKey(0)
# Run Strategy on CartPole MLP
evaluator = GymnaxFitness("CartPole-v1", num_env_steps=200, num_rollouts=16)

network = NetworkMapper["MLP"](
    num_hidden_units=64,
    num_hidden_layers=2,
    num_output_units=2,
    hidden_activation="relu",
    output_activation="categorical",
)
pholder = jnp.zeros((1, evaluator.input_shape[0]))
params = network.init(
    rng,
    x=pholder,
    rng=rng,
)

reshaper = ParameterReshaper(params)
evaluator.set_apply_fn(network.apply)

fit_shaper = FitnessShaper(maximize=True)


ParameterReshaper: 4610 parameters detected for optimization.


In [3]:
from evosax.experimental.subpops import BatchStrategy

strategy = BatchStrategy(
    strategy_name="DE",
    num_dims=reshaper.total_params,
    popsize=100,
    num_subpops=5,
    communication="best_subpop",
)
params = strategy.default_params
state = strategy.initialize(rng, params)

In [4]:
for t in range(20):
    rng, rng_eval, rng_iter = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_iter, state, params)
    x_re = reshaper.reshape(x)
    fitness = evaluator.rollout(rng_eval, x_re).mean(axis=1)
    fit_re = fit_shaper.apply(x, fitness)
    state = strategy.tell(x, fit_re, state, params)

    if t % 1 == 0:
        print(
            t + 1,
            fitness.mean(),
            fitness.max(),
            fitness.std(),
            state.best_fitness,  # Best fitness in all subpops
        )

  abs_value_flat = jax.tree_leaves(abs_value)
  value_flat = jax.tree_leaves(value)


1 23.370625 25.5 1.0780935 [-25.5 -25.5 -25.5 -25.5 -25.5]
2 22.074375 29.375 2.6189167 [-29.375 -29.375 -29.375 -29.375 -29.375]
3 20.88875 22.5625 0.74370164 [-29.375 -29.375 -29.375 -29.375 -29.375]
4 24.71 33.3125 2.8831882 [-33.3125 -33.3125 -33.3125 -33.3125 -33.3125]
5 21.94875 31.0625 3.4156966 [-33.3125 -33.3125 -33.3125 -33.3125 -33.3125]
6 21.903124 44.5625 5.5148053 [-44.5625 -44.5625 -44.5625 -44.5625 -44.5625]
7 30.741875 112.5625 14.554363 [-112.5625 -112.5625 -112.5625 -112.5625 -112.5625]
8 34.178123 137.625 25.262201 [-137.625 -137.625 -137.625 -137.625 -137.625]
9 40.82125 177.3125 33.683407 [-177.3125 -177.3125 -177.3125 -177.3125 -177.3125]
10 51.761875 185.3125 50.306335 [-185.3125 -185.3125 -185.3125 -185.3125 -185.3125]
11 62.724373 200.0 54.560562 [-200. -200. -200. -200. -200.]
12 79.213745 200.0 65.029335 [-200. -200. -200. -200. -200.]
13 77.94187 200.0 58.734848 [-200. -200. -200. -200. -200.]
14 70.84062 200.0 59.714462 [-200. -200. -200. -200. -200.]
15 9

## Experimental (!!!) - Subpopulation Meta-Batch ES Rollouts

In [5]:
from evosax.experimental.subpops import MetaStrategy

meta_strategy = MetaStrategy(
        meta_strategy_name="CMA_ES",
        inner_strategy_name="DE",
        meta_params=["diff_w", "cross_over_rate"],
        num_dims=reshaper.total_params,
        popsize=100,
        num_subpops=5,
        meta_strategy_kwargs={"elite_ratio": 0.5},
    )
meta_es_params = meta_strategy.default_params_meta
meta_es_params.replace(
    clip_min=jnp.array([0, 0]), clip_max=jnp.array([2, 1])
)
meta_es_params

EvoParams(mu_eff=DeviceArray(1.6496499, dtype=float32), c_1=DeviceArray(0.15949409, dtype=float32), c_mu=DeviceArray(0.02899084, dtype=float32), c_sigma=DeviceArray(0.42194194, dtype=float32), d_sigma=DeviceArray(1.421942, dtype=float32), c_c=DeviceArray(0.63072497, dtype=float32), chi_n=DeviceArray(1.2542727, dtype=float32, weak_type=True), c_m=1.0, sigma_init=1.0, init_min=DeviceArray([0.8, 0.9], dtype=float32), init_max=DeviceArray([0.8, 0.9], dtype=float32), clip_min=-3.4028235e+38, clip_max=3.4028235e+38)

In [6]:
# META: Initialize the meta strategy state
inner_es_params = meta_strategy.default_params
meta_state = meta_strategy.initialize_meta(rng, meta_es_params)

# META: Get altered inner es hyperparams (placeholder for init)
inner_es_params, meta_state = meta_strategy.ask_meta(
    rng, meta_state, meta_es_params, inner_es_params
)

# INNER: Initialize the inner batch ES
state = meta_strategy.initialize(rng, inner_es_params)

for t in range(20):
    rng, rng_eval, rng_iter = jax.random.split(rng, 3)

    # META: Get altered inner es hyperparams
    inner_es_params, meta_state = meta_strategy.ask_meta(
    rng, meta_state, meta_es_params, inner_es_params
    )

    # INNER: Ask for inner candidate params to evaluate on problem
    x, state = meta_strategy.ask(rng_iter, state, inner_es_params)

    # INNER: Update using pseudo fitness
    x_re = reshaper.reshape(x)
    fitness = evaluator.rollout(rng_eval, x_re).mean(axis=1)
    fit_re = fit_shaper.apply(x, fitness)
    state = meta_strategy.tell(x, fit_re, state, inner_es_params)

    # META: Update the meta strategy
    meta_state = meta_strategy.tell_meta(
        inner_es_params, fit_re, meta_state, meta_es_params
    )

    if t % 1 == 0:
        print(
            t + 1,
            fitness.mean(),
            fitness.max(),
            fitness.std(),
            state.best_fitness,  # Best fitness in all subpops
        )
        print(inner_es_params.diff_w)
        print(inner_es_params.cross_over_rate)
        print(20 * "=")


1 21.616875 29.6875 3.9452865 [-28.8125 -22.8125 -28.8125 -27.8125 -29.6875]
[ 2.8802464  -0.15859854  1.0776247   1.7196195   0.8666483 ]
[ 3.227366    2.2789712   1.0106008   0.59644413 -0.31195474]
2 18.93375 43.125 5.404836 [-43.125  -22.8125 -28.8125 -27.8125 -33.    ]
[0.3164339  0.38105685 1.4904149  0.54773945 1.7374508 ]
[ 1.2298826   0.5813654   0.74932307  0.25477123 -0.1159527 ]
3 18.708124 51.875 6.0304956 [-43.125  -22.8125 -28.8125 -27.8125 -51.875 ]
[ 1.4259593  -0.06665286  0.8067391   0.25669232  1.0614586 ]
[0.86134017 0.4752618  0.9046646  0.07983232 0.47602725]
4 24.71875 72.1875 13.535652 [-72.1875 -22.875  -56.875  -27.8125 -51.875 ]
[ 0.28945768  1.2563071   1.3974397  -0.01699698  0.9223877 ]
[ 1.0732273   1.7199515   1.1780126  -0.05289584  1.3038926 ]
5 33.605625 95.875 18.421816 [-77.25   -25.3125 -95.875  -27.8125 -65.25  ]
[ 0.14544031  0.43611148 -0.20747858  0.57340264  1.0367332 ]
[0.9635209 1.6865335 1.6758163 1.1580497 1.3674667]
6 42.816875 128.75 32