In [7]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
import jax
import jax.numpy as jnp
from evosax import CMA_ES , EvoTF_ES
from evosax.problems import BBOBFitness

# Instantiate the problem evaluator
rosenbrock = BBOBFitness("RosenbrockOriginal", num_dims=100, seed_id=2)
#rosenbrock.visualize(plot_log_fn=True)

In [16]:
# Instantiate the search strategy
rng = jax.random.PRNGKey(0)
#strategy = EvoTF_ES(popsize=512, num_dims=100) #elite_ratio=0.5)
strategy = CMA_ES(popsize=512, num_dims=100, elite_ratio=0.5)
es_params = strategy.default_params.replace(init_min=-2, init_max=2)

state = strategy.initialize(rng, es_params)

# Run ask-eval-tell loop - NOTE: By default minimization
for t in range(6000):
    rng, rng_gen, rng_eval = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_gen, state, es_params)
    fitness = rosenbrock.rollout(rng_eval, x)
    state = strategy.tell(x, fitness, state, es_params)

    if (t + 1) % 10 == 0:
        # print("CMA-ES - # Gen: {}|Fitness: {:.5f}|Params: {}".format(
        #     t+1, state.best_fitness, state.best_member))
        print("EvoTF_ES - # Gen: {}|Fitness: {:.5f}|Params: {}".format(
            t+1, state.best_fitness, state.best_member))

EvoTF_ES - # Gen: 10|Fitness: 235897.42188|Params: [-3.0815911   1.0019869   1.0508764  -1.2513977  -2.6713443   0.47416937
  0.4258684  -0.5278035  -0.29909804  0.44716775 -2.5095687  -2.1267352
  1.2558817   1.5497535  -1.8494569   0.19007277 -0.9511514  -0.30477732
 -1.6645725  -0.18528163 -1.8829681   1.29323    -0.2579921   0.77920145
  0.10304642 -0.09693801 -0.1157835  -1.487005    0.81089395 -0.8437668
  1.5572705  -3.0200696   0.23751855  1.5797329  -0.17408028 -0.21333718
  0.22853045 -0.8534716   0.19799256 -1.3301114   0.73754245  0.2218988
 -1.4041053  -0.09604776  1.7222635   0.33597195  1.5475081  -0.9972808
 -0.71261054 -0.15965652 -1.0457085   1.3056891   1.1164984  -1.3731906
 -0.91892254 -0.34861252 -0.4935683   0.6779733   1.0775588  -2.393556
  0.11259335  0.839476    0.42494518  1.5990453   1.2144978   2.645577
  0.64020884 -1.0920336   0.23529524  0.04302859  0.46685636  0.63987714
  0.9034476   1.1533213   1.6324841  -2.3266702   0.17756662  2.6642928
 -1.278351