# 07 - Evolving Brax Controllers
### [Last Update: June 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/07_brax_control.ipynb)

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

!pip install -q git+https://github.com/RobertTLange/evosax.git@main
!pip install -q brax evojax

## Open-ES with MLP Controller

In [1]:
import numpy as np
from evojax.obs_norm import ObsNormalizer
from evojax.sim_mgr import SimManager
from evojax.task.brax_task import BraxTask
from evojax.policy import MLPPolicy

from evosax import Strategies
from evosax.utils.evojax_wrapper import Evosax2JAX_Wrapper

In [2]:
def get_brax_task(
    env_name = "ant",
    hidden_dims = [32, 32, 32, 32],
):
    train_task = BraxTask(env_name, test=False)
    test_task = BraxTask(env_name, test=True)
    policy = MLPPolicy(
        input_dim=train_task.obs_shape[0],
        output_dim=train_task.act_shape[0],
        hidden_dims=hidden_dims,
    )
    return train_task, test_task, policy

In [3]:
train_task, test_task, policy = get_brax_task("ant")
solver = Evosax2JAX_Wrapper(
    Strategies["OpenES"],
    param_size=policy.num_params,
    pop_size=256,
    es_config={"maximize": True,
               "centered_rank": True,
               "lrate_init": 0.01,
               "lrate_decay": 0.999,
               "lrate_limit": 0.001},
    es_params={"sigma_init": 0.05,
    "sigma_decay": 0.999,
    "sigma_limit": 0.01},
    seed=0,
)
obs_normalizer = ObsNormalizer(
    obs_shape=train_task.obs_shape, dummy=not True
)
sim_mgr = SimManager(
    policy_net=policy,
    train_vec_task=train_task,
    valid_vec_task=test_task,
    seed=0,
    obs_normalizer=obs_normalizer,
    pop_size=256,
    use_for_loop=False,
    n_repeats=16,
    test_n_repeats=1,
    n_evaluations=128
)

print(f"START EVOLVING {policy.num_params} PARAMS.")
# Run ES Loop.
for gen_counter in range(1000):
    params = solver.ask()
    scores, _ = sim_mgr.eval_params(params=params, test=False)
    solver.tell(fitness=scores)
    if gen_counter == 0 or (gen_counter + 1) % 50 == 0:
        test_scores, _ = sim_mgr.eval_params(
            params=solver.best_params, test=True
        )
        print(
            {
                "num_gens": gen_counter + 1,
            },
            {
                "train_perf": float(np.nanmean(scores)),
                "test_perf": float(np.nanmean(test_scores)),
            },
        )

START EVOLVING 6248 PARAMS.
{'num_gens': 1} {'train_perf': 984.8515625, 'test_perf': 996.6640625}
{'num_gens': 50} {'train_perf': 977.61962890625, 'test_perf': 995.9266357421875}
{'num_gens': 100} {'train_perf': 972.63818359375, 'test_perf': 998.8201904296875}
{'num_gens': 150} {'train_perf': 974.7550048828125, 'test_perf': 1005.531982421875}
{'num_gens': 200} {'train_perf': 1276.9852294921875, 'test_perf': 1715.9610595703125}
{'num_gens': 250} {'train_perf': 1773.773681640625, 'test_perf': 2281.2216796875}
{'num_gens': 300} {'train_perf': 2309.93212890625, 'test_perf': 2937.6201171875}
{'num_gens': 350} {'train_perf': 2772.38134765625, 'test_perf': 3408.61474609375}
{'num_gens': 400} {'train_perf': 3173.67919921875, 'test_perf': 3793.0986328125}
{'num_gens': 450} {'train_perf': 3442.1396484375, 'test_perf': 4159.99365234375}
{'num_gens': 500} {'train_perf': 3810.6884765625, 'test_perf': 4592.73876953125}
{'num_gens': 550} {'train_perf': 4118.63671875, 'test_perf': 4821.9951171875}
{'n

# Visualize Learning Curve and Policy

In [20]:
from IPython.display import HTML
from brax import envs
from brax.io import html
import jax

env = envs.create(env_name="ant")
task_reset_fn = jax.jit(env.reset)
policy_reset_fn = jax.jit(policy.reset)
step_fn = jax.jit(env.step)
act_fn = jax.jit(policy.get_actions)
obs_norm_fn = jax.jit(obs_normalizer.normalize_obs)

best_params = solver.best_params
obs_params = sim_mgr.obs_params

total_reward = 0
rollout = []
rng = jax.random.PRNGKey(seed=42)
task_state = task_reset_fn(rng=rng)
policy_state = policy_reset_fn(task_state)
while not task_state.done:
    rollout.append(task_state)
    task_state = task_state.replace(
    obs=obs_norm_fn(task_state.obs[None, :], obs_params).reshape(1, 87))
    act, policy_state = act_fn(task_state, best_params[None, :], policy_state)
    task_state = task_state.replace(
    obs=obs_norm_fn(task_state.obs[None, :], obs_params).reshape(87,))
    task_state = step_fn(task_state, act[0])
    total_reward = total_reward + task_state.reward

print("Cumulative reward:", total_reward)
HTML(html.render(env.sys, [s.qp for s in rollout]))

Cumulative reward: 1117.1326
