In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'
import sys
sys.path.append("..")

import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

from evosax.problems.low_d_optimisation import (batch_rosenbrock,
                                                batch_himmelblau,
                                                batch_hump_camel)
from evosax.utils import init_logger, update_logger

# Demo: Ask-Evaluate-Tell API
## 2D Rosenbrock Function

In [2]:
from evosax.strategies.gaussian import init_strategy, ask, tell

In [3]:
rng = jax.random.PRNGKey(0)
num_params = 2
a, b = 1, 100
num_generations = 100
print_every_gen = 5

mean_init, sigma_init = jnp.zeros(2), 1.0
pop_size, mu = 10, 5
params, memory = init_strategy(mean_init, sigma_init,
                               pop_size, mu)

# Fix population size, fix elite size
rng, rng_input = jax.random.split(rng)
x, memory = ask(rng_input, params, memory)
value = batch_rosenbrock(x, a, b)
memory = tell(x, value, params, memory)
memory["mean"], memory["sigma"]



(DeviceArray([0.18893634, 0.03640625], dtype=float32),
 DeviceArray([1., 1.], dtype=float32))

In [4]:
params

{'pop_size': 10,
 'n_dim': 2,
 'weights': DeviceArray([0.2, 0.2, 0.2, 0.2, 0.2, 0. , 0. , 0. , 0. , 0. ], dtype=float32),
 'c_m': 1,
 'c_sigma': 0.0,
 'tol_fun': 1e-12,
 'min_generations': 10}

### Run full Rosenbrock Example with JIT

In [5]:
mean_init, sigma_init = jnp.zeros(num_params), 1
params, memory = init_strategy(mean_init, sigma_init, pop_size, mu)

top_k = 3

evo_logger = init_logger(top_k, num_params)

for generation in range(num_generations):
    # Ask - Eval - Tell - Log
    rng, rng_input = jax.random.split(rng)
    x, memory = ask(rng_input, params, memory)
    value = batch_rosenbrock(x, a, b)
    memory = tell(x, value, params, memory)
    evo_logger = update_logger(evo_logger, x, value, memory, top_k)
    if (generation + 1) % print_every_gen == 0:
        print("# Gen: {}|Fitness: {:.2f}|Params: {}".format(generation+1,
                                                  evo_logger["top_values"][0],
                                                  evo_logger["top_params"][0]))

# Gen: 5|Fitness: 2.51|Params: [ 0.0354423  -0.12457814]
# Gen: 10|Fitness: 0.91|Params: [0.74999964 0.6545119 ]
# Gen: 15|Fitness: 0.39|Params: [1.335381  1.7303352]
# Gen: 20|Fitness: 0.10|Params: [0.6815662  0.46708193]
# Gen: 25|Fitness: 0.10|Params: [0.6815662  0.46708193]
# Gen: 30|Fitness: 0.10|Params: [0.6815662  0.46708193]
# Gen: 35|Fitness: 0.01|Params: [1.0614763 1.1363623]
# Gen: 40|Fitness: 0.01|Params: [1.0614763 1.1363623]
# Gen: 45|Fitness: 0.01|Params: [1.0614763 1.1363623]
# Gen: 50|Fitness: 0.01|Params: [1.0614763 1.1363623]
# Gen: 55|Fitness: 0.01|Params: [1.0614763 1.1363623]
# Gen: 60|Fitness: 0.01|Params: [1.0614763 1.1363623]
# Gen: 65|Fitness: 0.01|Params: [1.0614763 1.1363623]
# Gen: 70|Fitness: 0.01|Params: [1.0614763 1.1363623]
# Gen: 75|Fitness: 0.01|Params: [1.0614763 1.1363623]
# Gen: 80|Fitness: 0.01|Params: [1.0614763 1.1363623]
# Gen: 85|Fitness: 0.01|Params: [1.0614763 1.1363623]
# Gen: 90|Fitness: 0.01|Params: [1.0614763 1.1363623]
# Gen: 95|Fitness