# 07 - Evolving Brax Controllers
### [Last Update: June 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/07_brax_control.ipynb)

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

!pip install -q git+https://github.com/RobertTLange/evosax.git@main
!pip install -q brax

## Open-ES with MLP Controller

In [16]:
import jax
import jax.numpy as jnp

from evosax import OpenES, ParameterReshaper, FitnessShaper, NetworkMapper
from evosax.utils import ESLog
from evosax.problems import BraxFitness

# Instantiate brax rollout wrapper & network architecture
evaluator = BraxFitness("ant", num_env_steps=1000, num_rollouts=16)

rng = jax.random.PRNGKey(0)
network = NetworkMapper["MLP"](
    num_hidden_units=32,
    num_hidden_layers=4,
    num_output_units=evaluator.action_shape,
    hidden_activation="tanh",
    output_activation="tanh",
)
pholder = jnp.zeros((1, evaluator.input_shape[0]))
params = network.init(
    rng,
    x=pholder,
    rng=rng,
)

param_reshaper = ParameterReshaper(params)

# Set mapping dictionary for parallelization
evaluator.set_apply_fn(param_reshaper.vmap_dict, network.apply)

ParameterReshaper: 6248 parameters detected for optimization.


In [17]:
strategy = OpenES(popsize=256,
                  num_dims=param_reshaper.total_params,
                  opt_name="adam")
strategy.default_params

EvoParams(opt_params=OptParams(lrate_init=0.01, lrate_decay=0.999, lrate_limit=0.001, momentum=0.9, beta_1=None, beta_2=None, eps=None, max_speed=None), sigma_init=0.04, sigma_decay=0.999, sigma_limit=0.01, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38)

In [18]:
num_generations = 1000
print_every_k_gens = 20

es_logging = ESLog(param_reshaper.total_params,
                   num_generations,
                   top_k=5,
                   maximize=True)
log = es_logging.initialize()

fit_shaper = FitnessShaper(centered_rank=True,
                           z_score=True,
                           w_decay=0.1,
                           maximize=True)

state = strategy.initialize(rng)

for gen in range(num_generations):
    rng, rng_init, rng_ask, rng_eval = jax.random.split(rng, 4)
    x, state = strategy.ask(rng_ask, state)
    reshaped_params = param_reshaper.reshape(x)
    fitness = evaluator.rollout(rng_eval, reshaped_params).mean(axis=1)
    fit_re = fit_shaper.apply(x, fitness)
    state = strategy.tell(x, fit_re, state)
    log = es_logging.update(log, x, fitness)
    
    if gen % print_every_k_gens == 0:
        print("Generation: ", gen, "Generation: ", log["log_top_1"][gen])

Generation:  0 Generation:  203.22612
Generation:  20 Generation:  203.62665


KeyboardInterrupt: 

# Visualize Learning Curve and Policy

In [None]:
# Plot the learning curve over generations
es_logging.plot(log, "Ant MLP OpenAI-ES")

In [31]:
from IPython.display import HTML
from brax import envs
from brax.io import html

env = envs.create(env_name="ant")
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(network.apply)

net_params = param_reshaper.reshape_single(state.mean)

rollout = []
rng = jax.random.PRNGKey(seed=0)
env_state = jit_env_reset(rng=rng)
cum_reward = 0
for _ in range(1000):
    rollout.append(env_state)
    act_rng, rng = jax.random.split(rng)
    norm_obs = evaluator.obs_normalizer.normalize_obs(env_state.obs, evaluator.obs_params)
    act = jit_inference_fn(net_params, env_state.obs, act_rng)
    env_state = jit_env_step(env_state, act)
    cum_reward += env_state.reward

print("Cumulative reward:", cum_reward)
HTML(html.render(env.sys, [s.qp for s in rollout]))

Cumulative reward: -2459.6707
