# Simple ES Benchmark Function

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import jax
import matplotlib.pyplot as plt

## 2D Rosenbrock with CMA-ES

In [2]:
from evosax import CMA_ES
from evosax.problems import ClassicFitness

# Instantiate the problem evaluator
rosenbrock = ClassicFitness("rosenbrock", num_dims=2)

# Instantiate the search strategy
rng = jax.random.PRNGKey(0)
strategy = CMA_ES(popsize=20, num_dims=2, elite_ratio=0.5)
params = strategy.default_params
state = strategy.initialize(rng, params)

# Run ask-eval-tell loop - NOTE: By default minimization
for t in range(50):
    rng, rng_gen, rng_eval = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_gen, state, params)
    fitness = rosenbrock.rollout(rng_eval, x)
    state = strategy.tell(x, fitness, state, params)

    if (t + 1) % 10 == 0:
        print("# Gen: {}|Fitness: {:.2f}|Params: {}".format(
            t+1, state["best_fitness"], state["best_member"]))

## 2D Rosenbrock with Other ES

In [None]:
from evosax import Strategies
rng = jax.random.PRNGKey(0)

for s_name in ["Simple_ES", "Simple_GA", "PSO_ES", "Differential_ES"]:
    if s_name in ["Simple_ES", "Simple_GA"]: 
        strategy = Strategies[s_name](popsize=20, num_dims=2, elite_ratio=0.5)
    else:
        strategy = Strategies[s_name](popsize=20, num_dims=2)
    params = strategy.default_params
    state = strategy.initialize(rng, params)

    for t in range(50):
        rng, rng_gen, rng_eval = jax.random.split(rng, 3)
        x, state = strategy.ask(rng_gen, state, params)
        fitness = rosenbrock.rollout(rng_eval, x)
        state = strategy.tell(x, fitness, state, params)

        if (t + 1) % 10 == 0:
            print("# Gen: {}|Fitness: {:.2f}|Params: {}".format(
                t+1, state["best_fitness"], state["best_member"]))