[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/Kheperax/blob/main/examples/main/me_training.ipynb)

# Running MAP-Elites on Kheperax

This example is directly inspired from the [MAP-Elites example of QDax](https://github.com/adaptive-intelligent-robotics/QDax/blob/96163f218f0ec1918aa237acefe3671f201c141f/examples/mapelites.ipynb)

In [1]:
import kheperax

# Uncomment the following line if you run this notebook on Google Colab
# !pip install kheperax[cuda12]

In [None]:
# Imports

import functools
from pathlib import Path

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.core.map_elites import MAPElites
from qdax.utils.metrics import default_qd_metrics
from qdax.utils.plotting import plot_2d_map_elites_repertoire

from kheperax.tasks.config import KheperaxConfig
from kheperax.tasks.main import KheperaxTask

In [6]:
# Parameters
seed = 42
batch_size = 2048
num_evaluations = int(1e6)
num_iterations = num_evaluations // batch_size
grid_shape = (50, 50)
episode_length = 250
mlp_policy_hidden_layer_sizes = (8,)

iso_sigma = 0.2
line_sigma = 0.0

In [7]:
# Init a random key
random_key = jax.random.PRNGKey(seed)

In [8]:
# Define Task configuration
config_kheperax = KheperaxConfig.get_default()
config_kheperax.episode_length = episode_length
config_kheperax.mlp_policy_hidden_layer_sizes = mlp_policy_hidden_layer_sizes

# Example of modification of the robots attributes
# (same thing could be done with the maze)
config_kheperax.robot = config_kheperax.robot.replace(
    lasers_return_minus_one_if_out_of_range=True
)

In [9]:
# Create Kheperax Task.
random_key, subkey = jax.random.split(random_key)
(
    env,
    policy_network,
    scoring_fn,
) = KheperaxTask.create_default_task(
    config_kheperax,
    random_key=subkey,
)

In [10]:
# Initialise population of controllers
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=batch_size)
fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)

In [11]:
# Define QDax emitter
variation_fn = functools.partial(
    isoline_variation,
    iso_sigma=iso_sigma,
    line_sigma=line_sigma,
)
mixing_emitter = MixingEmitter(
    mutation_fn=lambda x, y: (x, y),
    variation_fn=variation_fn,
    variation_percentage=1.0,
    batch_size=batch_size,
)

# Define a metrics function
metrics_fn = functools.partial(
    default_qd_metrics,
    qd_offset=0.5,
)

# Instantiate MAP-Elites
map_elites = MAPElites(
    scoring_function=scoring_fn,
    emitter=mixing_emitter,
    metrics_function=metrics_fn,
)

In [None]:
# Compute the centroids
min_bd, max_bd = env.behavior_descriptor_limits
centroids = compute_euclidean_centroids(
    grid_shape=grid_shape,
    minval=min_bd,
    maxval=max_bd,
)

# Initializes repertoire and emitter state
repertoire, emitter_state, random_key = map_elites.init(
    init_variables, centroids, random_key
)


In [None]:
update_fn = jax.jit(map_elites.update)

# Run MAP-Elites loop
for iteration in range(num_iterations):
    (
        repertoire,
        emitter_state,
        metrics,
        random_key,
    ) = update_fn(
        repertoire,
        emitter_state,
        random_key,
    )
    print(
        f"{iteration}/{num_iterations}"
        f" - {({k: v.item() for (k, v) in metrics.items()})}"
    )

In [None]:
%matplotlib inline

# plot archive
fig, axes = plot_2d_map_elites_repertoire(
    centroids=repertoire.centroids,
    repertoire_fitnesses=repertoire.fitnesses,
    minval=min_bd,
    maxval=max_bd,
    repertoire_descriptors=repertoire.descriptors,
    # vmin=-0.2,
    # vmax=0.0,
)
plt.show()