# Traditional MO Genetic Algorithms (NSGA2 & SPEA2) on Rastrigin 

In [3]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
import jax.numpy as jnp
import jax

import matplotlib.pyplot as plt

import time

from functools import partial


from qdax.core.nsga2 import (
    run_nsga2
)
from qdax.core.spea2 import (
    run_spea2
)

from qdax.core.emitters.mutation_operators import (
    polynomial_crossover, 
    polynomial_mutation
)
from qdax.utils.pareto_front import compute_pareto_front
from qdax.utils.plotting import plot_global_pareto_front


from qdax.basic_scorers import rastrigin_scorer
from qdax.utils.pareto_front import compute_pareto_front
from qdax.plotting import plot_global_pareto_front

ModuleNotFoundError: No module named 'qdax.basic_scorers'

## NSGA2

In [None]:
# Parameters
population_size = 1000
num_iterations = 1000
proportion_mutation = 0.80
minval, maxval = -5.12, 5.12
batch_size = 100
genotype_dim = 6
lag, base_lag = 2.2, 0 

# Mutation & Crossover
crossover_function = partial(
    polynomial_crossover_function, 
    proportion_var_to_change=0.5,
)

mutation_function = partial(
    polynomial_mutation_function, 
    proportion_to_mutate=0.5, 
    eta=0.05, 
    minval=minval, 
    maxval=maxval
)

# Scoring function
scoring_function = partial(
    rastrigin_scorer,
    lag=lag,
    base_lag=base_lag
)
scoring_fn = lambda x: scoring_function(x)[0]

# Initialize
key = jax.random.PRNGKey(0)
key, sub_key = jax.random.split(key)
init_genotypes = jax.random.uniform(
    sub_key, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32
)

In [None]:
init_time = time.time()

solutions = run_nsga2(
    init_genotypes=init_genotypes,
    random_key=key,
    scoring_function=scoring_fn,
    crossover_function=crossover_function,
    mutation_function=mutation_function,
    batch_size=batch_size,
    crossover_percentage=1-proportion_mutation,
    num_iterations=num_iterations,
    population_size=population_size
)

duration = time.time() - init_time
print(f'Duration: {duration:.2f}s')


In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
pareto_bool = compute_pareto_front(solutions.scores)
plot_global_pareto_front(solutions.scores[pareto_bool], ax=ax)
ax.set_title('NSGA2')
plt.show()

## SPEA2

In [None]:
init_time = time.time()

solutions = run_spea2(
    init_genotypes=init_genotypes,
    random_key=key,
    scoring_function=scoring_fn,
    crossover_function=crossover_function,
    mutation_function=mutation_function,
    batch_size=batch_size,
    crossover_percentage=1-proportion_mutation,
    num_iterations=num_iterations,
    population_size=population_size,
    num_neighbours=1
)

duration = time.time() - init_time
print(f'Duration: {duration:.2f}s')


In [None]:

fig, ax = plt.subplots(figsize=(6, 6))
pareto_bool = compute_pareto_front(solutions.scores)
plot_global_pareto_front(solutions.scores[pareto_bool], ax=ax)
ax.set_title('SPEA2')
plt.show()