[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/distributed_mapelites.ipynb)

# Optimizing with MAP-Elites in JAX (multi-devices example)

This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [MAP-Elites](https://arxiv.org/abs/1504.04909).
It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:

- how to define the problem
- how to create an emitter
- how to create a Map-elites instance to work on multiple devices
- which functions must be defined before training
- how to launch a certain number of training steps

In [None]:
from IPython.display import clear_output

try:
    import qdax
except:
    print("QDax not found. Installing...")
    !pip install qdax[cuda12]
    import qdax

clear_output()

In [None]:
!pip install ipympl | tail -n 1
# %matplotlib widget
# from google.colab import output
# output.enable_custom_widget_manager()

import os

from IPython.display import clear_output
import functools

try:
    from tqdm import tqdm
except:
    !pip install tqdm | tail -n 1
    from tqdm import tqdm

import time

import jax
import jax.numpy as jnp

from qdax.core.distributed_map_elites import DistributedMAPElites
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax import environments
from qdax.tasks.brax_envs import scoring_function_brax_envs as scoring_function
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.emitters.standard_emitters import MixingEmitter

from qdax.utils.metrics import CSVLogger, default_qd_metrics


if "COLAB_TPU_ADDR" in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()


clear_output()

## Setup and get devices

Setup the default platform where the MAP-Elites will be stored and MAP-Elite updates will happen.

In [None]:
default_device = 'cpu'
jax.config.update('jax_platform_name', default_device)

In [None]:
# Get devices (change gpu by tpu if needed)
devices = jax.devices('gpu')
num_devices = len(devices)
print(f'Detected the following {num_devices} device(s): {devices}')

## Setup run parameters

In [None]:
#@title QD Training Definitions Fields
#@markdown ---
batch_size_per_device = 100  #@param {type:"number"}
batch_size = batch_size_per_device * num_devices #@param {type:"number"}
env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']
episode_length = 100 #@param {type:"integer"}
num_iterations = 1000 #@param {type:"integer"}
seed = 42 #@param {type:"integer"}
policy_hidden_layer_sizes = (64, 64) #@param {type:"raw"}
iso_sigma = 0.005 #@param {type:"number"}
line_sigma = 0.05 #@param {type:"number"}
num_init_cvt_samples = 50000 #@param {type:"integer"}
num_centroids = 1024 #@param {type:"integer"}
min_descriptor = 0. #@param {type:"number"}
max_descriptor = 1.0 #@param {type:"number"}
#@markdown ---

## Init environment, policy, population params, init states of the env

Define the environment in which the policies will be trained. In this notebook, we focus on controllers learning to move a robot in a physical simulation. We also define the shared policy, that every individual in the population will use. Once the policy is defined, all individuals are defined by their parameters, that corresponds to their genotype.

In [None]:
# Init environment
env = environments.create(env_name, episode_length=episode_length)
reset_fn = jax.jit(env.reset)

# Init a random key
key = jax.random.key(seed)

# Init policy network
policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)
policy_network = MLP(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jnp.tanh,
)

# Init population of controllers (batch size controllers)
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num=batch_size)
fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)

## Define the way the policy interacts with the env

Now that the environment and policy has been defined, it is necessary to define a function that describes how the policy must be used to interact with the environment and to store transition data.

In [None]:
# Define the function to play a step with the policy in the environment
def play_step_fn(
    env_state,
    policy_params,
    key,
):
    """
    Play an environment step and return the updated state and the transition.
    """

    actions = policy_network.apply(policy_params, env_state.obs)

    state_desc = env_state.info["state_descriptor"]
    next_state = env.step(env_state, actions)

    transition = QDTransition(
        obs=env_state.obs,
        next_obs=next_state.obs,
        rewards=next_state.reward,
        dones=next_state.done,
        actions=actions,
        truncations=next_state.info["truncation"],
        state_desc=state_desc,
        next_state_desc=next_state.info["state_descriptor"],
    )

    return next_state, policy_params, key, transition

## Define the scoring function and the way metrics are computed

The scoring function is used in the evaluation step to determine the fitness and descriptor of each individual.

In [None]:
# Prepare the scoring function
descriptor_extraction_fn = environments.descriptor_extractor[env_name]
scoring_fn = functools.partial(
    scoring_function,
    episode_length=episode_length,
    play_reset_fn=reset_fn,
    play_step_fn=play_step_fn,
    descriptor_extractor=descriptor_extraction_fn,
)

# Get minimum reward value to make sure qd_score are positive
reward_offset = environments.reward_offset[env_name]

# Define a metrics function
metrics_function = functools.partial(
    default_qd_metrics,
    qd_offset=reward_offset * episode_length,
)

## Define the emitter

The emitter is used to evolve the population at each mutation step.

In [None]:
# Define emitter
variation_fn = functools.partial(
    isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma
)
mixing_emitter = MixingEmitter(
    mutation_fn=None,
    variation_fn=variation_fn,
    variation_percentage=1.0,
    batch_size=batch_size_per_device
)

## Instantiate and initialise the MAP Elites algorithm

In [None]:
# Instantiate MAP-Elites
map_elites = DistributedMAPElites(
    scoring_function=scoring_fn,
    emitter=mixing_emitter,
    metrics_function=metrics_function,
)

# Compute the centroids
key, subkey = jax.random.split(key)
centroids = compute_cvt_centroids(
    num_descriptors=env.descriptor_length,
    num_init_cvt_samples=num_init_cvt_samples,
    num_centroids=num_centroids,
    minval=min_descriptor,
    maxval=max_descriptor,
    key=subkey,
)

# Compute initial repertoire and emitter state
keys = jax.random.split(key, num=num_devices)
keys = jnp.stack(keys)

# add a dimension for devices
init_variables = jax.tree.map(
    lambda x: jnp.reshape(x, (num_devices, batch_size_per_device,) + x.shape[1:]),
    init_variables
)

# get initial elements
repertoire, emitter_state, init_metrics = map_elites.get_distributed_init_fn(
    centroids=centroids,
    devices=devices,
)(genotypes=init_variables, key=keys)

## Launch MAP-Elites iterations

In [None]:
log_period = 10
num_loops = num_iterations // log_period

csv_logger = CSVLogger(
    "mapelites-logs.csv",
    header=["loop", "iteration", "qd_score", "max_fitness", "coverage", "time"]
)
all_metrics = {}

# Get update function
update_fn = map_elites.get_distributed_update_fn(num_iterations=log_period, devices=devices)

# main loop
for i in tqdm(range(num_loops), total=num_loops):

    start_time = time.time()

    # main iterations
    repertoire, emitter_state, metrics = update_fn(repertoire, emitter_state, keys)

    # get metrics
    metrics = jax.tree.map(lambda x: x[0], metrics)
    timelapse = time.time() - start_time

    # log metrics
    logged_metrics = {"time": timelapse, "loop": 1+i, "iteration": 1 + i*log_period}
    for key, value in metrics.items():
        # take last value
        logged_metrics[key] = value[-1]

        # take all values
        if key in all_metrics.keys():
            all_metrics[key] = jnp.concatenate([all_metrics[key], value])
        else:
            all_metrics[key] = value

    csv_logger.log(logged_metrics)

## Retrieve the repertoire from the first device

All devices have the same duplicated version of the repertoire

In [None]:
# Get the repertoire from the first device
repertoire = jax.tree.map(lambda x: x[0], repertoire)