# 00 - Getting Started with `evosax` - The Ask-Eval-Tell API
### [Last Update: March 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/00_getting_started.ipynb)

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

!pip install git+https://github.com/RobertTLange/evosax.git@main

## Evolution Strategy Instantiation

In [None]:
import jax
import jax.numpy as jnp
from evosax import CMA_ES
from evosax.problems import ClassicFitness

# Instantiate the evolution strategy instance
strategy = CMA_ES(num_dims=2, popsize=10)

# Get default hyperparameters (e.g. lrate, etc.)
es_params = strategy.default_params
es_params = es_params.replace(init_min= -3, init_max=3)

# Initialize the strategy
rng = jax.random.PRNGKey(0)
state = strategy.initialize(rng, es_params)

# Have a look at the hyperparameters (change if desired)
es_params

## Classic Evolution Strategy Benchmarks

In [None]:
# Instantiate helper class for classic evolution strategies benchmarks
evaluator = ClassicFitness("rosenbrock", num_dims=2)

Given our initialized strategy we are now ready to `ask` for a set of candidate parameters. Afterwards, we evaluate these on the 2D Rosenbrock problem and `tell` them to our strategy. The strategy will update its `state` and we can iterate.

In [None]:
# Ask for a set of candidate solutions to evaluate
x, state = strategy.ask(rng, state, es_params)
# Evaluate the population members
fitness = evaluator.rollout(rng, x)
# Update the evolution strategy
state = strategy.tell(x, fitness, state, es_params)
state

## Running the ES Loop with Logging

In [None]:
from evosax.utils import ESLog
# Jittable logging helper
num_gens = 50
es_logging = ESLog(num_dims=2, num_generations=num_gens, top_k=3, maximize=False)
log = es_logging.initialize()

In [None]:
state = strategy.initialize(rng, es_params)
for i in range(num_gens):
    rng, rng_ask = jax.random.split(rng)
    # Ask for a set candidates
    x, state = strategy.ask(rng_ask, state, es_params)
    # Evaluate the candidates
    fitness = evaluator.rollout(rng, x)
    # Update the strategy based on fitness
    state = strategy.tell(x, fitness, state, es_params)
    # Update the log with results
    log = es_logging.update(log, x, fitness)
    
es_logging.plot(log, "2D Rosenbrock CMA-ES", ylims=(0, 30))

## Simultaneous PyTree Evaluation with `evosax`'s `ParameterReshaper`

`evosax` supports the automatic reshaping of proposed flat vectors into Pytrees for smooth neural network fitness evaluations. The transformation is again JAX-composable (`jit`, `vmap`, etc.). Below you find an example for a flax-based multi-layer perceptron:

In [None]:
from flax import linen as nn


class MLP(nn.Module):
    """Simple ReLU MLP."""

    num_hidden_units: int
    num_hidden_layers: int
    num_output_units: int

    @nn.compact
    def __call__(self, x, rng):
        for l in range(self.num_hidden_layers):
            x = nn.Dense(features=self.num_hidden_units)(x)
            x = nn.relu(x)
        x = nn.Dense(features=self.num_output_units)(x)
        return jax.random.categorical(rng, x)
    

# Instantiate the model callables and get a placeholder pytree
network = MLP(64, 2, 2)
policy_params = network.init(rng, jnp.zeros(4,), rng)

In [None]:
from evosax.utils import ParameterReshaper

# Instantiate the reshape helper & get total number of parameters to reshape
param_reshaper = ParameterReshaper(policy_params)
param_reshaper.total_params

Now let's instantiate another evolution strategy (`DE` - Differential Evolution) and generate a set of generation members: 

In [None]:
from evosax import DE
strategy = DE(popsize=100, num_dims=param_reshaper.total_params)
state = strategy.initialize(rng, strategy.default_params)
x, state = strategy.ask(rng, state, strategy.default_params)
x.shape

As we can see this is simply an array with (#population members, #parameters) shape. In order to reshape this array into stacked pytrees, we can simply pass it to the reshaper:

In [None]:
net_params = param_reshaper.reshape(x)
net_params.keys(), net_params['params']['Dense_0']['kernel'].shape

If you now want to map over the population member axis, you can do so with the of the `vmap_dict` (more about this later):

In [None]:
# Get dictionary to vectorize/parallelize rollouts with
param_reshaper.vmap_dict

## Fitness Shaping with `evosax`'s `FitnessShaper`

By default `evosax` will minimize the objective. If you want to instead maximize it (as you commonly do with RL returns) or want to apply any other common ES fitness shaping, you can use the `FitnessShaper`:

In [None]:
from evosax import FitnessShaper
fit_shaper = FitnessShaper(centered_rank=True, w_decay=0.01, maximize=True)

x = jnp.array([[1.0], [2.0], [3.0]])
fit = jnp.array([0.0, 1.0, 2.0])
fit_shaper.apply(x, fit)

## ARS on CartPole Task

`evosax` also comes with a simple fitness evaluation helper for a JAX-based version of Cartpole. You will have to make use of the `vmap_dict` in order to vectorize the rollouts along the population axis:

In [None]:
from evosax.problems import GymFitness

evaluator = GymFitness("CartPole-v1", num_env_steps=200, num_rollouts=16)
evaluator.set_apply_fn(param_reshaper.vmap_dict, network.apply)

In [None]:
from evosax import ARS

strategy = ARS(popsize=100,
               num_dims=param_reshaper.total_params,
               elite_ratio=0.1, opt_name="sgd")

es_params = strategy.default_params
es_params = es_params.replace(opt_params=es_params.opt_params.replace(momentum=0.0))
es_params

In [None]:
num_generations = 250
num_rollouts = 20
print_every_k_gens = 20

es_logging = ESLog(param_reshaper.total_params,
                   num_generations,
                   top_k=5,
                   maximize=True)
log = es_logging.initialize()
fit_shaper = FitnessShaper(maximize=True)

state = strategy.initialize(rng, es_params)

for gen in range(num_generations):
    rng, rng_init, rng_ask, rng_eval = jax.random.split(rng, 4)
    x, state = strategy.ask(rng_ask, state, es_params)
    reshaped_params = param_reshaper.reshape(x)
    fitness = evaluator.rollout(rng_eval, reshaped_params).mean(axis=1)
    fit_re = fit_shaper.apply(x, fitness)
    state = strategy.tell(x, fit_re, state, es_params)
    log = es_logging.update(log, x, fitness)
    
    if gen % print_every_k_gens == 0:
        print("Generation: ", gen, "Performance: ", -state.best_fitness)
    #break
        
es_logging.plot(log, "CartPole Augmented Random Search")