In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax.numpy as jnp
import jax
import time
from typing import Tuple, Any

from functools import partial

from qdax.core.containers.repertoire import compute_cvt_centroids
from qdax.core.containers.mome_repertoire import MOMERepertoire
from qdax.core.multi_objectives_map_elites import MOME, compute_moqd_metrics, add_init_metrics
from qdax.core.emitters.mutation_operators import (
    polynomial_mutation, 
    polynomial_crossover, 
    isoline_variation
)
from qdax.utils.plotting import plot_2d_map_elites_grid, plot_mome_pareto_fronts

import matplotlib.pyplot as plt

from qdax.types import Fitness, Descriptor, RNGKey, Metrics

In [None]:
jax.devices()

In [None]:
def rastrigin_scorer(
    genotypes: jnp.ndarray, base_lag: int, lag: int
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Rastrigin Scorer with first two dimensions as descriptors
    """
    descriptors = genotypes[:, :2]
    f1 = -(
        10 * genotypes.shape[1]
        + jnp.sum(
            (genotypes - base_lag) ** 2
            - 10 * jnp.cos(2 * jnp.pi * (genotypes - base_lag)),
            axis=1,
        )
    )

    f2 = -(
        10 * genotypes.shape[1]
        + jnp.sum(
            (genotypes - lag) ** 2 - 10 * jnp.cos(2 * jnp.pi * (genotypes - lag)),
            axis=1,
        )
    )
    scores = jnp.stack([f1, f2], axis=-1)

    return scores, descriptors

### Example of running MOME on a simple Trigonometry problem

In [None]:
pareto_front_max_length = 50
num_variables = 100
num_iterations = 1000

num_centroids = 64
minval = -2
maxval = 4
proportion_to_mutate = 0.6
eta = 1
proportion_var_to_change = 0.5
crossover_percentage = 1.
batch_size = 100
lag = 2.2
base_lag = 0
key = jax.random.PRNGKey(42)
descriptor_mode="first"
grid_type = 'voronoi'

In [None]:
scoring_function = partial(rastrigin_scorer, base_lag=base_lag, lag=lag)

In [None]:
crossover_function = partial(
    polynomial_crossover,
    proportion_var_to_change=proportion_var_to_change
)

mutation_function = partial(
    polynomial_mutation,
    eta=eta,
    minval=minval,
    maxval=maxval,
    proportion_to_mutate=proportion_to_mutate
)

line_crossover_function = partial(
    isoline_variation,
    iso_sigma=0.005,
    line_sigma=0.05,
)

In [None]:
my_boolean = False
if my_boolean:
    x, y = jnp.arange(-5, 5, step=0.2), jnp.arange(-5, 5, step=0.2)

    #Score
    all_enumerations = jnp.array([[(x[i], y[j]) for j in range(len(x))] for i in range(len(x))])
    all_enumerations = jnp.concatenate(all_enumerations)
    all_scores = scorer.compute_scores(all_enumerations)
    reference_point = jnp.array([jnp.min(all_scores[:, 0]), jnp.min(all_scores[:, 1])])
else:
    reference_point = jnp.array([ -150, -150])

In [None]:
if grid_type == 'voronoi':
    
    init_time = time.time()
    centroids = compute_cvt_centroids(
        num_descriptors=2, 
        num_init_cvt_samples=20000, 
        num_centroids=num_centroids, 
        minval=minval, 
        maxval=maxval
    )
    duration = time.time() - init_time
    print(f'Computed centroids in {duration:.2f}s')
    centroids = centroids[jnp.argsort(centroids[:, 0])]
    
else:
    
    step = 0.5
    centroids = jnp.concatenate([
        jnp.array([[i,j] for j in jnp.arange(minval-step, maxval+step, step)]) 
        for i in jnp.arange(minval-step, maxval+step, step)]
    )
    num_centroids = len(centroids)  

In [None]:
from jax.config import config

config.update('jax_disable_jit', False)

In [None]:
mome = MOME(
    scoring_function=scoring_function,
    batch_size=batch_size
)

In [None]:
# init function
init_function = jax.jit(partial(
    mome.init,
    centroids=centroids,
    pareto_front_max_length=pareto_front_max_length,
))

# iteration function
iteration_function = jax.jit(partial(
    mome.update,
    crossover_function=crossover_function,
    mutation_function=mutation_function,
    crossover_percentage=crossover_percentage,
))

@jax.jit
def iteration_fn(
    carry: Tuple[MOMERepertoire, jnp.ndarray], unused: Any
) -> Tuple[Tuple[MOMERepertoire, RNGKey], Metrics]:
    # iterate over grid
    grid, random_key = carry
    grid, random_key = iteration_function(grid, random_key)

    # get metrics
    metrics = compute_moqd_metrics(grid, reference_point)
    return (grid, random_key), metrics

In [None]:
random_key = jax.random.PRNGKey(42)
random_key, subkey = jax.random.split(random_key)
init_genotypes = jax.random.uniform(
    random_key, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32
)

# init algorithm
map_elites_grid = init_function(init_genotypes)

init_metrics = compute_moqd_metrics(map_elites_grid, reference_point)

In [None]:
init_time = time.time()

# run optimization loop
(map_elites_grid, random_key), metrics = jax.lax.scan(
    iteration_fn, (map_elites_grid, random_key), (), length=num_iterations
)

duration = time.time() - init_time
print(f'MOME ran for {duration:.2f}s')

metrics = add_init_metrics(metrics, init_metrics)

In [None]:
moqd_scores = jnp.sum(metrics.moqd_score, where=metrics.moqd_score!=-jnp.inf, axis=-1)

In [None]:
f, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(25, 5))

steps = batch_size * jnp.arange(start=0, stop=num_iterations+1)
ax1.plot(steps, moqd_scores)
ax1.set_xlabel('Num steps')
ax1.set_ylabel('MOQD Score')

ax2.plot(steps, metrics.max_hypervolume)
ax2.set_xlabel('Num steps')
ax2.set_ylabel('Max Hypervolume')

ax3.plot(steps, metrics.max_sum_scores)
ax3.set_xlabel('Num steps')
ax3.set_ylabel('Max Sum Scores')

ax4.plot(steps, metrics.coverage)
ax4.set_xlabel('Num steps')
ax4.set_ylabel('Coverage')
plt.show()

In [None]:
fig, axes = plt.subplots(figsize=(12, 6), ncols=2)
plot_mome_pareto_fronts(
    centroids,
    map_elites_grid,
    minval=minval,
    maxval=maxval,
    color_style='spectral',
    axes=axes,
    with_global=True
)
plt.show()

In [None]:
plot_2d_map_elites_grid(
    centroids=centroids,
    grid_fitness=metrics.moqd_score[-1],
    minval=minval,
    maxval=maxval,
)
plt.show()