# 09 - Batch Strategy Rollouts
### [Last Update: June 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/09_exp_batch_es.ipynb)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

!pip install git+https://github.com/RobertTLange/evosax.git@main

## Experimental (!!!) - Subpopulation Batch ES Rollouts

In [1]:
import jax
import jax.numpy as jnp
from evosax import NetworkMapper
from evosax.problems import GymFitness
from evosax.utils import ParameterReshaper, FitnessShaper

rng = jax.random.PRNGKey(0)
# Run Strategy on CartPole MLP
evaluator = GymFitness("CartPole-v1", num_env_steps=200, num_rollouts=16)

network = NetworkMapper["MLP"](
    num_hidden_units=64,
    num_hidden_layers=2,
    num_output_units=2,
    hidden_activation="relu",
    output_activation="categorical",
)
pholder = jnp.zeros((1, evaluator.input_shape[0]))
params = network.init(
    rng,
    x=pholder,
    rng=rng,
)

reshaper = ParameterReshaper(params)
evaluator.set_apply_fn(reshaper.vmap_dict, network.apply)

fit_shaper = FitnessShaper(maximize=True)




ParameterReshaper: 4610 parameters detected for optimization.


In [9]:
from evosax.experimental.subpops import BatchStrategy

strategy = BatchStrategy(
    strategy_name="DE",
    num_dims=reshaper.total_params,
    popsize=100,
    num_subpops=5,
    communication="best_subpop",
)
params = strategy.default_params
state = strategy.initialize(rng, params)

In [10]:
for t in range(20):
    rng, rng_eval, rng_iter = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_iter, state, params)
    x_re = reshaper.reshape(x)
    fitness = evaluator.rollout(rng_eval, x_re).mean(axis=1)
    fit_re = fit_shaper.apply(x, fitness)
    state = strategy.tell(x, fit_re, state, params)

    if t % 1 == 0:
        print(
            t + 1,
            fitness.mean(),
            fitness.max(),
            fitness.std(),
            state.best_fitness,  # Best fitness in all subpops
        )

1 22.464375 26.3125 2.0493376 [-26.3125 -26.3125 -26.3125 -26.3125 -26.3125]
2 22.575624 29.75 2.9526908 [-29.75 -29.75 -29.75 -29.75 -29.75]
3 22.914999 29.125 3.4180415 [-29.75 -29.75 -29.75 -29.75 -29.75]
4 19.238125 28.9375 2.5196369 [-29.75 -29.75 -29.75 -29.75 -29.75]
5 19.704374 33.0625 2.316076 [-33.0625 -33.0625 -33.0625 -33.0625 -33.0625]
6 23.7925 61.875 9.7088585 [-61.875 -61.875 -61.875 -61.875 -61.875]
7 35.21 118.5625 16.621597 [-118.5625 -118.5625 -118.5625 -118.5625 -118.5625]
8 38.021873 86.375 18.679571 [-118.5625 -118.5625 -118.5625 -118.5625 -118.5625]
9 45.83875 148.75 31.13269 [-148.75 -148.75 -148.75 -148.75 -148.75]
10 36.0625 125.6875 28.167828 [-148.75 -148.75 -148.75 -148.75 -148.75]
11 44.895 182.9375 38.524178 [-182.9375 -182.9375 -182.9375 -182.9375 -182.9375]
12 49.030624 170.0 36.70624 [-182.9375 -182.9375 -182.9375 -182.9375 -182.9375]
13 47.264374 170.75 32.65505 [-182.9375 -182.9375 -182.9375 -182.9375 -182.9375]
14 47.146248 174.8125 35.011383 [-182

## Experimental (!!!) - Subpopulation Meta-Batch ES Rollouts

In [12]:
from evosax.experimental.subpops import MetaStrategy

meta_strategy = MetaStrategy(
        meta_strategy_name="CMA_ES",
        inner_strategy_name="DE",
        meta_params=["diff_w", "cross_over_rate"],
        num_dims=reshaper.total_params,
        popsize=100,
        num_subpops=5,
        meta_strategy_kwargs={"elite_ratio": 0.5},
    )
meta_es_params = meta_strategy.default_params_meta
meta_es_params.replace(
    clip_min=jnp.array([0, 0]), clip_max=jnp.array([2, 1])
)
meta_es_params

EvoParams(mu_eff=DeviceArray(1.6496499, dtype=float32), c_1=DeviceArray(0.15949409, dtype=float32), c_mu=DeviceArray(0.02899084, dtype=float32), c_sigma=DeviceArray(0.42194194, dtype=float32), d_sigma=DeviceArray(1.421942, dtype=float32), c_c=DeviceArray(0.63072497, dtype=float32), chi_n=DeviceArray(1.2542727, dtype=float32), weights=DeviceArray([ 0.73042274,  0.2695773 ,  0.        , -0.726532  ,
             -1.2900741 ], dtype=float32), weights_truncated=DeviceArray([0.73042274, 0.2695773 , 0.        , 0.        , 0.        ],            dtype=float32), c_m=1.0, sigma_init=0.065, init_min=DeviceArray([0.8, 0.9], dtype=float32), init_max=DeviceArray([0.8, 0.9], dtype=float32), clip_min=-3.4028235e+38, clip_max=3.4028235e+38)

In [17]:
# META: Initialize the meta strategy state
inner_es_params = meta_strategy.default_params
meta_state = meta_strategy.initialize_meta(rng, meta_es_params)

# META: Get altered inner es hyperparams (placeholder for init)
inner_es_params, meta_state = meta_strategy.ask_meta(
    rng, meta_state, meta_es_params, inner_es_params
)

# INNER: Initialize the inner batch ES
state = meta_strategy.initialize(rng, inner_es_params)

for t in range(20):
    rng, rng_eval, rng_iter = jax.random.split(rng, 3)

    # META: Get altered inner es hyperparams
    inner_es_params, meta_state = meta_strategy.ask_meta(
    rng, meta_state, meta_es_params, inner_es_params
    )

    # INNER: Ask for inner candidate params to evaluate on problem
    x, state = meta_strategy.ask(rng_iter, state, inner_es_params)

    # INNER: Update using pseudo fitness
    x_re = reshaper.reshape(x)
    fitness = evaluator.rollout(rng_eval, x_re).mean(axis=1)
    fit_re = fit_shaper.apply(x, fitness)
    state = meta_strategy.tell(x, fit_re, state, inner_es_params)

    # META: Update the meta strategy
    meta_state = meta_strategy.tell_meta(
        inner_es_params, fit_re, meta_state, meta_es_params
    )

    if t % 1 == 0:
        print(
            t + 1,
            fitness.mean(),
            fitness.max(),
            fitness.std(),
            state.best_fitness,  # Best fitness in all subpops
        )
        print(inner_es_params.diff_w)
        print(inner_es_params.cross_over_rate)
        print(20 * "=")


1 20.289999 23.5 1.0141777 [-22.125 -23.5   -23.5   -23.5   -22.125]
[0.879098   0.76078224 0.7285294  0.8667383  0.9170291 ]
[0.87730575 0.89257807 0.88649285 0.8704477  0.95789057]
2 23.526875 27.8125 2.8445435 [-27.75   -27.75   -27.5625 -27.8125 -27.5625]
[0.7529136  0.75194293 0.75470865 0.73406416 0.7236362 ]
[0.92120713 0.80959445 0.904763   0.9089725  0.9178807 ]
3 18.4025 23.375 1.7486227 [-27.75   -27.75   -27.5625 -27.8125 -27.5625]
[0.76000106 0.7763824  0.80698913 0.7326284  0.79374725]
[0.8445878  0.90652514 0.9469857  0.9224319  0.8527581 ]
4 20.2475 28.875 2.2587662 [-27.75   -27.75   -27.5625 -28.875  -27.5625]
[0.8238988 0.7451631 0.7800138 0.8030797 0.7789296]
[0.81761867 0.87493116 0.85717034 0.8252281  0.8858736 ]
5 21.651875 26.75 2.3301566 [-27.75   -27.75   -27.5625 -28.875  -27.5625]
[0.7483712  0.78120756 0.7842538  0.8036731  0.8382279 ]
[0.8533484  0.85126436 0.81118304 0.87271136 0.7978533 ]
6 24.300625 32.5625 3.8705084 [-30.1875 -30.75   -28.5625 -32.5625