In [None]:
import jax.numpy as jnp
import numpy as np
import jax
import time

from functools import partial
from jax import grad, jit, vmap, random

from parl.algorithms.map_elites import compute_cvt_centroids
from parl.algorithms.deprecated.multi_objectives_map_elites import run_mome
from parl.mutation_operators import (
    polynomial_mutation_function, 
    polynomial_crossover_function, 
    isoline_crossover_function
)
from parl.plotting import plot_2d_map_elites_grid, plot_mome_pareto_fronts

import matplotlib.pyplot as plt
from parl.basic_scorers import rastrigin_scorer

In [None]:
jax.devices()

### Example of running MOME on a simple Trigonometry problem

In [None]:
pareto_front_max_length = 50
num_variables = 10
num_iterations = 1000

num_centroids = 64
minval = -2
maxval = 4
proportion_to_mutate = 0.6
eta = 1
proportion_var_to_change = 0.5
crossover_percentage = 1.0
batch_size = 100
lag = 2.2
base_lag = 0
key = random.PRNGKey(42)
descriptor_mode="first"
grid_type = 'voronoi'

In [None]:
scoring_function = partial(rastrigin_scorer, base_lag=base_lag, lag=lag)

In [None]:
crossover_function = partial(
    polynomial_crossover_function,
    proportion_var_to_change = proportion_var_to_change
)

mutation_function = partial(
    polynomial_mutation_function,
    eta=eta,
    minval=minval,
    maxval=maxval,
    proportion_to_mutate=proportion_to_mutate
)

line_crossover_function = partial(
    isoline_crossover_function,
    iso_sigma=0.005,
    line_sigma=0.05,
)

In [None]:
if False:
    x, y = jnp.arange(-5, 5, step=0.2), jnp.arange(-5, 5, step=0.2)

    #Score
    all_enumerations = jnp.array([[(x[i], y[j]) for j in range(len(x))] for i in range(len(x))])
    all_enumerations = jnp.concatenate(all_enumerations)
    all_scores = scorer.compute_scores(all_enumerations)
    reference_point = jnp.array([jnp.min(all_scores[:, 0]), jnp.min(all_scores[:, 1])])
    
if True:
    reference_point = jnp.array([ -150, -150])

In [None]:
if grid_type == 'voronoi':
    
    init_time = time.time()
    centroids = compute_cvt_centroids(
        num_descriptors=2, 
        num_init_cvt_samples=20000, 
        num_centroids=num_centroids, 
        minval=minval, 
        maxval=maxval
    )
    duration = time.time() - init_time
    print(f'Computed centroids in {duration:.2f}s')
    centroids = centroids[jnp.argsort(centroids[:, 0])]
    
else:
    
    step = 0.5
    centroids = jnp.concatenate([
        jnp.array([[i,j] for j in jnp.arange(minval-step, maxval+step, step)]) 
        for i in jnp.arange(minval-step, maxval+step, step)]
    )
    num_centroids = len(centroids)  

In [None]:
run_fn = jax.jit(partial(
    run_mome,
    scoring_function=scoring_function,
    crossover_function=line_crossover_function,
    mutation_function=mutation_function,
    batch_size=batch_size,
    num_iterations=num_iterations,
    crossover_percentage=crossover_percentage,
    pareto_front_max_length=pareto_front_max_length,
    reference_point=reference_point
))

In [None]:
init_time = time.time()

x = jax.random.uniform(key, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32)

map_elites_grid, metrics = run_fn(
    init_genotypes=x,
    centroids=centroids,
    random_key=key,
)

duration = time.time() - init_time
print(f'MOME ran for {duration:.2f}s')

In [None]:
moqd_scores = jnp.sum(metrics.moqd_score, where=metrics.moqd_score!=-jnp.inf, axis=-1)

In [None]:
f, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(25, 5))

steps = batch_size * jnp.arange(start=0, stop=num_iterations+1)
ax1.plot(steps, moqd_scores)
ax1.set_xlabel('Num steps')
ax1.set_ylabel('MOQD Score')

ax2.plot(steps, metrics.max_hypervolume)
ax2.set_xlabel('Num steps')
ax2.set_ylabel('Max Hypervolume')

ax3.plot(steps, metrics.max_sum_scores)
ax3.set_xlabel('Num steps')
ax3.set_ylabel('Max Sum Scores')

ax4.plot(steps, metrics.coverage)
ax4.set_xlabel('Num steps')
ax4.set_ylabel('Coverage')
plt.show()

In [None]:
from parl.plotting import plot_mome_pareto_fronts

fig, axes = plt.subplots(figsize=(12, 6), ncols=2)
plot_mome_pareto_fronts(
    centroids,
    map_elites_grid,
    minval=minval,
    maxval=maxval,
    color_style='spectral',
    axes=axes,
    with_global=True
)
plt.show()

In [None]:
plot_2d_map_elites_grid(
    centroids=centroids,
    grid_fitness=metrics.moqd_score[-1],
    minval=minval,
    maxval=maxval,
)
plt.show()