# 08 - Evolution Strategy with Parallelization [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/08_multi_gpu.ipynb)

For a complete tutorial about parallelization with JAX, please refer to the [advanced guides](https://docs.jax.dev/en/latest/advanced_guide.html) in the official [documentation](https://docs.jax.dev/en/latest/).

## Installation

You will need Python 3.10 or later, and a working JAX installation. For example, you can install JAX on NVIDIA GPU with:

In [None]:
%pip install -U "jax[cuda]"

Then, install `evosax` from PyPi:

In [None]:
%pip install -U "evosax[examples]"

## Import

In [1]:
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import matplotlib.pyplot as plt
import optax

In [2]:
seed = 0
key = jax.random.key(seed)

## Number of devices to use

In [3]:
num_devices: int | None = None  # None or number of devices to use
num_devices = jax.device_count() if num_devices is None else num_devices

assert num_devices <= jax.device_count(), (
    f"Requested {num_devices} devices, but only {jax.device_count()} available."
)

## Mesh

In [7]:
devices = jax.devices()[:num_devices]
mesh = Mesh(devices, ("devices",))

## Sharding

In [8]:
replicate_sharding = NamedSharding(mesh, PartitionSpec())
parallel_sharding = NamedSharding(mesh, PartitionSpec("devices"))

## Humanoid environment

In [9]:
from evosax.problems import BraxProblem as Problem
from evosax.problems.networks import MLP, tanh_output_fn

policy = MLP(
    layer_sizes=(32, 32, 32, 32, 17),
    output_fn=tanh_output_fn,
)

problem = Problem(
    env_name="humanoid",
    policy=policy,
    episode_length=1000,
    num_rollouts=16,
    use_normalize_obs=True,
)

key, subkey = jax.random.split(key)
problem_state = problem.init(key)

key, subkey = jax.random.split(key)
solution = problem.sample(subkey)

In [10]:
print(f"Number of pararmeters: {sum(leaf.size for leaf in jax.tree.leaves(solution))}")

Number of pararmeters: 11569


## Open_ES

In [17]:
from evosax.algorithms import Open_ES as ES

num_generations = 512
population_size = 256

lr_schedule = optax.exponential_decay(
    init_value=0.01,
    transition_steps=num_generations,
    decay_rate=0.1,
)
std_schedule = optax.exponential_decay(
    init_value=0.05,
    transition_steps=num_generations,
    decay_rate=0.2,
)
es = ES(
    population_size=population_size,
    solution=solution,
    optimizer=optax.adam(learning_rate=lr_schedule),
    std_schedule=std_schedule,
)

## Replicate params across devices

In [18]:
params = jax.device_put(es.default_params, replicate_sharding)

## Initialize state

In [19]:
key, subkey = jax.random.split(key)
state = jax.jit(es.init, out_shardings=replicate_sharding)(subkey, solution, params)

In [20]:
key_ask, key_eval, key_tell = jax.random.split(key, 3)

## Ask

In [21]:
population, state = jax.jit(
    es.ask,
    out_shardings=(parallel_sharding, replicate_sharding)
)(key_ask, state, params)

In [22]:
# Population is sharded across devices
jax.debug.visualize_array_sharding(jax.tree.leaves(population)[0])

## Eval

In [23]:
fitness, problem_state, _ = problem.eval(key_eval, population, problem_state)

In [24]:
# Fitness is sharded across devices
jax.debug.visualize_array_sharding(fitness)

## Tell

In [25]:
state, _ = es.tell(key_tell, population, -fitness, state, params)

In [31]:
# State is replicated across devices
jax.debug.visualize_array_sharding(state.mean)

## Run

In [32]:
def step(carry, key):
    state, params, problem_state = carry
    key_ask, key_eval, key_tell = jax.random.split(key, 3)

    population, state = jax.jit(
        es.ask,
        out_shardings=(parallel_sharding, replicate_sharding)
    )(key_ask, state, params)

    fitness, problem_state, _ = problem.eval(key_eval, population, problem_state)

    state, metrics = es.tell(
        key_tell, population, -fitness, state, params
    )  # Minimize fitness

    return (state, params, problem_state), metrics

In [33]:
key, subkey = jax.random.split(key)
state = es.init(subkey, solution, params)

fitness_log = []
log_period = 32
for i in range(num_generations // log_period):
    # Train
    key, subkey = jax.random.split(key)
    keys = jax.random.split(subkey, log_period)
    (state, params, problem_state), metrics = jax.lax.scan(
        step,
        (state, params, problem_state),
        keys,
    )

    # Eval
    mean = es.get_mean(state)
    key, subkey = jax.random.split(key)
    fitness, problem_state, info = problem.eval(
        key, jax.tree.map(lambda x: x[None], mean), problem_state
    )
    print(f"Generation {(i + 1) * log_period:3d} | Mean fitness: {fitness.mean():.2f}")

Generation  32 | Mean fitness: 806.80
Generation  64 | Mean fitness: 857.39
Generation  96 | Mean fitness: 1289.51
Generation 128 | Mean fitness: 2215.57
Generation 160 | Mean fitness: 4937.22
Generation 192 | Mean fitness: 4966.06
Generation 224 | Mean fitness: 5001.54
Generation 256 | Mean fitness: 5012.36
Generation 288 | Mean fitness: 5022.10
Generation 320 | Mean fitness: 5031.70
Generation 352 | Mean fitness: 5041.81
Generation 384 | Mean fitness: 5056.41
Generation 416 | Mean fitness: 5066.43
Generation 448 | Mean fitness: 5090.71
Generation 480 | Mean fitness: 5114.17
Generation 512 | Mean fitness: 5160.81


## Visualize policy

In [34]:
mean = es.get_mean(state)
mean = es._unravel_solution(state.best_solution)

key, subkey = jax.random.split(key)
fitness, problem_state, info = problem.eval(
    key, jax.tree.map(lambda x: x[None], mean), problem_state
)
fitness[0]

Array(5182.995, dtype=float32)

In [35]:
from brax.io import html
from IPython.display import HTML

rollout = [
    jax.tree_util.tree_map(lambda x: x[0, 0, t], info["env_states"].pipeline_state)
    for t in range(200)
]

html_content = html.render(
    problem.env.sys.tree_replace({"opt.timestep": problem.env.dt}), rollout
)
HTML(html_content)

In [36]:
# Write to file
with open("humanoid_visualization.html", "w") as f:
    f.write(html_content)

print("Visualization saved to 'humanoid_visualization.html'")

Visualization saved to 'humanoid_visualization.html'
