In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax 
import jax.numpy as jnp
from jax import random, grad
from qdax.core.map_elites import MAPElites
from qdax.core.emitters.cma_mega_emitter import CMAMEGAEmitter
from qdax.core.containers.repertoire import compute_cvt_centroids, MapElitesRepertoire
from typing import Dict

In [None]:
def rastrigin_scoring(x: jnp.ndarray):
    return -(10 * x.shape[-1] + jnp.sum((x-5.12*0.4)**2 - 10 * jnp.cos(2 * jnp.pi * (x-5.12*0.4))))

def clip(x: jnp.ndarray):
    return x*(x<=5.12)*(x>=-5.12) + 5.12/x*((x>5.12)+(x<-5.12))

def _rastrigin_descriptor_1(x: jnp.ndarray):
    return jnp.mean(clip(x[:x.shape[0]//2]))

def _rastrigin_descriptor_2(x: jnp.ndarray):
    return jnp.mean(clip(x[x.shape[0]//2:]))

def rastrigin_descriptors(x: jnp.ndarray):
    return jnp.array([_rastrigin_descriptor_1(x), _rastrigin_descriptor_2(x)])

def rastrigin_fn(x: jnp.ndarray):
    return (rastrigin_scoring(x), rastrigin_descriptors(x))

rastrigin_grad_scores = grad(rastrigin_scoring)

In [None]:
def scoring_function(x):
    scores, descriptors = rastrigin_scoring(x), rastrigin_descriptors(x)
    gradients = jnp.array([rastrigin_grad_scores(x), grad(_rastrigin_descriptor_1)(x), grad(_rastrigin_descriptor_2)(x)]).T
    gradients = jnp.nan_to_num(gradients)

    # Compute normalized gradients
    norm_gradients = jax.tree_map(
        lambda x: jnp.linalg.norm(x, axis=1, keepdims=True),
        gradients,
    )
    grads = jax.tree_map(
        lambda x, y: x / y, gradients, norm_gradients
    )
    grads = jnp.nan_to_num(grads)
    extra_scores = {
        'gradients': gradients,
        'normalized_grads': grads
    }

    return scores, descriptors, extra_scores

scoring_fn = jax.vmap(scoring_function)

In [None]:
num_dimensions = 1000
num_centroids = 10000
minval = -5.12
maxval = 5.12
batch_size = 36
learning_rate = 1
sigma_g = 10

worst_objective = rastrigin_scoring(-jnp.ones(num_dimensions) * 5.12)
best_objective = rastrigin_scoring(jnp.ones(num_dimensions) * 5.12 * 0.4)


def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]:

    # get metrics
    grid_empty = repertoire.fitnesses == -jnp.inf
    adjusted_fitness = (
        (repertoire.fitnesses - worst_objective) / (best_objective - worst_objective)
    )
    qd_score = jnp.sum(adjusted_fitness, where=~grid_empty) / num_centroids
    coverage = 100 * jnp.mean(1.0 - grid_empty)
    max_fitness = jnp.max(adjusted_fitness)
    return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage}

random_key = random.PRNGKey(0)
initial_population = random.uniform(random_key, shape=(batch_size, num_dimensions))

centroids = compute_cvt_centroids(
    num_descriptors=2, 
    num_init_cvt_samples=10000, 
    num_centroids=num_centroids, 
    minval=minval, 
    maxval=maxval
)

emitter = CMAMEGAEmitter(
    scoring_function=scoring_fn,
    batch_size=batch_size,
    learning_rate=learning_rate,
    num_descriptors=2,
    sigma_g=sigma_g,
)

map_elites = MAPElites(
    scoring_function=scoring_fn,
    emitter=emitter,
    metrics_function=metrics_fn
)

In [None]:
repertoire, emitter_state = map_elites.init(initial_population, centroids, random_key)

@jax.jit
def update_scan_fn(carry, unused):
    # iterate over grid
    repertoire, emitter_state, random_key = carry
    (repertoire, emitter_state, metrics, random_key,) = map_elites.update(
        repertoire,
        emitter_state,
        random_key,
    )

    return (repertoire, emitter_state, random_key), metrics

In [None]:
%%time
num_iterations = 20000

(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
    update_scan_fn,
    (repertoire, emitter_state, random_key),
    (),
    length=num_iterations,
)

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt


font_size = 16
mpl_params = {
    "axes.labelsize": font_size,
    "axes.titlesize": font_size,
    "legend.fontsize": font_size,
    "xtick.labelsize": font_size,
    "ytick.labelsize": font_size,
    "text.usetex": False,
    "axes.titlepad": 10,
}

mpl.rcParams.update(mpl_params)

# Visualize the training evolution and final repertoire
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(40,10))

env_steps = jnp.arange(num_iterations) * batch_size

axes[0].plot(env_steps, metrics['coverage'])
axes[0].set_xlabel('Num evaluations')
axes[0].set_ylabel('Coverage in %')
axes[0].set_title('Coverage evolution during training')
axes[0].set_aspect(0.95/axes[0].get_data_ratio(), adjustable='box')

axes[1].plot(env_steps, metrics['max_fitness'])
axes[1].set_xlabel('Num evaluations')
axes[1].set_ylabel('Maximum fitness')
axes[1].set_title('Maximum fitness evolution during training')
axes[1].set_aspect(0.95/axes[1].get_data_ratio(), adjustable='box')

axes[2].plot(env_steps, metrics['qd_score'])
axes[2].set_xlabel('Num evaluations')
axes[2].set_ylabel('QD Score')
axes[2].set_title('QD Score evolution during training')
axes[2].set_aspect(0.95/axes[2].get_data_ratio(), adjustable='box')