# Traditional MO Genetic Algorithms (NSGA2 & SPEA2) on Rastrigin 

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax.numpy as jnp
import jax

from typing import Tuple

import matplotlib.pyplot as plt

import time

from functools import partial


from qdax.core.nsga2 import (
    NSGA2
)
from qdax.core.spea2 import (
    SPEA2
)
from qdax.core.genetic_algorithm import GeneticAlgorithm

from qdax.core.emitters.mutation_operators import (
    polynomial_crossover, 
    polynomial_mutation
)
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.utils.pareto_front import compute_pareto_front
from qdax.utils.plotting import plot_global_pareto_front

from qdax.utils.pareto_front import compute_pareto_front
from qdax.utils.plotting import plot_global_pareto_front
from qdax.utils.metrics import default_ga_metrics

from qdax.types import Genotype, Fitness, Descriptor

In [None]:
def rastrigin_scorer(
    genotypes: Genotype, base_lag: int, lag: int
) -> Tuple[Fitness, Descriptor]:
    """
    Rastrigin Scorer with first two dimensions as descriptors
    """
    descriptors = genotypes[:, :2]
    f1 = -(
        10 * genotypes.shape[1]
        + jnp.sum(
            (genotypes - base_lag) ** 2
            - 10 * jnp.cos(2 * jnp.pi * (genotypes - base_lag)),
            axis=1,
        )
    )

    f2 = -(
        10 * genotypes.shape[1]
        + jnp.sum(
            (genotypes - lag) ** 2 - 10 * jnp.cos(2 * jnp.pi * (genotypes - lag)),
            axis=1,
        )
    )
    scores = jnp.stack([f1, f2], axis=-1)

    return scores, descriptors


In [None]:
# Parameters
population_size = 1000
num_iterations = 1000
proportion_mutation = 0.80
minval, maxval = -5.12, 5.12
batch_size = 100
genotype_dim = 6
lag, base_lag = 2.2, 0
num_neighbours=1 # for spea2

# Mutation & Crossover
crossover_function = partial(
    polynomial_crossover, 
    proportion_var_to_change=0.5,
)

mutation_function = partial(
    polynomial_mutation, 
    proportion_to_mutate=0.5, 
    eta=0.05, 
    minval=minval, 
    maxval=maxval
)

# Define the emitter
mixing_emitter = MixingEmitter(
    mutation_fn=mutation_function, 
    variation_fn=crossover_function, 
    variation_percentage=1-proportion_mutation, 
    batch_size=batch_size
)

# Scoring function
scoring_function = partial(
    rastrigin_scorer,
    lag=lag,
    base_lag=base_lag
)

def scoring_fn(x, random_key):
    return scoring_function(x)[0], {}, random_key

# Initialize
random_key = jax.random.PRNGKey(0)
random_key, subkey = jax.random.split(random_key)
init_genotypes = jax.random.uniform(
    subkey, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32
)

## NSGA2

In [None]:
# instantitiate nsga2
nsga2 = NSGA2(
    scoring_function=scoring_fn,
    emitter=mixing_emitter,
    metrics_function=default_ga_metrics
)

# init nsga2
repertoire, emitter_state, random_key = nsga2.init(
    init_genotypes,
    population_size,
    random_key
)

In [None]:
%%time

# run optimization loop
(repertoire, emitter_state, random_key), _ = jax.lax.scan(
    nsga2.scan_update, (repertoire, emitter_state, random_key), (), length=num_iterations
)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
pareto_bool = compute_pareto_front(repertoire.fitnesses)
plot_global_pareto_front(repertoire.fitnesses[pareto_bool], ax=ax)
ax.set_title('NSGA2')
plt.show()

## SPEA2

In [None]:
# instantitiate spea2
spea2 = SPEA2(
    scoring_function=scoring_fn,
    emitter=mixing_emitter,
    metrics_function=default_ga_metrics
)

# init spea2
repertoire, emitter_state, random_key = spea2.init(
    init_genotypes,
    population_size,
    num_neighbours,
    random_key
)

In [None]:
%%time

# run optimization loop
(repertoire, emitter_state, random_key), _ = jax.lax.scan(
    spea2.scan_update, (repertoire, emitter_state, random_key), (), length=num_iterations
)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
pareto_bool = compute_pareto_front(repertoire.fitnesses)
plot_global_pareto_front(repertoire.fitnesses[pareto_bool], ax=ax)
ax.set_title('SPEA2')
plt.show()