# 01 - Simple ES Benchmark Function
### [Last Update: February 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

!pip install evosax

import jax
import jax.numpy as jnp

## 2D Rosenbrock with CMA-ES

`evosax` implements a set of different classic benchmark functions. These include multi-dimensional versions of `quadratic`, `rosenbrock`, `ackley`, `griewank`, `rastrigin`, `schwefel`, `himmelblau`, `six-hump`. In the following we focus on the 2D Rosenbrock case, but feel free to play around with the others.

In [2]:
from evosax import CMA_ES
from evosax.problems import ClassicFitness

# Instantiate the problem evaluator
rosenbrock = ClassicFitness("rosenbrock", num_dims=2)

# Instantiate the search strategy
rng = jax.random.PRNGKey(0)
strategy = CMA_ES(popsize=20, num_dims=2, elite_ratio=0.5)
es_params = strategy.default_params
state = strategy.initialize(rng, es_params)

# Run ask-eval-tell loop - NOTE: By default minimization
for t in range(50):
    rng, rng_gen, rng_eval = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_gen, state, es_params)
    fitness = rosenbrock.rollout(rng_eval, x)
    state = strategy.tell(x, fitness, state, es_params)

    if (t + 1) % 10 == 0:
        print("CMA-ES - # Gen: {}|Fitness: {:.2f}|Params: {}".format(
            t+1, state["best_fitness"], state["best_member"]))

CMA-ES - # Gen: 10|Fitness: 0.67|Params: [0.18411857 0.02598005]
CMA-ES - # Gen: 20|Fitness: 0.14|Params: [0.6254667  0.39644676]
CMA-ES - # Gen: 30|Fitness: 0.01|Params: [0.9449675  0.89805067]
CMA-ES - # Gen: 40|Fitness: 0.00|Params: [0.999287  0.9985225]
CMA-ES - # Gen: 50|Fitness: 0.00|Params: [0.999984   0.99996775]


## 2D Rosenbrock with Other ES

In [3]:
from evosax import Strategies
rng = jax.random.PRNGKey(0)

for s_name in ["Simple_ES", "Simple_GA", "PSO_ES", "Differential_ES"]:
    if s_name in ["Simple_ES", "Simple_GA"]: 
        strategy = Strategies[s_name](popsize=20, num_dims=2, elite_ratio=0.5)
    else:
        strategy = Strategies[s_name](popsize=20, num_dims=2)
    es_params = strategy.default_params
    state = strategy.initialize(rng, es_params)

    for t in range(30):
        rng, rng_gen, rng_eval = jax.random.split(rng, 3)
        x, state = strategy.ask(rng_gen, state, es_params)
        fitness = rosenbrock.rollout(rng_eval, x)
        state = strategy.tell(x, fitness, state, es_params)

        if (t + 1) % 5 == 0:
            print("{} - # Gen: {}|Fitness: {:.2f}|Params: {}".format(
                s_name, t+1, state["best_fitness"], state["best_member"]))
    print(20*"=")

Simple_ES - # Gen: 5|Fitness: 0.30|Params: [0.92314136 0.90602154]
Simple_ES - # Gen: 10|Fitness: 0.30|Params: [0.92314136 0.90602154]
Simple_ES - # Gen: 15|Fitness: 0.01|Params: [1.0070696 1.0240867]
Simple_ES - # Gen: 20|Fitness: 0.01|Params: [1.0070696 1.0240867]
Simple_ES - # Gen: 25|Fitness: 0.01|Params: [1.0070696 1.0240867]
Simple_ES - # Gen: 30|Fitness: 0.01|Params: [1.0070696 1.0240867]
Simple_GA - # Gen: 5|Fitness: 0.48|Params: [1.0857993 1.2479298]
Simple_GA - # Gen: 10|Fitness: 0.34|Params: [1.0905902 1.2473906]
Simple_GA - # Gen: 15|Fitness: 0.16|Params: [1.0978061 1.2433528]
Simple_GA - # Gen: 20|Fitness: 0.06|Params: [1.1003273 1.2336206]
Simple_GA - # Gen: 25|Fitness: 0.01|Params: [1.1077771 1.2309372]
Simple_GA - # Gen: 30|Fitness: 0.01|Params: [1.1069536 1.225423 ]
PSO_ES - # Gen: 5|Fitness: 1.11|Params: [-0.01428866  0.02790421]
PSO_ES - # Gen: 10|Fitness: 0.03|Params: [1.0889671 1.1718146]
PSO_ES - # Gen: 15|Fitness: 0.01|Params: [1.109518  1.2260276]
PSO_ES - # Gen

# xNES on Sinusoidal Task

In [4]:
from evosax.strategies import xNES

def f(x):
    """Taken from https://github.com/chanshing/xnes""" 
    r = jnp.sum(x ** 2)
    return -jnp.sin(r) / r

batch_func = jax.vmap(f, in_axes=0)

rng = jax.random.PRNGKey(0)
strategy = xNES(popsize=50, num_dims=2)
es_params = strategy.default_params
es_params["use_adaptive_sampling"] = True
es_params["use_fitness_shaping"] = True
es_params["eta_bmat"] = 0.01
es_params["eta_sigma"] = 0.1

state = strategy.initialize(rng, es_params)
state["mean"] = jnp.array([9999.0, -9999.0])  # a bad init guess
num_iters = 5000
for t in range(num_iters):
    rng, rng_iter = jax.random.split(rng)
    y, state = strategy.ask(rng_iter, state, es_params)
    fitness = batch_func(y)
    state = strategy.tell(y, fitness, state, es_params)
    if (t + 1) % 500 == 0:
        print("xNES - # Gen: {}|Fitness: {:.5f}|Params: {}".format(
                t+1, state["best_fitness"], state["best_member"]))


xNES - # Gen: 500|Fitness: -0.00000|Params: [ 9972.834 -9983.528]
xNES - # Gen: 1000|Fitness: -0.00000|Params: [ 9955.314 -9948.974]
xNES - # Gen: 1500|Fitness: -0.00000|Params: [ 9955.314 -9948.974]
xNES - # Gen: 2000|Fitness: -0.00000|Params: [ 9166.509 -9119.922]
xNES - # Gen: 2500|Fitness: -0.00000|Params: [ 9105.494 -9127.413]
xNES - # Gen: 3000|Fitness: -0.00000|Params: [ 7209.5283 -6139.299 ]
xNES - # Gen: 3500|Fitness: -1.00000|Params: [ 0.00756863 -0.00739648]
xNES - # Gen: 4000|Fitness: -1.00000|Params: [ 0.00756863 -0.00739648]
xNES - # Gen: 4500|Fitness: -1.00000|Params: [ 0.00756863 -0.00739648]
xNES - # Gen: 5000|Fitness: -1.00000|Params: [ 0.00756863 -0.00739648]
