# Evolving Pendulum Controllers
### [Last Update: June 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/gymnax/blob/main/examples/02_evolution.ipynb)

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

!pip install -q git+https://github.com/RobertTLange/gymnax.git@main
!pip install -q git+https://github.com/RobertTLange/evosax.git@main

## Population Rollouts with `gymnax` Environments

In this notebook we will use `gymnax` to parallelize fitness rollouts across population members and initial conditions. Let's start by defining a policy and the corresponding episode rollout using the `RolloutWrapper`:

In [13]:
import jax
import jax.numpy as jnp
from evosax import NetworkMapper
import gymnax
from gymnax.experimental import RolloutWrapper

# MLP Policy with categorical readout for acrobot
rng = jax.random.PRNGKey(0)
model = NetworkMapper["MLP"](
    num_hidden_units=64,
    num_hidden_layers=2,
    num_output_units=3,
    hidden_activation="relu",
    output_activation="categorical",
)


# Create placeholder params for env
env, env_params = gymnax.make("Acrobot-v1")
pholder = jnp.zeros(env.observation_space(env_params).shape)
policy_params = model.init(
    rng,
    x=pholder,
    rng=rng,
)

# Define rollout manager for pendulum env
manager = RolloutWrapper(model.apply, env_name="Acrobot-v1")

# Simple single episode rollout for policy
obs, action, reward, next_obs, done, cum_ret = manager.single_rollout(rng, policy_params)
reward.shape, cum_ret

((500,), DeviceArray([-500.], dtype=float32))

## Open-ES with MLP Controller

Next we instantiate the Evolution Strategy from `evosax` and set the hyperparameters of the strategy:

In [14]:
from evosax import OpenES
from evosax import ParameterReshaper, FitnessShaper

# Helper for parameter reshaping into appropriate datastructures
param_reshaper = ParameterReshaper(policy_params, n_devices=1)

# Instantiate and initialize the evolution strategy
strategy = OpenES(popsize=100,
                  num_dims=param_reshaper.total_params,
                  opt_name="adam")

es_params = strategy.default_params
es_params = es_params.replace(sigma_init=0.1, sigma_decay=0.999, sigma_limit=0.01)
es_params = es_params.replace(opt_params=es_params.opt_params.replace(
    lrate_init=0.1, lrate_decay=0.999, lrate_limit=0.001))

ParameterReshaper: 4803 parameters detected for optimization.


We then initialize the state of the search distribution and use a standard fitness shaping utility for OpenES:

In [15]:
es_state = strategy.initialize(rng)

fit_shaper = FitnessShaper(maximize=True, centered_rank=True)

Finally, we are ready to evolve our control policy using the simple `ask`/`tell` API:

In [16]:
num_generations = 100
num_mc_evals = 32
print_every_k_gens = 20

for gen in range(num_generations):
    rng, rng_init, rng_ask, rng_eval = jax.random.split(rng, 4)
    # Ask for candidates to evaluate
    x, es_state = strategy.ask(rng_ask, es_state)
    
    # Reshape parameters into flax FrozenDicts
    reshaped_params = param_reshaper.reshape(x)
    rng_batch_eval = jax.random.split(rng_eval, num_mc_evals)
    
    # Perform population evaluation
    _, _, _, _, _, cum_ret = manager.population_rollout(rng_batch_eval, reshaped_params)
    
    # Mean over MC rollouts, shape fitness and update strategy
    fitness = cum_ret.mean(axis=1).squeeze()
    fit_re = fit_shaper.apply(x, fitness)
    es_state = strategy.tell(x, fit_re, es_state)
    
    if (gen + 1) % print_every_k_gens == 0:
        print("Generation: ", gen + 1, "Generation: ", fitness.mean())

Generation:  20 Generation:  -489.0306
Generation:  40 Generation:  -244.8928
Generation:  60 Generation:  -107.26843
Generation:  80 Generation:  -86.66281
Generation:  100 Generation:  -80.88281


`evosax` also already comes equipped with a fitness rollout wrapper for all `gymnax` environments:

In [17]:
from evosax.problems import GymFitness

evaluator = GymFitness("Acrobot-v1", num_env_steps=500, num_rollouts=16, n_devices=1)
evaluator.set_apply_fn(param_reshaper.vmap_dict, model.apply)
evaluator.rollout(rng_eval, reshaped_params).mean(axis=1)

DeviceArray([-84.25  , -78.4375, -79.1875, -80.0625, -83.625 , -77.875 ,
             -85.9375, -75.8125, -80.0625, -78.5   , -78.0625, -78.125 ,
             -79.3125, -78.375 , -83.3125, -77.625 , -78.4375, -79.25  ,
             -79.4375, -78.8125, -79.5625, -78.3125, -79.3125, -77.9375,
             -80.6875, -75.5625, -79.0625, -78.875 , -77.625 , -77.8125,
             -78.375 , -73.25  , -80.375 , -78.4375, -78.4375, -96.8125,
             -79.4375, -74.8125, -81.5625, -89.4375, -76.5625, -84.5625,
             -78.625 , -78.375 , -79.625 , -77.75  , -76.4375, -84.5   ,
             -78.375 , -79.125 , -78.25  , -79.6875, -79.625 , -78.875 ,
             -77.25  , -78.875 , -77.8125, -84.3125, -79.4375, -77.375 ,
             -74.875 , -79.375 , -78.5625, -79.25  , -79.125 , -80.    ,
             -85.0625, -84.125 , -80.625 , -77.125 , -77.625 , -82.5   ,
             -78.    , -80.5   , -80.25  , -83.9375, -75.625 , -80.1875,
             -85.1875, -78.5625, -80.8125, -84.8125

## Evolving a Meta-LSTM to Control Different Length 2-Link Pendula

By default the two links in the Acrobot task have length 1. Wouldn't it be cool if we could solve the task for many different lengths? We will now evolve a recurrent controller that is capable of solving the Acrobot swing up task for all link lengths that sum up to two. In order to do so we will be sample a link combination that sums to 2:

In [18]:
# Let's have a look at the default environment settings
env_params

EnvParams(dt=0.2, link_length_1=1.0, link_length_2=1.0, link_mass_1=1.0, link_mass_2=1.0, link_com_pos_1=0.5, link_com_pos_2=0.5, link_moi=1.0, max_vel_1=12.566370614359172, max_vel_2=28.274333882308138, available_torque=DeviceArray([-1.,  0.,  1.], dtype=float32), torque_noise_max=0.0, max_steps_in_episode=500)

In [19]:
from gymnax.environments.classic_control.acrobot import EnvParams

# Sample a batch of environment parameters
def sample_link_params(rng, min_link=0.1, max_link=1.9):
    link_length_1 = jax.random.uniform(rng, (), minval=min_link, maxval=max_link)
    link_length_2 = 2 - link_length_1
    return EnvParams(link_length_1=link_length_1, link_length_2=link_length_2)

env_params = sample_link_params(rng)
env_params

EnvParams(dt=0.2, link_length_1=DeviceArray(0.19904068, dtype=float32), link_length_2=DeviceArray(1.8009593, dtype=float32), link_mass_1=1.0, link_mass_2=1.0, link_com_pos_1=0.5, link_com_pos_2=0.5, link_moi=1.0, max_vel_1=12.566370614359172, max_vel_2=28.274333882308138, available_torque=DeviceArray([-1.,  0.,  1.], dtype=float32), torque_noise_max=0.0, max_steps_in_episode=500)

We now simply incorporate the sampling step in our rollout routine:

In [20]:
def rollout(rng_input, policy_params, steps_in_episode):
    """Rollout a jitted gymnax episode with lax.scan."""
    # Reset the environment
    rng_reset, rng_episode, rng_link = jax.random.split(rng_input, 3)
    env_params = sample_link_params(rng_link)
    obs, state = env.reset(rng_reset, env_params)
    hidden = model.initialize_carry()

    def policy_step(state_input, tmp):
        """lax.scan compatible step transition in jax env."""
        obs, state, policy_params, rng, prev_a, hidden, cum_reward, valid_mask = state_input
        rng, rng_step, rng_net = jax.random.split(rng, 3)
        one_hot_action = jax.nn.one_hot(prev_a, 3).squeeze()
        aug_in = jnp.hstack([obs, one_hot_action])
        hidden, action = model.apply(policy_params, aug_in, hidden, rng_net)
        next_obs, next_state, reward, done, _ = env.step(
          rng_step, state, action, env_params
        )
        new_cum_reward = cum_reward + reward * valid_mask
        new_valid_mask = valid_mask * (1 - done)
        carry = [next_obs, next_state, policy_params, rng,
               action, hidden, new_cum_reward, new_valid_mask]
        return carry, [obs, action, reward, next_obs, done]

    # Scan over episode step loop
    carry_out, scan_out = jax.lax.scan(
      policy_step,
      [obs, state, policy_params, rng_episode, 0,
      hidden, jnp.array([0.0]), jnp.array([1.0])],
      (),
      steps_in_episode
    )
    return carry_out[-2]

We train an LSTM policy which can integrate the information of observation dynamics over time:

In [21]:
model = NetworkMapper["LSTM"](
    num_hidden_units=32,
    num_output_units=3,
    output_activation="categorical",
)

pholder = jnp.zeros((9,))
policy_params = model.init(
    rng,
    x=pholder,
    rng=rng,
    carry=model.initialize_carry()
)

In [22]:
rng_rollout = jax.vmap(rollout, in_axes=(0, None, None))
pop_rollout = jax.jit(jax.vmap(rng_rollout, in_axes=(None, 0, None)), static_argnums=2)
rollout(rng, policy_params, steps_in_episode=500)

DeviceArray([-500.], dtype=float32)

In [23]:
# Helper for parameter reshaping
param_reshaper = ParameterReshaper(policy_params, n_devices=1)

# Instantiate and initialize the evolution strategy
strategy = OpenES(popsize=100,
                  num_dims=param_reshaper.total_params,
                  opt_name="adam")

es_params = strategy.default_params
es_params = es_params.replace(sigma_init=0.1, sigma_decay=0.999, sigma_limit=0.01)
es_params = es_params.replace(opt_params=es_params.opt_params.replace(
    lrate_init=0.1, lrate_decay=0.999, lrate_limit=0.001))

es_state = strategy.initialize(rng)

fit_shaper = FitnessShaper(maximize=True, centered_rank=True)

ParameterReshaper: 5475 parameters detected for optimization.


In [24]:
num_generations = 200
num_mc_evals = 64
print_every_k_gens = 20

for gen in range(num_generations):
    rng, rng_init, rng_ask, rng_eval = jax.random.split(rng, 4)
    x, es_state = strategy.ask(rng_ask, es_state)
    reshaped_params = param_reshaper.reshape(x)
    rng_batch_eval = jax.random.split(rng_eval, num_mc_evals)
    cum_ret = pop_rollout(rng_batch_eval, reshaped_params, 500)
    fitness = cum_ret.mean(axis=1).squeeze()
    fit_re = fit_shaper.apply(x, fitness)
    es_state = strategy.tell(x, fit_re, es_state)
    
    if (gen + 1) % print_every_k_gens == 0:
        print("Generation: ", gen + 1, "Generation: ", fitness.mean())

Generation:  20 Generation:  -402.45718
Generation:  40 Generation:  -174.04187
Generation:  60 Generation:  -124.062965
Generation:  80 Generation:  -107.56796
Generation:  100 Generation:  -95.232185
Generation:  120 Generation:  -97.57
Generation:  140 Generation:  -88.842186
Generation:  160 Generation:  -87.27656
Generation:  180 Generation:  -91.187965
Generation:  200 Generation:  -87.63703
