# 07 - Evolving Brax Controllers
### [Last Update: June 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/07_brax_control.ipynb)

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

!pip install -q git+https://github.com/RobertTLange/evosax.git@main
!pip install -q brax evojax

## Open-ES with MLP Controller

In [1]:
import jax
import numpy as np
from evojax.obs_norm import ObsNormalizer
from evojax.sim_mgr import SimManager
from evojax.task.brax_task import BraxTask
from evojax.policy import MLPPolicy

from evosax import Strategies
from evosax.utils.evojax_wrapper import Evosax2JAX_Wrapper

In [2]:
def get_brax_task(
    env_name = "ant",
    hidden_dims = [32, 32, 32, 32],
):
    train_task = BraxTask(env_name, test=False)
    test_task = BraxTask(env_name, test=True)
    policy = MLPPolicy(
        input_dim=train_task.obs_shape[0],
        output_dim=train_task.act_shape[0],
        hidden_dims=hidden_dims,
    )
    return train_task, test_task, policy

In [3]:
print(jax.devices())

[cuda(id=0)]


In [5]:
train_task, test_task, policy = get_brax_task("ant")
solver = Evosax2JAX_Wrapper(
    Strategies["OpenES"],
    param_size=policy.num_params,
    pop_size=256,
    es_config={"maximize": True,
               "centered_rank": True,
               "lrate_init": 0.01,
               "lrate_decay": 0.999,
               "lrate_limit": 0.001},
    es_params={"sigma_init": 0.05,
    "sigma_decay": 0.999,
    "sigma_limit": 0.01},
    seed=0,
)
obs_normalizer = ObsNormalizer(
    obs_shape=train_task.obs_shape, dummy=not True
)
sim_mgr = SimManager(
    policy_net=policy,
    train_vec_task=train_task,
    valid_vec_task=test_task,
    seed=0,
    obs_normalizer=obs_normalizer,
    pop_size=256,
    use_for_loop=False,
    n_repeats=16,
    test_n_repeats=1,
    n_evaluations=128
)

print(f"START EVOLVING {policy.num_params} PARAMS.")
# Run ES Loop.
for gen_counter in range(10):
    params = solver.ask()
    scores, _ = sim_mgr.eval_params(params=params, test=False)
    solver.tell(fitness=scores)
    if gen_counter == 0 or (gen_counter + 1) % 2 == 0:
        test_scores, _ = sim_mgr.eval_params(
            params=solver.best_params, test=True
        )
        print(
            {
                "num_gens": gen_counter + 1,
            },
            {
                "train_perf": float(np.nanmean(scores)),
                "test_perf": float(np.nanmean(test_scores)),
            },
        )

MLPPolicy: 2024-08-27 16:18:08,672 [INFO] MLPPolicy.num_params = 4328
SimManager: 2024-08-27 16:18:08,827 [INFO] use_for_loop=False


START EVOLVING 4328 PARAMS.
{'num_gens': 1} {'train_perf': 126.9566421508789, 'test_perf': 598.7470703125}
{'num_gens': 2} {'train_perf': 127.49504089355469, 'test_perf': 158.68240356445312}
{'num_gens': 4} {'train_perf': 143.42929077148438, 'test_perf': 282.1661071777344}
{'num_gens': 6} {'train_perf': 148.34744262695312, 'test_perf': 298.7984924316406}
{'num_gens': 8} {'train_perf': 150.30282592773438, 'test_perf': 315.92645263671875}
{'num_gens': 10} {'train_perf': 162.5469970703125, 'test_perf': 239.8553009033203}


In [6]:
test_scores, _ = sim_mgr.eval_params(
            params=solver.best_params, test=True
        )
print(
    {
        "num_gens": gen_counter + 1,
    },
    {
        "train_perf": float(np.nanmean(scores)),
        "test_perf": float(np.nanmean(test_scores)),
    },
)

{'num_gens': 10} {'train_perf': 162.5469970703125, 'test_perf': 246.593994140625}


# Visualize Learning Curve and Policy

In [17]:

from IPython.display import HTML
from brax import envs
from brax.io import html
import jax
from brax.io import model
from brax.io import json

env = envs.create(env_name="ant")
task_reset_fn = jax.jit(env.reset)
policy_reset_fn = jax.jit(policy.reset)
step_fn = jax.jit(env.step)
act_fn = jax.jit(policy.get_actions)
obs_norm_fn = jax.jit(obs_normalizer.normalize_obs)

best_params = solver.best_params
obs_params = sim_mgr.obs_params

total_reward = 0
rollout = []
rng = jax.random.PRNGKey(seed=42)
task_state = task_reset_fn(rng=rng)
policy_state = policy_reset_fn(task_state)
while not task_state.done:
    rollout.append(task_state)
    task_state = task_state.replace(
    obs=obs_norm_fn(task_state.obs[None, :], obs_params).reshape(1, 27))
    act, policy_state = act_fn(task_state, best_params[None, :], policy_state)
    task_state = task_state.replace(
    obs=obs_norm_fn(task_state.obs[None, :], obs_params).reshape(27,))
    task_state = step_fn(task_state, act[0])
    total_reward = total_reward + task_state.reward

print("Cumulative reward:", total_reward)
HTML(html.render(env.sys, [s.pipeline_state for s in rollout]))
#HTML(html.render(env.sys, [rollout.state.pipeline_state]))

Cumulative reward: 505.7522


In [15]:
# env_name = 'ant' 
# backend = 'positional'
# env = envs.get_environment(env_name=env_name,
#                            backend=backend)
# state = jax.jit(env.reset)(rng = jax.random.PRNGKey(seed=42))
# HTML(html.render(env.sys, [state.pipeline_state]))
