In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

In [2]:
import jax
import jax.numpy as jnp
from evosax import CMA_ES , EvoTF_ES
from evosax.problems import BBOBFitness

# Instantiate the problem evaluator
rosenbrock = BBOBFitness("RosenbrockOriginal", num_dims=10, seed_id=2)
#rosenbrock.visualize(plot_log_fn=True)

In [3]:
# Instantiate the search strategy
rng = jax.random.PRNGKey(0)
strategy = EvoTF_ES(popsize=512, num_dims=10) #elite_ratio=0.5)
#strategy = CMA_ES(popsize=512, num_dims=100, elite_ratio=0.5)
es_params = strategy.default_params.replace(init_min=-1000, init_max=1000)

state = strategy.initialize(rng, es_params)

# Run ask-eval-tell loop - NOTE: By default minimization
for t in range(100):
    rng, rng_gen, rng_eval = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_gen, state, es_params)
    fitness = rosenbrock.rollout(rng_eval, x)
    state = strategy.tell(x, fitness, state, es_params)

    if (t + 1) % 10 == 0:
        # print("CMA-ES - # Gen: {}|Fitness: {:.5f}|Params: {}".format(
        #     t+1, state.best_fitness, state.best_member))
        print("EvoTF_ES - # Gen: {}|Fitness: {:.5f}".format(
            t+1, state.best_fitness,))
        # print("EvoTF_ES - # Gen: {}".format(
        #     t+1, state.best_fitness))

Loaded pretrained EvoTF model from ckpt: 2024_03_SNES_small.pkl
EvoTF_ES - # Gen: 10|Fitness: 25969597349888.00000
EvoTF_ES - # Gen: 20|Fitness: 20689186717696.00000
EvoTF_ES - # Gen: 30|Fitness: 15526163120128.00000
EvoTF_ES - # Gen: 40|Fitness: 10913834860544.00000
EvoTF_ES - # Gen: 50|Fitness: 5563140276224.00000
EvoTF_ES - # Gen: 60|Fitness: 2070152740864.00000
EvoTF_ES - # Gen: 70|Fitness: 791563534336.00000
EvoTF_ES - # Gen: 80|Fitness: 595409633280.00000
EvoTF_ES - # Gen: 90|Fitness: 534578331648.00000
EvoTF_ES - # Gen: 100|Fitness: 534578331648.00000
