In [1]:
import os
os.chdir('/home/sumeet/QDax')

In [2]:
import jax
import jax.numpy as jnp
import numpy as np
import functools

from qdax.core.neuroevolution.networks.networks import MLP
from qdax import environments
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.mdp_utils import scoring_function
from qdax.types import Genotype
from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire
from jax._src.flatten_util import ravel_pytree

from IPython.display import HTML, Image
from IPython.display import display
from brax.io import html, image

In [3]:
# define the env
env_name='ant_uni'
seed=1111
episode_length=1000
env = environments.create(env_name, episode_length=episode_length)
env_batch_size=1

In [4]:
# define the MLP architecture
policy_layer_sizes = (128, 128) + (env.action_size,)
policy_network = MLP(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jnp.tanh,
)

In [5]:
# Init a random key
random_key = jax.random.PRNGKey(seed)
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=env_batch_size)
fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)

# Create the initial environment states
random_key, subkey = jax.random.split(random_key)
keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=env_batch_size, axis=0)
reset_fn = jax.jit(jax.vmap(env.reset))
init_states = reset_fn(keys)

fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)

2023-01-09 17:02:56.901624: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:497] The NVIDIA driver's CUDA version is 11.8 which is older than the ptxas CUDA version (12.0.76). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [6]:

def load_archive(random_key):
    cp_path = '/home/sumeet/QDax/experiments/pga_me_ant_uni_testrun_seed_1111/checkpoints/checkpoint_00399/'
    random_key, subkey = jax.random.split(random_key)
    fake_batch = jnp.zeros(shape=(env.observation_size,))
    fake_params = policy_network.init(subkey, fake_batch)

    _, reconstruction_fn = ravel_pytree(fake_params)
    repertoire = MapElitesRepertoire.load(reconstruction_fn=reconstruction_fn, path=cp_path)
    return repertoire

In [9]:
repertoire = load_archive(random_key=random_key)
best_idx = jnp.argmax(repertoire.fitnesses)
best_fitness = jnp.max(repertoire.fitnesses)
best_bd = repertoire.descriptors[best_idx]
print(
    f"Best fitness in the repertoire: {best_fitness:.2f}\n",
    f"Behavior descriptor of the best individual in the repertoire: {best_bd}\n",
    f"Index in the repertoire of this individual: {best_idx}\n"
)


Best fitness in the repertoire: 5042.92
 Behavior descriptor of the best individual in the repertoire: [0.275      0.36400002 0.305      0.07      ]
 Index in the repertoire of this individual: 3230



In [10]:
my_params = jax.tree_util.tree_map(
    lambda x: x[best_idx],
    repertoire.genotypes
)

jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(policy_network.apply)

rollout = []
rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)
while not state.done:
    rollout.append(state)
    action = jit_inference_fn(my_params, state.obs)
    state = jit_env_step(state, action)

print(f"The trajectory of this individual contains {len(rollout)} transitions.")

HTML(html.render(env.sys, [s.qp for s in rollout[:500]]))

The trajectory of this individual contains 1000 transitions.
