In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

In [2]:
# Imports

import jax
import jax.numpy as jnp

from evosax import OpenES, CMA_ES, ParameterReshaper, FitnessShaper, NetworkMapper
from evosax.utils import ESLog
from evosax.problems import GymnaxFitness

from envs.custom_gymnax_fitness import CustomGymnaxFitness, SpecificGymnaxFitness

In [3]:
# Seeding

seed = 0
rng = jax.random.PRNGKey(seed)

In [4]:
# Setting up Network and Param Reshaper

network = NetworkMapper["MLP"](
    num_hidden_units=64,
    num_hidden_layers=2,
    num_output_units=2,
    hidden_activation="relu",
    output_activation="categorical",
)
pholder = jnp.zeros((4,))
params = network.init(
    rng,
    x=pholder,
    rng=rng,
)

param_reshaper = ParameterReshaper(params)

ParameterReshaper: 4610 parameters detected for optimization.


In [5]:
# Initialising 'Environment', which in this case is the evaluator
# GymnaxFitness handles initialisation of environment

evaluator = SpecificGymnaxFitness("CartPole-v1", num_env_steps=200, num_rollouts=16)
evaluator.set_apply_fn(network.apply)

In [6]:
# Setting up EvoAlgo Strategy, in this case OpenES
# popsize is the population size per generation

open_strategy = OpenES(popsize=100,
                num_dims=param_reshaper.total_params,
                opt_name="adam",
                lrate_init=0.1
                )

cma_strategy = CMA_ES(
    popsize=100,
    num_dims=param_reshaper.total_params,
)

strategy = open_strategy

strategy.default_params

EvoParams(opt_params=OptParams(lrate_init=0.1, lrate_decay=1.0, lrate_limit=0.001, momentum=None, beta_1=0.99, beta_2=0.999, beta_3=None, eps=1e-08, max_speed=None), sigma_init=0.03, sigma_decay=1.0, sigma_limit=0.01, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38)

In [7]:
# Setting up Logging
# num_generations is the number of 'updates'

num_generations = 100
print_every_k_gens = 20

es_logging = ESLog(param_reshaper.total_params,
                   num_generations=num_generations,
                   top_k=5,
                   maximize=True # As in, we are trying to maximise the reward
                   )
log = es_logging.initialize()

# Setting up FitnessShaper, where 'Fitness' is effectively the Reward

fit_shaper = FitnessShaper(centered_rank=True,
                           w_decay=0.0,
                           maximize=True)

In [8]:
# Full Training Cell

state = strategy.initialize(rng)

for gen in range(num_generations):
    rng, rng_init, rng_ask, rng_eval = jax.random.split(rng, 4)
    x, state = strategy.ask(rng_ask, state) # Get NN Params from Evo
    reshaped_params = param_reshaper.reshape(x) # Reshape NN Params for NN Rollout

    # Use NN Params for MLP Policy Action, send to RL Env, and get Rewards (Fitness)
    fitness = evaluator.rollout(rng_eval, reshaped_params).mean(axis=1)

    fit_re = fit_shaper.apply(x, fitness) # Reshape Fitness (Reward)
    state = strategy.tell(x, fit_re, state) # Update Evo Algo
    log = es_logging.update(log, x, fitness) # Update Logger
    
    if gen % print_every_k_gens == 0:
        print("Generation: ", gen, "Performance: ", log["log_top_1"][gen])

Generation:  0 Performance:  22.875
Generation:  20 Performance:  200.0
Generation:  40 Performance:  200.0
Generation:  60 Performance:  200.0
Generation:  80 Performance:  200.0
