In [1]:
%env CUDA_VISIBLE_DEVICES=0

import numpy as np
import jax.numpy as jnp
import jax


env: CUDA_VISIBLE_DEVICES=0


In [2]:
from es_map import jax_simple_es_train

In [7]:
config = {
    "ES_NUM_GENERATIONS" : 500,
    
    "ES_OPTIMIZER_TYPE" : "ADAM",
    "ES_lr" : 0.01,
    "ES_L2_COEFF" : 0.001,
    "ES_sigma" : 0.02,
}

In [8]:
%%time
results = jax_simple_es_train.train(config)

starting run with seed:  8479577531
starting iteration:  0
0 938.96344
starting iteration:  1
1 964.645
starting iteration:  2
2 970.0107
starting iteration:  3
3 988.2447
starting iteration:  4
4 987.97424
starting iteration:  5
5 993.34375
starting iteration:  6
6 998.0033
starting iteration:  7
7 1000.71893
starting iteration:  8
8 1003.2823
starting iteration:  9
9 1002.81744
starting iteration:  10
10 1003.18286
starting iteration:  11
11 1004.5213
starting iteration:  12
12 1004.64874
starting iteration:  13
13 1005.97284
starting iteration:  14
14 1005.35187
starting iteration:  15
15 1005.78784
starting iteration:  16
16 1007.6689
starting iteration:  17
17 1011.011
starting iteration:  18
18 1009.71576
starting iteration:  19
19 1011.9353
starting iteration:  20
20 1013.1077
starting iteration:  21
21 1014.0264
starting iteration:  22
22 1017.33185
starting iteration:  23
23 1018.41235
starting iteration:  24
24 1018.4844
starting iteration:  25
25 1023.09937
starting iteratio

In [5]:


# A simple ES loop
# A training loop for plain es with jax
import random
import numpy as np
import torch
import brax
from brax.envs import wrappers
from brax import jumpy as jp
from brax.envs import env
import jax.numpy as jnp
import jax

from es_map import map_elite_utils
from es_map import jax_evaluate
from es_map import jax_es_update




env_name = "ant"
population_size = 10000
evaluation_batch_size = 50
sigma = 0.02

# setup random seed

seed = random.randint(0, 10000000000)
print("starting run with seed: ",seed)
key = jax.random.PRNGKey(seed)
key, key_init_model = jax.random.split(key, 2)



starting run with seed:  7456265998


In [6]:
# setup env and model
env = jax_evaluate.create_ant(env_name,population_size,evaluation_batch_size)
model = jax_evaluate.create_MLP_model(env.observation_space.shape[1],env.action_space.shape[1])

# setup batched functions
batch_model_apply_fn = jax.jit(jax.vmap(model.apply))
batch_vec_to_params = jax.vmap(jax_evaluate.vec_to_params_tree,in_axes=[0, None,None])

# get initial parameters
initial_model_params = model.init(key_init_model)
vec,shapes,indicies = jax_evaluate.params_tree_to_vec(initial_model_params)

# I use torch, because i want to use torch optimizer, and since it is so little part of the copmutation (the grad update compared to the evaluations and grad caluclations),
# i dont care about it being slow
current_params = np.array(vec)

optimizer_state = None
observation_stats = {                  # this is a single obs stats to keep track of during the whole experiment. 
    "sum" : np.zeros(env.observation_space.shape[1]),       # This is always expanded, and always used to calculate the current mean and std
    "sumsq" : np.zeros(env.observation_space.shape[1]),
    "count" : 0,
}

generation_number = 0

key, key_create_pop = jax.random.split(key, 2)
all_params,perturbations = jax_es_update.jax_es_create_population(current_params,key_create_pop,
                                                                  popsize=population_size,
                                                                  eval_batch_size=evaluation_batch_size,sigma=sigma)



In [7]:
all_model_params = batch_vec_to_params(all_params,shapes,indicies)
fitness,bds,new_obs_stats = jax_evaluate.rollout_episodes(env,all_model_params,observation_stats,config,
                                                                        batch_model_apply_fn)

In [9]:
child_fitness = fitness[:population_size]
child_bds = bds[:population_size]
eval_fitness = fitness[population_size:]
eval_bds = bds[population_size:]

print(generation_number,np.mean(np.array(eval_fitness)))

#innovation = 
#evolvability = 
#entropy = 

grad = jax_es_update.jax_calculate_gradient(perturbations=perturbations,
                                            child_fitness=child_fitness,
                                            bds=bds,
                                            config=config,
                                            mode="fitness",
                                            archive=None)
grad = torch.from_numpy(np.array(grad))
current_params,optimizer_state = jax_es_update.do_gradient_step(current_params,grad,optimizer_state,config)


0 954.88007


In [35]:
def rollout_episodes(env,params,obs_stats,config,
                     batched_model_apply_fn):
    # env - gym wrapped batched brax env
    # params - batched parameter tree
    # obs_stats - obervation statistics for mean and var calculation for normalization
    #
    # Return
    # cumulative_reward - episode rewards
    # bds - episode behavior descriptors
    # new_obs_stats - updated obs_stats
    
    obs = env.reset()
    
    cumulative_reward = jnp.zeros(obs.shape[0])
    active_episode = jnp.ones_like(cumulative_reward)
    bd_dim = env._state.info["bd"].shape[1]
    bds = jnp.zeros([obs.shape[0],bd_dim])

    # prepare obs stats
    mean,var = map_elite_utils.calculate_obs_stats(obs_stats)
    mean = jnp.array(mean) # copy to gpu
    var = jnp.array(var)
    
    # also prepare variable to accumulate values to calcuate obs stats, we will add new obs and return these
    obs_sums = jnp.array(obs_stats["sum"])
    obs_squared_sums = jnp.array(obs_stats["sumsq"])
    obs_count = jnp.array([obs_stats["count"]])
    
    
    # max_steps = config[""] # TODO
    max_steps = 1000
    
    for step_i in range(max_steps):
    
        normalized_obs = (obs-mean) / var
        
        model_out = batched_model_apply_fn(params,normalized_obs)
        action = jax_evaluate.get_deterministic_actions(model_out)
    
        obs,reward,done,info = env.step(action)
        
        last_step_of_first_episode = active_episode * done # will only ever be 1 when we are at last step of first episode
        active_episode = active_episode * (1 - done) # once the first episode is done, active_episode will become and stay 0
        
        cumulative_reward += reward * active_episode
    
        # bd is sometimes nan and inf (we multiply by 0 in those cases, but still infects with nan...)
        info_bd = jnp.nan_to_num(info["bd"],nan=0.0, posinf=0.0, neginf=0.0)
        bds = bds + last_step_of_first_episode.reshape(-1,1) * info_bd

        # record observation stats, only count active episodes (zero out others)
        active_obs = active_episode.reshape(-1,1) * obs
        obs_sums = obs_sums + jnp.sum(active_obs,axis=0)
        obs_squared_sums = obs_squared_sums + jnp.sum(active_obs*active_obs,axis=0)
        obs_count = obs_count + jnp.sum(active_episode)
    
    # turn back obs stats into normal cpu format
    new_obs_stats = {
        "sum" : np.array(obs_sums),
        "sumsq" : np.array(obs_squared_sums),
        "count" : obs_count[0],
    }
    return cumulative_reward,bds,new_obs_stats




In [39]:
%%time
fitness,bds,new_obs_stats = rollout_episodes(env,all_model_params,observation_stats,config,
                                                                        batch_model_apply_fn)

CPU times: user 5.04 s, sys: 619 ms, total: 5.66 s
Wall time: 7.26 s


In [None]:


while True:
    if generation_number >= config["ES_NUM_GENERATIONS"]:
        print("Done, reached iteration: ",config["ES_NUM_GENERATIONS"])
        break

    print("starting iteration: ",generation_number)

    # Create pop
    key, key_create_pop = jax.random.split(key, 2)
    all_params,perturbations = jax_es_update.jax_es_create_population(current_params,key_create_pop,
                                                                      popsize=population_size,
                                                                      eval_batch_size=evaluation_batch_size,sigma=sigma)
    all_model_params = batch_vec_to_params(all_params,shapes,indicies)

    # Evaluate pop

    fitness,bds,new_obs_stats = jax_evaluate.rollout_episodes(env,all_model_params,observation_stats,config,
                                                                        batch_model_apply_fn)
    child_fitness = fitness[:population_size]
    child_bds = bds[:population_size]
    eval_fitness = fitness[population_size:]
    eval_bds = bds[population_size:]

    print(generation_number,np.mean(np.array(eval_fitness)))

    #innovation = 
    #evolvability = 
    #entropy = 

    grad = jax_es_update.jax_calculate_gradient(child_fitness,bds,config,mode="fitness")
    current_params,optimizer_state = jax_es_update.do_gradient_step(current_params,grad,optimizer_state,config)



    generation_number += 1

