In [92]:
%env CUDA_VISIBLE_DEVICES=0
%env XLA_PYTHON_CLIENT_PREALLOCATE=false

from es_map import qdax_task

import numpy as np
import brax
from brax import jumpy as jp
from brax.envs import env
import jax.numpy as jnp
import jax

import matplotlib.pyplot as plt


from brax.envs import wrappers
from brax.io import html
from brax import envs

env: CUDA_VISIBLE_DEVICES=0
env: XLA_PYTHON_CLIENT_PREALLOCATE=false


In [93]:
from es_map import behavior_map
from es_map import map_elite_utils
from es_map import jax_evaluate

In [94]:
key = jax.random.PRNGKey(777)
key_envs = jax.random.split(key, 1)

In [95]:

#run_path = "/scratch/ak1774/runs/large_files_jax/run-20220303_085708-pk4loduh"
#import json
#with open(run_path+'/config.json') as f:
#    config = json.load(f)

In [96]:
run_id = "n0img15v"
def id_to_path(id):
    from glob import glob    
    path = glob("/scratch/ak1774/runs/large_files_jax/*"+id, recursive = False)[0]
    return path
run_path = id_to_path(run_id)

import json
with open(run_path+'/config.json') as f:
    config = json.load(f)
theta = np.load(run_path+"/theta.npy")

import pickle
with open(run_path+'/obs_stats.pickle', 'rb') as handle:
    obs_stats = pickle.load(handle)

from es_map import jax_evaluate

In [7]:
def create_env(env_name):
    from es_map.qdax_envs import unidirectional_envs
    from es_map.qdax_envs import omnidirectional_envs
    if env_name == "ant":
        env = unidirectional_envs.ant.QDUniAnt()
    elif env_name == "walker":
        env = unidirectional_envs.walker.QDUniWalker()
    elif env_name == "hopper":
        env = unidirectional_envs.hopper.QDUniHopper()
    elif env_name == "halfcheetah":
        env = unidirectional_envs.halfcheetah.QDUniHalfcheetah()
    elif env_name == "humanoid":
        env = unidirectional_envs.humanoid.QDUniHumanoid()
    elif env_name == "ant_omni":
        env = omnidirectional_envs.ant.QDOmniAnt()
    elif env_name == "humanoid_omni":
        env = omnidirectional_envs.humanoid.QDOmniHumanoid()
    else:
        raise "unknown env name"
        
    env = wrappers.EpisodeWrapper(env, episode_length=1000, action_repeat=1)
    env = wrappers.VectorWrapper(env, batch_size=10) # for each iteration, we evaluate the children and the parent performance in parallel
    env = wrappers.AutoResetWrapper(env) # not sure if this is nessasary, we only look at the first episode
    env = wrappers.VectorGymWrapper(env) 
    return env

In [87]:
env = create_env(config["env_name"])
env = create_env("humanoid_omni")
config["env_name"]

'humanoid'

In [22]:
action = env.action_space.sample()
obs = env.reset()

In [13]:
obs.shape

(10, 87)

In [11]:
env

VectorGymWrapper(10)

In [5]:
from es_map.qdax_envs import unidirectional_envs
from es_map.qdax_envs import omnidirectional_envs

In [90]:
obs = env.reset()
    
cumulative_reward = jnp.zeros(obs.shape[0])
active_episode = jnp.ones_like(cumulative_reward)
bd_dim = env._state.info["bd"].shape[1]
bds = jnp.zeros([obs.shape[0],bd_dim])

final_pos = jnp.zeros([obs.shape[0],2]) # only for plotting purpuses

# also prepare variable to accumulate values to calcuate obs stats, we will add new obs and return these
obs_sums = jnp.array(obs_stats["sum"])
obs_squared_sums = jnp.array(obs_stats["sumsq"])
obs_count = jnp.array([obs_stats["count"]])


# prepare obs stats
mean,var = map_elite_utils.calculate_obs_stats(obs_stats)
mean = jnp.array(mean) # copy to gpu
var = jnp.array(var)

max_steps = config["episode_max_length"]

step_i = 0


In [91]:
prev_pos = np.zeros([2,2])
for step_i in range(100):

    normalized_obs = (obs-mean) / var

    action = env.action_space.sample()
    action = jnp.array(action)

    before_step_pos = env._state.qp.pos[:,0,0:2] 
    
    obs,reward,done,info = env.step(action)
 
    if step_i == (max_steps-1): # if we reached the max time limit, set done to 1 
        done = jnp.ones_like(done) 
    
    last_step_of_first_episode = active_episode * done # will only ever be 1 when we are at last step of first episode
    active_episode = active_episode * (1 - done) # once the first episode is done, active_episode will become and stay 0

    cumulative_reward += reward * active_episode

    # bd is sometimes nan and inf (we multiply by 0 in those cases, but still infects with nan...)
    info_bd = jnp.nan_to_num(info["bd"],nan=0.0, posinf=0.0, neginf=0.0)
    bds = bds + last_step_of_first_episode.reshape(-1,1) * info_bd

    # even when bd is not final pos, it is good to have the final position for plotting purpuses
    # we want to take the xy pos of the body (coord system is z up, right handed)
    # the pos is in the shape of [batch_dim,num_bodies,3]
    # we want to take body 0, because that is the torso i think for all robots
    current_pos = env._state.qp.pos[:,0,0:2]   # NOTE current pos is already 0 when the episode finished, we must use the previous one.
    
    final_pos = final_pos + last_step_of_first_episode.reshape(-1,1) * before_step_pos

    # record observation stats, only count active episodes (zero out others)
    active_obs = active_episode.reshape(-1,1) * obs
    obs_sums = obs_sums + jnp.sum(active_obs,axis=0)
    obs_squared_sums = obs_squared_sums + jnp.sum(active_obs*active_obs,axis=0)
    obs_count = obs_count + jnp.sum(active_episode)

    print(step_i,"##############")
    print("active_episode",active_episode[0])
    print("done",done[0])
    print("last_step_of_first_episode",last_step_of_first_episode[0])
    print("pos",current_pos[0,0],before_step_pos[0,0])
    #print("masked_pos",(last_step_of_first_episode.reshape(-1,1) * before_step_pos)[0,0])
    print("final_pos",final_pos[0,0])
    print("info_bd",info_bd[0,0])
    print("bds",bds[0,0])
    
    prev_pos = current_pos.copy()
    


0 ##############
active_episode 1.0
done 0.0
last_step_of_first_episode 0.0
pos 8.1233575e-06 0.0
final_pos 0.0
info_bd 0.0
bds 0.0
1 ##############
active_episode 1.0
done 0.0
last_step_of_first_episode 0.0
pos 7.5327196e-05 8.1233575e-06
final_pos 0.0
info_bd 8.1233575e-06
bds 0.0
2 ##############
active_episode 1.0
done 0.0
last_step_of_first_episode 0.0
pos 5.524418e-05 7.5327196e-05
final_pos 0.0
info_bd 7.5327196e-05
bds 0.0
3 ##############
active_episode 1.0
done 0.0
last_step_of_first_episode 0.0
pos -8.0070415e-05 5.524418e-05
final_pos 0.0
info_bd 5.524418e-05
bds 0.0
4 ##############
active_episode 1.0
done 0.0
last_step_of_first_episode 0.0
pos 0.000777387 -8.0070415e-05
final_pos 0.0
info_bd -8.0070415e-05
bds 0.0
5 ##############
active_episode 1.0
done 0.0
last_step_of_first_episode 0.0
pos 0.001202599 0.000777387
final_pos 0.0
info_bd 0.000777387
bds 0.0
6 ##############
active_episode 1.0
done 0.0
last_step_of_first_episode 0.0
pos 0.00075968093 0.001202599
final_pos 

In [None]:

for step_i in range(max_steps):


    normalized_obs = (obs-mean) / var

    model_out = batched_model_apply_fn(params,normalized_obs)
    action = get_deterministic_actions(model_out)

    obs,reward,done,info = env.step(action)

    if step_i == (max_steps-1): # if we reached the max time limit, set done to 1 
        done = jnp.ones_like(done) 
    last_step_of_first_episode = active_episode * done # will only ever be 1 when we are at last step of first episode
    active_episode = active_episode * (1 - done) # once the first episode is done, active_episode will become and stay 0

    cumulative_reward += reward * active_episode



    # bd is sometimes nan and inf (we multiply by 0 in those cases, but still infects with nan...)
    info_bd = jnp.nan_to_num(info["bd"],nan=0.0, posinf=0.0, neginf=0.0)
    bds = bds + last_step_of_first_episode.reshape(-1,1) * info_bd

    # even when bd is not final pos, it is good to have the final position for plotting purpuses
    # we want to take the xy pos of the body (coord system is z up, right handed)
    # the pos is in the shape of [batch_dim,num_bodies,3]
    # we want to take body 0, because that is the torso i think for all robots
    current_pos = env._state.qp.pos[:,0,0:2]  
    final_pos = final_pos + last_step_of_first_episode.reshape(-1,1) * current_pos

    # record observation stats, only count active episodes (zero out others)
    active_obs = active_episode.reshape(-1,1) * obs
    obs_sums = obs_sums + jnp.sum(active_obs,axis=0)
    obs_squared_sums = obs_squared_sums + jnp.sum(active_obs*active_obs,axis=0)
    obs_count = obs_count + jnp.sum(active_episode)

# turn back obs stats into normal cpu format
new_obs_stats = {
    "sum" : np.array(obs_sums),
    "sumsq" : np.array(obs_squared_sums),
    "count" : obs_count[0],
}

In [41]:
a = np.zeros(5)
a[2] = 1
b = np.random.randn(5,2)
b

array([[-1.33474468,  0.48940328],
       [-0.21938814, -0.20488821],
       [-1.21252944,  1.64795951],
       [-0.86664792, -0.37448839],
       [ 0.27313648,  0.39886298]])

In [43]:
a.reshape(-1,1)*b

array([[-0.        ,  0.        ],
       [-0.        , -0.        ],
       [-1.21252944,  1.64795951],
       [-0.        , -0.        ],
       [ 0.        ,  0.        ]])

In [45]:
aa = jnp.array(a)
bb = jnp.array(b)
aa.reshape(-1,1)*bb

DeviceArray([[-0.       ,  0.       ],
             [-0.       , -0.       ],
             [-1.2125294,  1.6479595],
             [-0.       , -0.       ],
             [ 0.       ,  0.       ]], dtype=float32)