# 01 - Simple ES Benchmark Function
### [Last Update: March 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

!pip install -q git+https://github.com/RobertTLange/evosax.git@main

## 2D Rosenbrock with CMA-ES

`evosax` implements a set of different classic benchmark functions. These include multi-dimensional versions of `quadratic`, `rosenbrock`, `ackley`, `griewank`, `rastrigin`, `schwefel`, `himmelblau`, `six-hump`. In the following we focus on the 2D Rosenbrock case, but feel free to play around with the others.

In [1]:
import jax
import jax.numpy as jnp
from evosax import CMA_ES
from evosax.problems import ClassicFitness

# Instantiate the problem evaluator
rosenbrock = ClassicFitness("rosenbrock", num_dims=2)

# Instantiate the search strategy
rng = jax.random.PRNGKey(0)
strategy = CMA_ES(popsize=20, num_dims=2, elite_ratio=0.5)
state = strategy.initialize(rng)

# Run ask-eval-tell loop - NOTE: By default minimization
for t in range(50):
    rng, rng_gen, rng_eval = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_gen, state)
    fitness = rosenbrock.rollout(rng_eval, x)
    state = strategy.tell(x, fitness, state)

    if (t + 1) % 10 == 0:
        print("CMA-ES - # Gen: {}|Fitness: {:.2f}|Params: {}".format(
            t+1, state.best_fitness, state.best_member))



CMA-ES - # Gen: 10|Fitness: 0.13|Params: [0.6441135  0.41466928]
CMA-ES - # Gen: 20|Fitness: 0.00|Params: [0.97413015 0.9518173 ]
CMA-ES - # Gen: 30|Fitness: 0.00|Params: [0.9981632 0.9965331]
CMA-ES - # Gen: 40|Fitness: 0.00|Params: [0.9999719 0.9999461]
CMA-ES - # Gen: 50|Fitness: 0.00|Params: [0.9999997 0.9999994]


## 2D Rosenbrock with Other ES

In [2]:
from evosax import Strategies
rng = jax.random.PRNGKey(0)

for s_name in ["SimpleES", "SimpleGA", "PSO", "DE", "Sep_CMA_ES",
               "Full_iAMaLGaM", "Indep_iAMaLGaM", "MA_ES", "LM_MA_ES",
               "RmES", "GLD", "SimAnneal"]:
    strategy = Strategies[s_name](popsize=20, num_dims=2)
    es_params = strategy.default_params
    es_params = es_params.replace(init_min=-2, init_max=2)
    state = strategy.initialize(rng, es_params)

    for t in range(30):
        rng, rng_gen, rng_eval = jax.random.split(rng, 3)
        x, state = strategy.ask(rng_gen, state, es_params)
        fitness = rosenbrock.rollout(rng_eval, x)
        state = strategy.tell(x, fitness, state, es_params)

        if (t + 1) % 5 == 0:
            print("{} - # Gen: {}|Fitness: {:.2f}|Params: {}".format(
                s_name, t+1, state.best_fitness, state.best_member))
    print(20*"=")

SimpleES - # Gen: 5|Fitness: 0.44|Params: [0.41951638 0.14410228]
SimpleES - # Gen: 10|Fitness: 0.04|Params: [0.91297907 0.8160112 ]
SimpleES - # Gen: 15|Fitness: 0.01|Params: [0.98460174 0.97844493]
SimpleES - # Gen: 20|Fitness: 0.01|Params: [0.98460174 0.97844493]
SimpleES - # Gen: 25|Fitness: 0.01|Params: [0.98460174 0.97844493]
SimpleES - # Gen: 30|Fitness: 0.01|Params: [0.98460174 0.97844493]
SimpleGA - # Gen: 5|Fitness: 6.79|Params: [-0.012256   -0.24003565]
SimpleGA - # Gen: 10|Fitness: 0.68|Params: [0.21533592 0.02063736]
SimpleGA - # Gen: 15|Fitness: 0.39|Params: [0.4103716  0.14900509]
SimpleGA - # Gen: 20|Fitness: 0.18|Params: [0.5903524  0.33600026]
SimpleGA - # Gen: 25|Fitness: 0.17|Params: [0.6199676  0.39935672]
SimpleGA - # Gen: 30|Fitness: 0.13|Params: [0.64335036 0.40990546]
PSO - # Gen: 5|Fitness: 1.11|Params: [-0.01428866  0.02790421]
PSO - # Gen: 10|Fitness: 0.03|Params: [1.0889671 1.1718146]
PSO - # Gen: 15|Fitness: 0.01|Params: [1.109518  1.2260276]
PSO - # Gen: 

# XNES on Sinusoidal Task

In [3]:
from evosax.strategies import XNES

def f(x):
    """Taken from https://github.com/chanshing/xnes""" 
    r = jnp.sum(x ** 2)
    return -jnp.sin(r) / r

batch_func = jax.vmap(f, in_axes=0)

rng = jax.random.PRNGKey(0)
strategy = XNES(popsize=50, num_dims=2)
es_params = strategy.default_params
es_params = es_params.replace(use_adaptive_sampling=True, 
                              use_fitness_shaping=True,
                              eta_bmat=0.01,
                              eta_sigma_init=0.1)

state = strategy.initialize(rng, es_params)
# Set mean to a bad initial guess
state = state.replace(mean = jnp.array([9999.0, -9999.0]))
num_iters = 5000
for t in range(num_iters):
    rng, rng_iter = jax.random.split(rng)
    y, state = strategy.ask(rng_iter, state, es_params)
    fitness = batch_func(y)
    state = strategy.tell(y, fitness, state, es_params)
    if (t + 1) % 500 == 0:
        print("xNES - # Gen: {}|Fitness: {:.5f}|Params: {}".format(
                t+1, state.best_fitness, state.best_member))


xNES - # Gen: 500|Fitness: -0.00000|Params: [ 9991.45  -9987.809]
xNES - # Gen: 1000|Fitness: -0.00000|Params: [ 9951.659 -9911.333]
xNES - # Gen: 1500|Fitness: -1.00000|Params: [ 8.5644424e-05 -5.0786706e-03]
xNES - # Gen: 2000|Fitness: -1.00000|Params: [ 8.5644424e-05 -5.0786706e-03]
xNES - # Gen: 2500|Fitness: -1.00000|Params: [ 8.5644424e-05 -5.0786706e-03]
xNES - # Gen: 3000|Fitness: -1.00000|Params: [ 8.5644424e-05 -5.0786706e-03]
xNES - # Gen: 3500|Fitness: -1.00000|Params: [ 8.5644424e-05 -5.0786706e-03]
xNES - # Gen: 4000|Fitness: -1.00000|Params: [ 8.5644424e-05 -5.0786706e-03]
xNES - # Gen: 4500|Fitness: -1.00000|Params: [ 8.5644424e-05 -5.0786706e-03]
xNES - # Gen: 5000|Fitness: -1.00000|Params: [ 8.5644424e-05 -5.0786706e-03]
