###### [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/cmamega_example.ipynb)

# Optimizing with CMA-ME in Jax

This notebook shows how to use QDax to find diverse and performing parameters on the Rastrigin problem with [CMA-ME](https://arxiv.org/pdf/1912.02400.pdf). It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:

- how to define the problem
- how to create a CMA-ME emitter
- how to create a Map-elites instance
- which functions must be defined before training
- how to launch a certain number of training steps
- how to visualise the optimization process

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import math

import jax 
import jax.numpy as jnp

try:
    import flax
except:
    !pip install --no-deps git+https://github.com/google/flax.git@v0.4.1 |tail -n 1
    import flax

try:
    import chex
except:
    !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.3 |tail -n 1
    import chex
    
try:
    import qdax
except:
    !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1
    import qdax

from qdax.core.map_elites import MAPElites
from qdax.core.emitters.cma_emitter import CMAEmitter
from qdax.core.emitters.cma_opt_emitter import CMAOptimizingEmitter
from qdax.core.emitters.cma_rnd_emitter import CMARndEmitter
from qdax.core.emitters.cma_multi_emitter import CMAPoolEmitter
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids, MapElitesRepertoire
# from qdax.utils.plotting import plot_map_elites_results
from qdax.utils.plotting import plot_multidimensional_map_elites_grid

from typing import Dict

In [None]:
# from jax.config import config

# config.update('jax_disable_jit', True)

## Set the hyperparameters

Most hyperparameters are similar to those introduced in [Differentiable Quality Diversity paper](https://arxiv.org/pdf/2106.03894.pdf).

In [None]:
#@title QD Training Definitions Fields
#@markdown ---
num_iterations = 70000 #70000 #10000
num_dimensions = 100 #1000 #@param {type:"integer"} # try 20 and 100
# num_centroids = 10000 #@param {type:"integer"}
grid_shape = (500, 500) # (500, 500) 
batch_size = 36 #36 #@param {type:"integer"}
sigma_g = .5 #@param {type:"number"}
minval = -5.12 #@param {type:"number"}
maxval = 5.12 #@param {type:"number"}
min_bd = -5.12 * 0.5 * num_dimensions #@param {type:"number"}
max_bd = 5.12 * 0.5 * num_dimensions #@param {type:"number"}
emitter_type = "rnd" # "opt", "imp", "rnd"
pool_size = 15
opt_function = "sphere" # "rastrigin", "sphere"
#@markdown ---

## Defines the scoring function: rastrigin or sphere

In [None]:
def rastrigin_scoring(x: jnp.ndarray):
    first_term = 10 * x.shape[-1]
    second_term = jnp.sum((x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4)))
    return -(first_term + second_term)

def sphere_scoring(x: jnp.ndarray):
    return -jnp.sum((x + minval * 0.4) * (x + minval * 0.4), axis=-1)

if opt_function == "sphere":
    fitness_scoring = sphere_scoring
elif opt_function == "rastrigin":
    fitness_scoring = rastrigin_scoring
else:
    raise Exception("Invalid opt function name given")

def clip(x: jnp.ndarray):
    in_bound = (x <= maxval) * (x >= minval)
    return jnp.where(
        condition=in_bound,
        x=x,
        y=(maxval / x)
    )

def _behavior_descriptor_1(x: jnp.ndarray):
    return jnp.sum(clip(x[:x.shape[-1]//2]))

def _behavior_descriptor_2(x: jnp.ndarray):
    return jnp.sum(clip(x[x.shape[-1]//2:]))

def _behavior_descriptors(x: jnp.ndarray):
    return jnp.array([_behavior_descriptor_1(x), _behavior_descriptor_2(x)])

In [None]:
def scoring_function(x):
    scores, descriptors = fitness_scoring(x), _behavior_descriptors(x)
    return scores, descriptors, {}

def scoring_fn(x, random_key):
    fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x)
    return fitnesses, descriptors, extra_scores, random_key

## Define the metrics that will be used

In [None]:
worst_objective = fitness_scoring(-jnp.ones(num_dimensions) * 5.12)
# worst_objective = rastrigin_scoring(jnp.zeros(num_dimensions))
best_objective = fitness_scoring(jnp.ones(num_dimensions) * 5.12 * 0.4)

num_centroids = math.prod(grid_shape)

def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]:

    # get metrics
    grid_empty = repertoire.fitnesses == -jnp.inf
    adjusted_fitness = (
        (repertoire.fitnesses - worst_objective) * 100 / (best_objective - worst_objective)
    )
    qd_score = jnp.sum(adjusted_fitness, where=~grid_empty) # / num_centroids
    coverage = 100 * jnp.mean(1.0 - grid_empty)
    max_fitness = jnp.max(adjusted_fitness)
    return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage}

## Define the initial population, the emitter and the MAP Elites instance

The emitter is defined using the CMAME emitter class. This emitter is given to a MAP-Elites instance to create an instance of the CMA-ME algorithm.

In [None]:
random_key = jax.random.PRNGKey(0)
initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions))

centroids = compute_euclidean_centroids(
    grid_shape=grid_shape,
    minval=min_bd,
    maxval=max_bd,
)

# Define emitter
import functools
from qdax.core.emitters.mutation_operators import isoline_variation
variation_fn = functools.partial(
    isoline_variation, iso_sigma=0.5, line_sigma=0.
)
mixing_emitter = MixingEmitter(
    mutation_fn=None, 
    variation_fn=variation_fn, 
    variation_percentage=1.0, 
    batch_size=batch_size
)

emitter_kwargs = {
    "batch_size": batch_size,
    "genotype_dim": num_dimensions,
    "centroids": centroids,
    "sigma_g": sigma_g,
    "min_count": 1,
    "max_count": None,
}

if emitter_type == "opt":
    emitter = CMAOptimizingEmitter(**emitter_kwargs)
elif emitter_type == "imp":
    emitter = CMAEmitter(**emitter_kwargs)
elif emitter_type == "rnd":
    emitter = CMARndEmitter(**emitter_kwargs)
else:
    raise Exception("Invalid emitter type")

# emitter = CMAOptimizingEmitter(
# # emitter = CMAEmitter(
# # emitter = CMARndEmitter(
#     batch_size=batch_size,
#     genotype_dim=num_dimensions,
#     centroids=centroids,
#     sigma_g=sigma_g,
#     min_count=1,
#     max_count=None,
# )

emitter = CMAPoolEmitter(
    num_states=pool_size,
    emitter=emitter
)

map_elites = MAPElites(
    scoring_function=scoring_fn,
    emitter=emitter,
#     emitter=mixing_emitter,
    metrics_function=metrics_fn
)

In [None]:
repertoire, emitter_state, random_key = map_elites.init(initial_population, centroids, random_key)

In [None]:
%%time

(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
    map_elites.scan_update,
    (repertoire, emitter_state, random_key),
    (),
    length=num_iterations,
)

In [None]:
for k, v in metrics.items():
    print(f"{k} after {num_iterations * batch_size}: {v[-1]}")

In [None]:
metrics

In [None]:
# #@title Visualization

# # create the x-axis array
# env_steps = jnp.arange(num_iterations) * batch_size

# # create the plots and the grid
# fig, axes = plot_map_elites_results(
#     env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd
# )

In [None]:
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt

def plot_map_elites_results(
    env_steps: jnp.ndarray,
    metrics: Dict,
    repertoire: MapElitesRepertoire,
    min_bd: jnp.ndarray,
    max_bd: jnp.ndarray,
):
    """Plots three usual QD metrics, namely the coverage, the maximum fitness
    and the QD-score, along the number of environment steps. This function also
    plots a visualisation of the final map elites grid obtained. It ensures that
    those plots are aligned together to give a simple and efficient visualisation
    of an optimization process.

    Args:
        env_steps: the array containing the number of steps done in the environment.
        metrics: a dictionary containing metrics from the optimizatoin process.
        repertoire: the final repertoire obtained.
        min_bd: the mimimal possible values for the bd.
        max_bd: the maximal possible values for the bd.

    Returns:
        A figure and axes with the plots of the metrics and visualisation of the grid.
    """
    # Customize matplotlib params
    font_size = 16
    params = {
        "axes.labelsize": font_size,
        "axes.titlesize": font_size,
        "legend.fontsize": font_size,
        "xtick.labelsize": font_size,
        "ytick.labelsize": font_size,
        "text.usetex": False,
        "axes.titlepad": 10,
    }

    mpl.rcParams.update(params)

    # Visualize the training evolution and final repertoire
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(40, 10))

    # env_steps = jnp.arange(num_iterations) * episode_length * batch_size

    axes[0].plot(env_steps, metrics["coverage"])
    axes[0].set_xlabel("Environment steps")
    axes[0].set_ylabel("Coverage in %")
    axes[0].set_title("Coverage evolution during training")
    axes[0].set_aspect(0.95 / axes[0].get_data_ratio(), adjustable="box")

    axes[1].plot(env_steps, metrics["max_fitness"])
    axes[1].set_xlabel("Environment steps")
    axes[1].set_ylabel("Maximum fitness")
    axes[1].set_title("Maximum fitness evolution during training")
    axes[1].set_aspect(0.95 / axes[1].get_data_ratio(), adjustable="box")

    axes[2].plot(env_steps, metrics["qd_score"])
    axes[2].set_xlabel("Environment steps")
    axes[2].set_ylabel("QD Score")
    axes[2].set_title("QD Score evolution during training")
    axes[2].set_aspect(0.95 / axes[2].get_data_ratio(), adjustable="box")

#     _, axes = plot_2d_map_elites_repertoire(
#         centroids=repertoire.centroids,
#         repertoire_fitnesses=repertoire.fitnesses,
#         minval=min_bd,
#         maxval=max_bd,
#         repertoire_descriptors=repertoire.descriptors,
#         ax=axes[3],
#     )

    return fig, axes

In [None]:
env_steps = jnp.arange(num_iterations) * batch_size

fig, axes = plot_map_elites_results(
    env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd
)

figname = "cma_me_" + opt_function + "_" + str(num_dimensions) + "_" + emitter_type + ".png"
print("Save figure in: ", figname)
plt.savefig(figname)

In [None]:
# fig, axes = plot_multidimensional_map_elites_grid(
#     repertoire=repertoire,
#     minval=jnp.array([min_bd, min_bd]),
#     maxval=jnp.array([max_bd, max_bd]),
#     grid_shape=grid_shape,
# )