# 01 - Simple ES Benchmark Function
### [Last Update: February 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import jax
import matplotlib.pyplot as plt

## 2D Rosenbrock with CMA-ES

In [2]:
from evosax import CMA_ES
from evosax.problems import ClassicFitness

# Instantiate the problem evaluator
rosenbrock = ClassicFitness("rosenbrock", num_dims=2)

# Instantiate the search strategy
rng = jax.random.PRNGKey(0)
strategy = CMA_ES(popsize=20, num_dims=2, elite_ratio=0.5)
params = strategy.default_params
state = strategy.initialize(rng, params)

# Run ask-eval-tell loop - NOTE: By default minimization
for t in range(50):
    rng, rng_gen, rng_eval = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_gen, state, params)
    fitness = rosenbrock.rollout(rng_eval, x)
    state = strategy.tell(x, fitness, state, params)

    if (t + 1) % 10 == 0:
        print("# Gen: {}|Fitness: {:.2f}|Params: {}".format(
            t+1, state["best_fitness"], state["best_member"]))

## 2D Rosenbrock with Other ES

In [None]:
from evosax import Strategies
rng = jax.random.PRNGKey(0)

for s_name in ["Simple_ES", "Simple_GA", "PSO_ES", "Differential_ES"]:
    if s_name in ["Simple_ES", "Simple_GA"]: 
        strategy = Strategies[s_name](popsize=20, num_dims=2, elite_ratio=0.5)
    else:
        strategy = Strategies[s_name](popsize=20, num_dims=2)
    params = strategy.default_params
    state = strategy.initialize(rng, params)

    for t in range(50):
        rng, rng_gen, rng_eval = jax.random.split(rng, 3)
        x, state = strategy.ask(rng_gen, state, params)
        fitness = rosenbrock.rollout(rng_eval, x)
        state = strategy.tell(x, fitness, state, params)

        if (t + 1) % 10 == 0:
            print("# Gen: {}|Fitness: {:.2f}|Params: {}".format(
                t+1, state["best_fitness"], state["best_member"]))

# xNES on Sinusoidal Task

In [None]:
from evosax.strategies import xNES


def f(x):  # sin(x^2+y^2)/(x^2+y^2)
    r = jnp.sum(x ** 2)
    return -jnp.sin(r) / r

batch_func = jax.vmap(f, in_axes=0)

rng = jax.random.PRNGKey(0)
strategy = xNES(popsize=50, num_dims=2)
params = strategy.default_params
params["use_adaptive_sampling"] = True
params["use_fitness_shaping"] = True
params["eta_bmat"] = 0.01
params["eta_sigma"] = 0.1

state = strategy.initialize(rng, params)
state["mean"] = jnp.array([9999.0, -9999.0])  # a bad init guess
fitness_log = []
num_iters = 5000
for t in range(num_iters):
    rng, rng_iter = jax.random.split(rng)
    y, state = strategy.ask(rng_iter, state, params)
    fitness = batch_func(y)
    state = strategy.tell(y, fitness, state, params)
    best_id = jnp.argmin(fitness)
    fitness_log.append(fitness[best_id])
    if t % 500 == 0:
        print(t, jnp.min(jnp.array(fitness_log)), state["mean"])
