# Evolving Pendulum Controllers
### [Last Update: July 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/gymnax/blob/main/examples/es_in_gymnax.ipynb)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

!pip install -q git+https://github.com/RobertTLange/gymnax.git@main
!pip install -q git+https://github.com/RobertTLange/evosax.git@main

## Population Rollouts with `gymnax` Environments

In this notebook we will use `gymnax` to parallelize fitness rollouts across population members and initial conditions. Let's start by defining a policy and the corresponding episode rollout using the `RolloutWrapper`:

In [39]:
import jax
import jax.numpy as jnp
from evosax import NetworkMapper

# MLP Policy with gaussian readout for pendulum
rng = jax.random.PRNGKey(0)
model = NetworkMapper["MLP"](
    num_hidden_units=64,
    num_hidden_layers=2,
    num_output_units=3,
    hidden_activation="relu",
    output_activation="categorical",
)

import gymnax

env, env_params = gymnax.make("Acrobot-v1")
pholder = jnp.zeros(env.observation_space(env_params).shape)
policy_params = model.init(
    rng,
    x=pholder,
    rng=rng,
)


from gymnax.experimental import RolloutWrapper

# Define rollout manager for pendulum env
manager = RolloutWrapper(model.apply, env_name="Acrobot-v1")

# Simple single episode rollout for policy
obs, action, reward, next_obs, done, cum_ret = manager.single_rollout(rng, policy_params)
reward.shape, cum_ret

((500,), DeviceArray([-500.], dtype=float32))

## Open-ES with MLP Controller

Next we instantiate the Evolution Strategy and get a set of parameters to evaluate:

In [42]:
from evosax import OpenES
from evosax import ParameterReshaper, FitnessShaper, NetworkMapper

# Helper for parameter reshaping
param_reshaper = ParameterReshaper(policy_params)

# Instantiate and initialize the evolution strategy
strategy = OpenES(popsize=100,
                  num_dims=param_reshaper.total_params,
                  opt_name="adam")

es_params = strategy.default_params
es_params = es_params.replace(sigma_init=0.1, sigma_decay=0.999, sigma_limit=0.01)
es_params = es_params.replace(opt_params=es_params.opt_params.replace(
    lrate_init=0.1, lrate_decay=0.999, lrate_limit=0.001))

es_state = strategy.initialize(rng)

fit_shaper = FitnessShaper(maximize=True, centered_rank=True)

ParameterReshaper: 4803 parameters detected for optimization.


In [43]:
num_generations = 100
num_mc_evals = 16
print_every_k_gens = 20

for gen in range(num_generations):
    rng, rng_init, rng_ask, rng_eval = jax.random.split(rng, 4)
    x, es_state = strategy.ask(rng_ask, es_state)
    reshaped_params = param_reshaper.reshape(x)
    rng_batch_eval = jax.random.split(rng_eval, num_mc_evals)
    _, _, _, _, _, cum_ret = manager.population_rollout(rng_batch_eval, reshaped_params)
    fitness = cum_ret.mean(axis=1).squeeze()
    fit_re = fit_shaper.apply(x, fitness)
    es_state = strategy.tell(x, fit_re, es_state)
    
    if gen % print_every_k_gens == 0:
        print("Generation: ", gen, "Generation: ", fitness.mean())

Generation:  0 Generation:  -496.91
Generation:  20 Generation:  -488.30875
Generation:  40 Generation:  -265.81
Generation:  60 Generation:  -102.61312
Generation:  80 Generation:  -83.508125
Generation:  100 Generation:  -82.125
Generation:  120 Generation:  -79.756874


KeyboardInterrupt: 

`evosax` also already comes equipped with a fitness rollout wrapper for all `gymnax` environments:

In [None]:
from evosax.problems import GymFitness

evaluator = GymFitness("CartPole-v1", num_env_steps=500, num_rollouts=16)
evaluator.set_apply_fn(param_reshaper.vmap_dict, network.apply)

## Evolving a Meta-LSTM to Control Different Length 2-Link Pendula

By default the two links in the Acrobot task have length 1. Wouldn't it be cool if we could solve the task for many different lengths? We will now evolve a recurrent controller that is capable of solving the Acrobot swing up task for all link lengths that sum up to two. 

In [44]:
env_params

EnvParams(dt=0.2, link_length_1=1.0, link_length_2=1.0, link_mass_1=1.0, link_mass_2=1.0, link_com_pos_1=0.5, link_com_pos_2=0.5, link_moi=1.0, max_vel_1=12.566370614359172, max_vel_2=28.274333882308138, available_torque=DeviceArray([-1.,  0.,  1.], dtype=float32), torque_noise_max=0.0, max_steps_in_episode=500)

In [54]:
model = NetworkMapper["LSTM"](
    num_hidden_units=32,
    num_output_units=3,
    output_activation="categorical",
)

In [55]:
from gymnax.environments.classic_control.acrobot import EnvParams

# Sample a batch of environment parameters
def sample_link_params(rng, min_link=0.1, max_link=1.9):
    link_length_1 = jax.random.uniform(rng, (), minval=min_link, maxval=max_link)
    link_length_2 = 2 - link_length_1
    return EnvParams(link_length_1=link_length_1, link_length_2=link_length_2)

env_params = sample_link_params(rng)
env_params

EnvParams(dt=0.2, link_length_1=DeviceArray(1.0257741, dtype=float32), link_length_2=DeviceArray(0.9742259, dtype=float32), link_mass_1=1.0, link_mass_2=1.0, link_com_pos_1=0.5, link_com_pos_2=0.5, link_moi=1.0, max_vel_1=12.566370614359172, max_vel_2=28.274333882308138, available_torque=DeviceArray([-1.,  0.,  1.], dtype=float32), torque_noise_max=0.0, max_steps_in_episode=500)

In [62]:
def rollout(rng_input, policy_params, steps_in_episode):
  """Rollout a jitted gymnax episode with lax.scan."""
  # Reset the environment
  rng_reset, rng_episode, rng_link = jax.random.split(rng_input, 3)
  env_params = sample_link_params(rng_link)
  obs, state = env.reset(rng_reset, env_params)
  hidden = model.initialize_carry()

  def policy_step(state_input, tmp):
      """lax.scan compatible step transition in jax env."""
      obs, state, policy_params, rng, prev_a, hidden, cum_reward, valid_mask = state_input
      rng, rng_step, rng_net = jax.random.split(rng, 3)
      one_hot_action = jax.nn.one_hot(prev_a, 3).squeeze()
      aug_in = jnp.hstack([obs, one_hot_action])
      hidden, action = model.apply(policy_params, aug_in, hidden, rng_net)
      next_obs, next_state, reward, done, _ = env.step(
          rng_step, state, action, env_params
      )
      new_cum_reward = cum_reward + reward * valid_mask
      new_valid_mask = valid_mask * (1 - done)
      carry = [next_obs, next_state, policy_params, rng,
               action, hidden, new_cum_reward, new_valid_mask]
      return carry, [obs, action, reward, next_obs, done]

  # Scan over episode step loop
  carry_out, scan_out = jax.lax.scan(
      policy_step,
      [obs, state, policy_params, rng_episode, 0,
      hidden, jnp.array([0.0]), jnp.array([1.0])],
      (),
      steps_in_episode
  )
  return carry_out[-2]

In [63]:
pholder = jnp.zeros((9,))
policy_params = model.init(
    rng,
    x=pholder,
    rng=rng,
    carry=model.initialize_carry()
)

rollout(rng, policy_params, steps_in_episode=500)



DeviceArray([-500.], dtype=float32)

In [66]:
rng_rollout = jax.vmap(rollout, in_axes=(0, None, None))
pop_rollout = jax.vmap(rng_rollout, in_axes=(None, 0, None))

In [67]:
# Helper for parameter reshaping
param_reshaper = ParameterReshaper(policy_params)

# Instantiate and initialize the evolution strategy
strategy = OpenES(popsize=100,
                  num_dims=param_reshaper.total_params,
                  opt_name="adam")

es_params = strategy.default_params
es_params = es_params.replace(sigma_init=0.1, sigma_decay=0.999, sigma_limit=0.01)
es_params = es_params.replace(opt_params=es_params.opt_params.replace(
    lrate_init=0.1, lrate_decay=0.999, lrate_limit=0.001))

es_state = strategy.initialize(rng)

fit_shaper = FitnessShaper(maximize=True, centered_rank=True)

ParameterReshaper: 5475 parameters detected for optimization.


In [68]:
num_generations = 100
num_mc_evals = 16
print_every_k_gens = 20

for gen in range(num_generations):
    rng, rng_init, rng_ask, rng_eval = jax.random.split(rng, 4)
    x, es_state = strategy.ask(rng_ask, es_state)
    reshaped_params = param_reshaper.reshape(x)
    rng_batch_eval = jax.random.split(rng_eval, num_mc_evals)
    cum_ret = pop_rollout(rng_batch_eval, reshaped_params, 500)
    fitness = cum_ret.mean(axis=1).squeeze()
    fit_re = fit_shaper.apply(x, fitness)
    es_state = strategy.tell(x, fit_re, es_state)
    
    if gen % print_every_k_gens == 0:
        print("Generation: ", gen, "Generation: ", fitness.mean())



Generation:  0 Generation:  -493.6856
Generation:  20 Generation:  -449.6606
Generation:  40 Generation:  -225.56062
Generation:  60 Generation:  -139.83624
