# 08 - Indirect Encodings
### [Last Update: June 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/08_encodings.ipynb)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

!pip install -q git+https://github.com/RobertTLange/evosax.git@main
!pip install -q gymnax

^C
[31mERROR: Operation cancelled by user[0m


## Experimental (!!!) - Random Encodings

In [1]:
import jax
import jax.numpy as jnp
from evosax import NetworkMapper
from evosax.problems import GymnaxFitness
from evosax.utils import ParameterReshaper

rng = jax.random.PRNGKey(0)
# Run Strategy on CartPole MLP
evaluator = GymnaxFitness("CartPole-v1", num_env_steps=200, num_rollouts=16)

network = NetworkMapper["MLP"](
    num_hidden_units=64,
    num_hidden_layers=2,
    num_output_units=2,
    hidden_activation="relu",
    output_activation="categorical",
)
pholder = jnp.zeros((1, evaluator.input_shape[0]))
params = network.init(
    rng,
    x=pholder,
    rng=rng,
)

reshaper = ParameterReshaper(params)

# Raw number of parameters to encode by hypernetwork
reshaper.total_params

ParameterReshaper: 4610 parameters detected for optimization.


4610

In [2]:
from evosax.utils import FitnessShaper
from evosax.experimental.decodings import RandomDecoder

# Only optimize 10 parameters!
num_encoding_dims = 6
reshaper = RandomDecoder(num_encoding_dims, params)
evaluator.set_apply_fn(network.apply)

fit_shaper = FitnessShaper(maximize=True)


RandomDecoder: Encoding parameters - 6


In [3]:
from evosax import DE

strategy = DE(
    num_dims=reshaper.total_params,
    popsize=100,
)
state = strategy.initialize(rng)

for t in range(100):
    rng, rng_eval, rng_iter = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_iter, state)
    x_re = reshaper.reshape(x)
    fitness = evaluator.rollout(rng_eval, x_re).mean(axis=1)
    fit_re = fit_shaper.apply(x, fitness)
    state = strategy.tell(x, fit_re, state)

    if (t + 1) % 20 == 0:
        print(
            t + 1,
            fitness.mean(),
            fitness.max(),
            fitness.std(),
            state.best_fitness,
        )

  abs_value_flat = jax.tree_leaves(abs_value)
  value_flat = jax.tree_leaves(value)


20 137.16063 200.0 54.08962 -200.0
40 143.86063 200.0 53.067028 -200.0
60 149.36375 200.0 55.794975 -200.0
80 152.18437 200.0 52.207775 -200.0
100 153.35687 200.0 50.88368 -200.0


## Experimental (!!!) - Hypernetwork Encodings

In [4]:
from evosax.experimental.decodings import HyperDecoder

reshaper = HyperDecoder(
        params,
        hypernet_config={
            "num_latent_units": 3,  # Latent units per module kernel/bias
            "num_hidden_units": 2,  # Hidden dimensionality of a_i^j embedding
        },
    )
reshaper.total_params

ParameterReshaper: 2306 parameters detected for optimization.


2306

In [5]:
strategy = DE(
    num_dims=reshaper.total_params,
    popsize=100,
)
state = strategy.initialize(rng)

for t in range(60):
    rng, rng_eval, rng_iter = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_iter, state)
    x_re = reshaper.reshape(x)
    fitness = evaluator.rollout(rng_eval, x_re).mean(axis=1)
    fit_re = fit_shaper.apply(x, fitness)
    state = strategy.tell(x, fit_re, state)

    if (t + 1) % 20 == 0:
        print(
            t + 1,
            fitness.mean(),
            fitness.max(),
            fitness.std(),
            state.best_fitness
            )

20 16.82375 41.625 7.145526 -41.625
40 58.821247 200.0 74.27308 -200.0
60 51.158123 200.0 69.18335 -200.0
