In [1]:
import time
from functools import partial
from IPython.display import Image
import matplotlib.pyplot as plt 
import seaborn as sns

import jax 
import jax.numpy as jnp
from brax import envs
from brax.io import image
from brax import jumpy as jp
import evosax

ModuleNotFoundError: No module named 'seaborn'

In [46]:
# create env
env = envs.create(env_name="ant")
jit_env_step = jax.jit(env.step)
rng = jax.random.PRNGKey(0)
T = 100

In [47]:
@jax.jit
def eval_action_sequence(actions):
    reward = 0
    state = env.reset(rng=jax.random.PRNGKey(42))
    def body(i, carry):
        state, reward = carry
        state = jit_env_step(state, actions[i])
        reward = reward + state.reward 
        return [state, reward]
    carry = jax.lax.fori_loop(0, len(actions), body, [state, reward])
    return carry[1]

In [78]:
def run_es(rng, strategy, T, epochs=500):
    es_params = strategy.default_params
    state = strategy.initialize(rng, es_params)
    rewards = []
    times = [time.time()]
    for epoch in range(epochs):
        rng, rng_gen, rng_eval = jax.random.split(rng, 3)
        x, state = strategy.ask(rng_gen, state, es_params)
        actions = x.reshape(-1, T, env.action_size)
        reward = jax.vmap(eval_action_sequence)(actions)
        state = strategy.tell(x, -reward, state, es_params)
        rewards.append(-state.best_fitness)
        times.append(time.time())
    times = [t - times[0] for t in times[:-1]]
    return rewards, times, state

In [83]:
strategy = evosax.CMA_ES(num_dims=T*env.action_size, popsize=256)
CMA_rewards, CMA_times, CMA_states = run_es(rng, strategy, T)

In [81]:
strategy = evosax.PSO(num_dims=T*env.action_size, popsize=256)
PSO_rewards, PSO_times, PSO_states = run_es(rng, strategy, T)

In [84]:
strategy = evosax.GLD(num_dims=T*env.action_size, popsize=256)
GLD_rewards, GLD_times, GLD_states = run_es(rng, strategy, T)

In [85]:
strategy = evosax.ARS(num_dims=T*env.action_size, popsize=256)
ARS_rewards, ARS_times, ARS_states = run_es(rng, strategy, T)

In [86]:
strategy = evosax.PersistentES(num_dims=T*env.action_size, popsize=256)
PES_rewards, PES_times, PES_states = run_es(rng, strategy, T)

In [2]:
plt.plot(CMA_times, CMA_rewards, label="CMA_ES")
plt.plot(PSO_times, PSO_rewards, label="PSO")
plt.plot(GLD_times, GLD_rewards, label="GLD")
plt.plot(ARS_times, ARS_rewards, label="ARS")
plt.plot(PES_times, PES_rewards, label="PES")
plt.legend()
plt.title("Environment : Ant")
plt.xlabel("time [s]")
plt.ylabel("reward")
plt.show()

NameError: name 'CMA_times' is not defined

In [50]:
def get_image(evo_state):    
    x = evo_state.best_member
    actions = x.reshape(T, env.action_size)
    rollout = []
    state = env.reset(rng=jax.random.PRNGKey(42))
    for action in actions:
        state = jit_env_step(state, action)
        rollout.append(state)
    return image.render(env.sys, [s.qp for s in rollout], width=320, height=320)