[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/pgame.ipynb)

# Optimizing with DCRL-ME in JAX

This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [Descriptor-Conditioned Reinforcement Learning MAP-Elites (DCRL-ME)](https://arxiv.org/abs/2401.08632).
This algorithm extends and improves upon [Descriptor-Conditioned Gradients MAP-Elites (DCG-ME)](https://dl.acm.org/doi/abs/10.1145/3583131.3590503)
It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:

- how to define the problem
- how to create the DCRL emitter
- how to create a Map-elites instance
- which functions must be defined before training
- how to launch a certain number of training steps
- how to visualize the results of the training process

In [None]:
from IPython.display import clear_output

try:
    import qdax
except:
    print("QDax not found. Installing...")
    !pip install qdax[cuda12]
    import qdax

clear_output()

In [None]:
!pip install ipympl | tail -n 1
# %matplotlib widget
# from google.colab import output
# output.enable_custom_widget_manager()

import os

from IPython.display import clear_output
import functools
import time
from typing import Any, Tuple

import jax
import jax.numpy as jnp

from qdax import environments
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax.core.emitters.dcrl_me_emitter import DCRLMEConfig, DCRLMEEmitter
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.map_elites import MAPElites
from qdax.core.neuroevolution.buffers.buffer import DCRLTransition
from qdax.core.neuroevolution.networks.networks import MLP, MLPDC
from qdax.custom_types import EnvState, Params, RNGKey
from qdax.environments import descriptor_extractor
from qdax.environments.wrappers import OffsetRewardWrapper, ClipRewardWrapper
from qdax.tasks.brax_envs import scoring_function_brax_envs
from qdax.utils.plotting import plot_map_elites_results

from qdax.utils.metrics import CSVLogger, default_qd_metrics


if "COLAB_TPU_ADDR" in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

clear_output()

In [None]:
#@title QD Training Definitions Fields
seed = 42 #@param {type:"integer"}

env_name = "ant_omni" #@param['ant_uni', 'hopper_uni', 'walker_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']
episode_length = 250 #@param {type:"integer"}
min_descriptor = -30.0 #@param {type:"number"}
max_descriptor = 30.0 #@param {type:"number"}

num_iterations = 1000 #@param {type:"integer"}
batch_size = 256 #@param {type:"integer"}

# Archive
num_init_cvt_samples = 50000 #@param {type:"integer"}
num_centroids = 1024 #@param {type:"integer"}
policy_hidden_layer_sizes = (128, 128) #@param {type:"raw"}

# DCRL-ME
ga_batch_size = 128 #@param {type:"integer"}
dcrl_batch_size = 64 #@param {type:"integer"}
ai_batch_size = 64 #@param {type:"integer"}
lengthscale = 0.1 #@param {type:"number"}

# GA emitter
iso_sigma = 0.005 #@param {type:"number"}
line_sigma = 0.05 #@param {type:"number"}

# DCRL emitter
critic_hidden_layer_size = (256, 256) #@param {type:"raw"}
num_critic_training_steps = 3000 #@param {type:"integer"}
num_pg_training_steps = 150 #@param {type:"integer"}
replay_buffer_size = 1_000_000 #@param {type:"integer"}
discount = 0.99 #@param {type:"number"}
reward_scaling = 1.0 #@param {type:"number"}
critic_learning_rate = 3e-4 #@param {type:"number"}
actor_learning_rate = 3e-4 #@param {type:"number"}
policy_learning_rate = 5e-3 #@param {type:"number"}
noise_clip = 0.5 #@param {type:"number"}
policy_noise = 0.2 #@param {type:"number"}
soft_tau_update = 0.005 #@param {type:"number"}
policy_delay = 2 #@param {type:"number"}
#@markdown ---

## Init environment, policy, population params, init states of the env

Define the environment in which the policies will be trained. In this notebook, we focus on controllers learning to move a robot in a physical simulation. We also define the shared policy, that every individual in the population will use. Once the policy is defined, all individuals are defined by their parameters, that corresponds to their genotype.

In [None]:
# Init a random key
key = jax.random.key(seed)

# Init environment
env = environments.create(env_name, episode_length=episode_length)
env = OffsetRewardWrapper(
    env, offset=environments.reward_offset[env_name]
)  # apply reward offset as DCRL needs positive rewards
env = ClipRewardWrapper(
    env, clip_min=0.,
)  # apply reward clip as DCRL needs positive rewards
reset_fn = jax.jit(env.reset)

# Init policy network
policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)
policy_network = MLP(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jnp.tanh,
)
actor_dc_network = MLPDC(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jnp.tanh,
)

# Init population of controllers
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num=batch_size)
fake_batch_obs = jnp.zeros(shape=(batch_size, env.observation_size))
init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs)

## Define the way the policy interacts with the env

In [None]:
# Define the function to play a step with the policy in the environment
def play_step_fn(
    env_state: EnvState, policy_params: Params, key: RNGKey
) -> Tuple[EnvState, Params, RNGKey, DCRLTransition]:
    actions = policy_network.apply(policy_params, env_state.obs)
    state_desc = env_state.info["state_descriptor"]
    next_state = env.step(env_state, actions)

    transition = DCRLTransition(
        obs=env_state.obs,
        next_obs=next_state.obs,
        rewards=next_state.reward,
        dones=next_state.done,
        truncations=next_state.info["truncation"],
        actions=actions,
        state_desc=state_desc,
        next_state_desc=next_state.info["state_descriptor"],
        desc=jnp.zeros(
            env.descriptor_length,
        )
        * jnp.nan,
        desc_prime=jnp.zeros(
            env.descriptor_length,
        )
        * jnp.nan,
    )

    return next_state, policy_params, key, transition

## Define the scoring function and the way metrics are computed

The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual.

In [None]:
# Prepare the scoring function
descriptor_extraction_fn = descriptor_extractor[env_name]
scoring_fn = functools.partial(
    scoring_function_brax_envs,
    episode_length=episode_length,
    play_reset_fn=reset_fn,
    play_step_fn=play_step_fn,
    descriptor_extractor=descriptor_extraction_fn,
)

# Get minimum reward value to make sure qd_score are positive
reward_offset = environments.reward_offset[env_name]

# Define a metrics function
metrics_function = functools.partial(
    default_qd_metrics,
    qd_offset=reward_offset * episode_length,
)

## Define the emitter: DCRL Emitter

The emitter is used to evolve the population at each mutation step. In this example, the emitter is the Descriptor-Conditioned RL emitter, the one used in DCRL-ME. It trains a critic with the transitions experienced in the environment and uses the critic to apply Descriptor-Conditioned gradients updates to the policies evolved.

In [None]:
dcrl_emitter_config = DCRLMEConfig(
    ga_batch_size=ga_batch_size,
    dcrl_batch_size=dcrl_batch_size,
    ai_batch_size=ai_batch_size,
    lengthscale=lengthscale,
    critic_hidden_layer_size=critic_hidden_layer_size,
    num_critic_training_steps=num_critic_training_steps,
    num_pg_training_steps=num_pg_training_steps,
    batch_size=batch_size,
    replay_buffer_size=replay_buffer_size,
    discount=discount,
    reward_scaling=reward_scaling,
    critic_learning_rate=critic_learning_rate,
    actor_learning_rate=actor_learning_rate,
    policy_learning_rate=policy_learning_rate,
    noise_clip=noise_clip,
    policy_noise=policy_noise,
    soft_tau_update=soft_tau_update,
    policy_delay=policy_delay,
)

In [None]:
# Get the emitter
variation_fn = functools.partial(
    isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma
)

dcrl_emitter = DCRLMEEmitter(
    config=dcrl_emitter_config,
    policy_network=policy_network,
    actor_network=actor_dc_network,
    env=env,
    variation_fn=variation_fn,
)

## Instantiate and initialise the MAP Elites algorithm

In [None]:
# Instantiate MAP Elites
map_elites = MAPElites(
    scoring_function=scoring_fn,
    emitter=dcrl_emitter,
    metrics_function=metrics_function,
)

# Compute the centroids
key, subkey = jax.random.split(key)
centroids = compute_cvt_centroids(
    num_descriptors=env.descriptor_length,
    num_init_cvt_samples=num_init_cvt_samples,
    num_centroids=num_centroids,
    minval=min_descriptor,
    maxval=max_descriptor,
    key=subkey,
)

# compute initial repertoire
key, subkey = jax.random.split(key)
repertoire, emitter_state, init_metrics = map_elites.init(
    init_params, centroids, subkey
)

In [None]:
log_period = 10
num_loops = num_iterations // log_period

metrics = dict.fromkeys(["iteration", "qd_score", "coverage", "max_fitness", "time"], jnp.array([]))
csv_logger = CSVLogger(
    "dcrlme-logs.csv",
    header=list(metrics.keys())
)

# Main loop
map_elites_scan_update = map_elites.scan_update
for i in range(num_loops):
    start_time = time.time()
    (
        repertoire,
        emitter_state,
        key,
    ), current_metrics = jax.lax.scan(
        map_elites_scan_update,
        (repertoire, emitter_state, key),
        (),
        length=log_period,
    )
    timelapse = time.time() - start_time

    # Metrics
    current_metrics["iteration"] = jnp.arange(1+log_period*i, 1+log_period*(i+1), dtype=jnp.int32)
    current_metrics["time"] = jnp.repeat(timelapse, log_period)
    metrics = jax.tree.map(lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0), metrics, current_metrics)

    # Log
    csv_logger.log(jax.tree.map(lambda x: x[-1], metrics))

In [None]:
#@title Visualization

# Create the x-axis array
env_steps = metrics["iteration"]

%matplotlib inline
# Create the plots and the grid
fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=min_descriptor, max_descriptor=max_descriptor)