# 08 - Indirect Encodings
### [Last Update: June 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/08_encodings.ipynb)

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

!pip install git+https://github.com/RobertTLange/evosax.git@main

## Experimental (!!!) - Random Encodings

In [1]:
import jax
import jax.numpy as jnp
from evosax import NetworkMapper
from evosax.problems import GymFitness
from evosax.utils import ParameterReshaper

rng = jax.random.PRNGKey(0)
# Run Strategy on CartPole MLP
evaluator = GymFitness("CartPole-v1", num_env_steps=200, num_rollouts=16)

network = NetworkMapper["MLP"](
    num_hidden_units=64,
    num_hidden_layers=2,
    num_output_units=2,
    hidden_activation="relu",
    output_activation="categorical",
)
pholder = jnp.zeros((1, evaluator.input_shape[0]))
params = network.init(
    rng,
    x=pholder,
    rng=rng,
)

reshaper = ParameterReshaper(params)
reshaper.total_params



ParameterReshaper: 4610 parameters detected for optimization.


DeviceArray(4610, dtype=int32)

In [2]:
from evosax.utils import FitnessShaper
from evosax.experimental.decodings import RandomDecoder

# Only optimize 10 parameters!
num_encoding_dims = 6
reshaper = RandomDecoder(num_encoding_dims, params)
evaluator.set_apply_fn(reshaper.vmap_dict, network.apply)

fit_shaper = FitnessShaper(maximize=True)


In [3]:
from evosax import DE

strategy = DE(
    num_dims=reshaper.total_params,
    popsize=100,
)
state = strategy.initialize(rng)

for t in range(100):
    rng, rng_eval, rng_iter = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_iter, state)
    x_re = reshaper.reshape(x)
    fitness = evaluator.rollout(rng_eval, x_re).mean(axis=1)
    fit_re = fit_shaper.apply(x, fitness)
    state = strategy.tell(x, fit_re, state)

    if (t + 1) % 20 == 0:
        print(
            t + 1,
            fitness.mean(),
            fitness.max(),
            fitness.std(),
            state.best_fitness,
        )



20 133.08125 200.0 50.5719 -200.0
40 149.9775 200.0 45.9242 -200.0
60 157.07562 200.0 51.393032 -200.0
80 151.68312 200.0 53.497288 -200.0
100 160.70312 200.0 50.2572 -200.0


## Experimental (!!!) - Hypernetwork Encodings

In [4]:
from evosax.experimental.decodings import HyperDecoder

reshaper = HyperDecoder(
        params,
        hypernet_config={
            "num_latent_units": 3,  # Latent units per module kernel/bias
            "num_hidden_units": 2,  # Hidden dimensionality of a_i^j embedding
        },
    )
reshaper.total_params

ParameterReshaper: 2306 parameters detected for optimization.


DeviceArray(2306, dtype=int32)

In [5]:
strategy = DE(
    num_dims=reshaper.total_params,
    popsize=100,
)
state = strategy.initialize(rng)

for t in range(100):
    rng, rng_eval, rng_iter = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_iter, state)
    x_re = reshaper.reshape(x)
    fitness = evaluator.rollout(rng_eval, x_re).mean(axis=1)
    fit_re = fit_shaper.apply(x, fitness)
    state = strategy.tell(x, fit_re, state)

    if (t + 1) % 20 == 0:
        print(
            t + 1,
            fitness.mean(),
            fitness.max(),
            fitness.std(),
            state.best_fitness
            )

20 19.089375 30.3125 6.1910863 -33.4375
40 29.787498 195.5625 32.49928 -200.0
60 31.540625 200.0 44.170444 -200.0
80 28.501875 200.0 47.56071 -200.0
100 28.136875 200.0 44.376225 -200.0
