In [None]:
from IPython.display import clear_output

try:
    import qdax
except:
    print("QDax not found. Installing...")
    !pip install qdax[cuda12]
    import qdax

clear_output()

In [None]:
import functools
import math
import time
from typing import Tuple

import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt
from brax.v1.io import html
from IPython.display import HTML
from tqdm import tqdm

from qdax import environments
from qdax.baselines.pbt import PBTTrainingState
from qdax.baselines.td3_pbt import PBTTD3, PBTTD3Config
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax.core.emitters.pbt_me_emitter import PBTEmitter, PBTEmitterConfig
from qdax.core.emitters.pbt_variation_operators import td3_pbt_variation_fn
from qdax.core.distributed_map_elites import DistributedMAPElites
from qdax.custom_types import RNGKey
from qdax.utils.metrics import default_qd_metrics
from qdax.utils.plotting import plot_2d_map_elites_repertoire, plot_map_elites_results

In [None]:
jax.config.update("jax_platform_name", "cpu")

In [None]:
# Get devices (change gpu by tpu if needed)
devices = jax.devices('gpu')
num_devices = len(devices)
print(f"Detected the following {num_devices} device(s): {devices}")

In [None]:
env_name = "anttrap"

seed = 0

# TD3 config
episode_length = 1000
batch_size = 256
policy_delay = 2
soft_tau_update = 0.005
critic_hidden_layer_size = (256, 256)
policy_hidden_layer_size = (256, 256)

# Emitter config
buffer_size = 100000
pg_population_size_per_device = 10
ga_population_size_per_device = 10
num_training_steps = 5000
env_batch_size = 250
grad_updates_per_step = 1.0
iso_sigma = 0.005
line_sigma = 0.05

fraction_best_to_replace_from = 0.1
fraction_to_replace_from_best = 0.2
fraction_to_replace_from_samples = 0.4
# this fraction is used only for transfer between devices
fraction_sort_exchange = 0.1

eval_env_batch_size = 1

# MAP-Elites config
num_init_cvt_samples = 50000
num_centroids = 1024
log_period = 1
num_iterations = 20
save_repertoire_freq = 5
# num_iterations = 450
# save_repertoire_freq = 15

In [None]:
# Initialize environments
env = environments.create(
    env_name=env_name,
    batch_size=env_batch_size * pg_population_size_per_device,
    episode_length=episode_length,
    auto_reset=True,
)

eval_env = environments.create(
    env_name=env_name,
    batch_size=eval_env_batch_size,
    episode_length=episode_length,
    auto_reset=True,
)

In [None]:
min_descriptor, max_descriptor = env.descriptor_limits

In [None]:
key = jax.random.key(seed)
key, subkey_1, subkey_2 = jax.random.split(key, 3)
env_states = jax.jit(env.reset)(rng=subkey_1)
eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey_2)

In [None]:
# get agent
config = PBTTD3Config(
    episode_length=episode_length,
    batch_size=batch_size,
    policy_delay=policy_delay,
    soft_tau_update=soft_tau_update,
    critic_hidden_layer_size=critic_hidden_layer_size,
    policy_hidden_layer_size=policy_hidden_layer_size,
)

agent = PBTTD3(config=config, action_size=env.action_size)

In [None]:
# init emitter
emitter_config = PBTEmitterConfig(
    buffer_size=buffer_size,
    num_training_iterations=num_training_steps // env_batch_size,
    env_batch_size=env_batch_size,
    grad_updates_per_step=grad_updates_per_step,
    pg_population_size_per_device=pg_population_size_per_device,
    ga_population_size_per_device=ga_population_size_per_device,
    num_devices=num_devices,
    fraction_best_to_replace_from=fraction_best_to_replace_from,
    fraction_to_replace_from_best=fraction_to_replace_from_best,
    fraction_to_replace_from_samples=fraction_to_replace_from_samples,
    fraction_sort_exchange=fraction_sort_exchange,
)

In [None]:
variation_fn = functools.partial(
    td3_pbt_variation_fn, iso_sigma=iso_sigma, line_sigma=line_sigma
)

In [None]:
emitter = PBTEmitter(
    pbt_agent=agent,
    config=emitter_config,
    env=env,
    variation_fn=variation_fn,
)

In [None]:
# get scoring function
descriptor_extraction_fn = environments.descriptor_extractor[env_name]
eval_policy = agent.get_eval_qd_fn(eval_env, descriptor_extraction_fn=descriptor_extraction_fn)


def scoring_function(genotypes, key):
    population_size = jax.tree_leaves(genotypes)[0].shape[0]
    first_states = jax.tree_map(
        lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states
    )
    first_states = jax.tree_map(
        lambda x: jnp.repeat(x, population_size, axis=0), first_states
    )
    population_returns, population_descriptors, _, _ = eval_policy(genotypes, first_states)
    return population_returns, population_descriptors, {}

In [None]:
# Get minimum reward value to make sure qd_score are positive
reward_offset = environments.reward_offset[env_name]

# Define a metrics function
metrics_function = functools.partial(
    default_qd_metrics,
    qd_offset=reward_offset * episode_length,
)

# Get the MAP-Elites algorithm
map_elites = DistributedMAPElites(
    scoring_function=scoring_function,
    emitter=emitter,
    metrics_function=metrics_function,
)

In [None]:
key, subkey = jax.random.split(key)
centroids = compute_cvt_centroids(
    num_descriptors=env.descriptor_length,
    num_init_cvt_samples=num_init_cvt_samples,
    num_centroids=num_centroids,
    minval=min_descriptor,
    maxval=max_descriptor,
    key=subkey,
)

In [None]:
key, *keys = jax.random.split(key, num=1 + num_devices)
keys = jnp.stack(keys)

In [None]:
# get the initial training states and replay buffers
agent_init_fn = agent.get_init_fn(
    population_size=pg_population_size_per_device + ga_population_size_per_device,
    action_size=env.action_size,
    observation_size=env.observation_size,
    buffer_size=buffer_size,
)

# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647
keys = jax.random.key_data(keys)

training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)(keys)

In [None]:
# empty optimizers states to avoid too heavy repertories
training_states = jax.pmap(
    jax.vmap(training_states.__class__.empty_optimizers_states),
    axis_name="p",
    devices=devices,
)(training_states)

# initialize map-elites
repertoire, emitter_state, init_metrics = map_elites.get_distributed_init_fn(
    devices=devices, centroids=centroids
)(genotypes=training_states, key=keys)

In [None]:
update_fn = map_elites.get_distributed_update_fn(
    num_iterations=log_period, devices=devices
)

In [None]:
env_step_multiplier = (
    (pg_population_size_per_device + ga_population_size_per_device)
    * eval_env_batch_size
    * episode_length
    + num_training_steps * pg_population_size_per_device
) * num_devices

In [None]:
all_metrics = {}
repertoires = []

for i in tqdm(range(num_iterations // log_period), total=num_iterations // log_period):
    start_time = time.time()

    key, *keys = jax.random.split(key, num=1 + num_devices)
    keys = jnp.stack(keys)
    repertoire, emitter_state, metrics = update_fn(
        repertoire, emitter_state, keys
    )
    metrics_cpu = jax.tree_map(lambda x: jax.device_get(x)[0], metrics)
    timelapse = time.time() - start_time

    # log metrics
    for k, v in metrics_cpu.items():
        # take all values
        if k in all_metrics.keys():
            all_metrics[k] = jnp.concatenate([all_metrics[k], v])
        else:
            all_metrics[k] = v

    if i % save_repertoire_freq == 0:
        repertoires.append(jax.tree_map(lambda x: jax.device_get(x)[0], repertoire))

In [None]:
env_steps = (jnp.arange(num_iterations * log_period) + 1) * env_step_multiplier

# create the plots and the grid
fig, axes = plot_map_elites_results(
    env_steps=env_steps,
    metrics=all_metrics,
    repertoire=repertoires[-1],
    min_descriptor=min_descriptor,
    max_descriptor=max_descriptor,
)

In [None]:
import math

import matplotlib.pyplot as plt

num_repertoires = len(repertoires)
num_cols = 5

fig, axes = plt.subplots(
    nrows=math.ceil(num_repertoires / num_cols), ncols=num_cols, figsize=(30, 30), squeeze=False,
)
for i, repertoire in enumerate(repertoires):

    col_i = i % num_cols
    row_i = i // num_cols

    plot_2d_map_elites_repertoire(
        centroids=centroids,
        # repertoire_fitnesses=repertoire.fitnesses,
        repertoire_fitnesses=jnp.where(
            repertoire.fitnesses > -jnp.inf,
            repertoire.genotypes.expl_noise,
            -jnp.inf * jnp.ones_like(repertoire.fitnesses),
        ),
        minval=min_descriptor,
        maxval=max_descriptor,
        ax=axes[row_i, col_i],
    )
    axes[row_i, col_i].set_title(f"Grid after {env_step_multiplier * i} steps")