# 03 - Evolving a MNIST CNN with OpenES
### [Last Update: February 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/03_cnn_mnist.ipynb)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import jax
import jax.numpy as jnp

In [2]:
from evosax import Open_ES, ParameterReshaper, NetworkMapper
from evosax.problems import SupervisedFitness

rng = jax.random.PRNGKey(0)
network = NetworkMapper["CNN"](
        depth_1=1,
        depth_2=1,
        features_1=8,
        features_2=16,
        kernel_1=5,
        kernel_2=5,
        strides_1=1,
        strides_2=1,
        num_linear_layers=0,
        num_output_units=10,
    )
pholder = jnp.zeros((1, 28, 28, 1))
params = network.init(
    rng,
    x=pholder,
    rng=rng,
)

param_reshaper = ParameterReshaper(params['params'])

In [3]:
# Set up the dataloader for batch evaluations (may take a sec)
evaluator = SupervisedFitness("MNIST", batch_size=256)
evaluator.set_apply_fn(network.apply)
rollout = jax.jit(jax.vmap(evaluator.rollout, in_axes=(None, param_reshaper.vmap_dict)))

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [4]:
from evosax import Open_ES
strategy = Open_ES(popsize=100, num_dims=param_reshaper.total_params, opt_name="adam")
es_params = {
    "sigma_init": 0.01,  # Initial scale of isotropic Gaussian noise
    "sigma_decay": 0.999,  # Multiplicative decay factor
    "sigma_limit": 0.01,  # Smallest possible scale
    "lrate_init": 0.001,  # Initial learning rate
    "lrate_decay": 0.9999,  # Multiplicative decay factor
    "lrate_limit": 0.0001,  # Smallest possible lrate
    "beta_1": 0.99,   # Adam - beta_1
    "beta_2": 0.999,  # Adam - beta_2
    "eps": 1e-8,  # eps constant,
    "init_min": 0.0,  # Range of parameter archive initialization - Min
    "init_max": 0.0,  # Range of parameter archive initialization - Max
    "clip_min": -10,  # Range of parameter proposals - Min
    "clip_max": 10  # Range of parameter proposals - Max
}

In [5]:
from evosax import FitnessShaper
fit_shaper = FitnessShaper(centered_rank=True,
                           z_score=True,
                           w_decay=0.1)

In [6]:
num_generations = 7500
print_every_k_gens = 100
state = strategy.initialize(rng, es_params)

for gen in range(num_generations):
    rng, rng_init, rng_ask, rng_eval = jax.random.split(rng, 4)
    x, state = strategy.ask(rng_ask, state, es_params)
    reshaped_params = param_reshaper.reshape(x)
    # rollout will pytree w. test_acc, test_loss, train_acc, train_loss
    fitness_loss = []
    # run 4 sequential batch evals (circumvent accelerator memory problems)
    for i in range(4):
        rng_eval, rng_eval_i = jax.random.split(rng_eval)
        fitness = rollout(rng_eval_i, reshaped_params)
        fitness_loss.append(fitness['train_loss'])
    
    fitness_loss = jnp.mean(jnp.stack(fitness_loss, axis=1), axis=1)
    fit_re = fit_shaper.apply(x, fitness_loss)
    state = strategy.tell(x, fit_re, state, es_params)
    if gen % print_every_k_gens == 0:
        print(gen, fitness['train_acc'].mean(), fitness['test_acc'].mean(), 
              fitness_loss.mean(), fitness['test_loss'].mean())

0 0.105 0.09765625 2.3026407 2.3026216
100 0.10601562 0.20328124 2.2948122 2.2905602
200 0.3980078 0.42074218 2.116134 2.100343
300 0.67105466 0.71335936 1.4127865 1.3928121
400 0.7093359 0.76124996 0.93174905 0.83537227
500 0.7780078 0.79531246 0.6982758 0.6634201
600 0.8074609 0.80781245 0.61464864 0.69131184
700 0.8813281 0.83 0.56244344 0.55927765
800 0.8567578 0.84925777 0.4986973 0.4897943
900 0.8617578 0.84734374 0.51981276 0.538323
1000 0.85679686 0.88078123 0.50712556 0.40201977
1100 0.84910154 0.86902344 0.5138326 0.47206232
1200 0.876914 0.8602734 0.4440429 0.43906233
1300 0.84410155 0.86960936 0.5221144 0.5351238
1400 0.8720312 0.90199214 0.482296 0.3756349
1500 0.8815625 0.87324214 0.4720793 0.45056823
1600 0.9103515 0.8879687 0.38361242 0.37585488
1700 0.92222655 0.8800781 0.36601463 0.4366532
1800 0.8698437 0.91054684 0.38637078 0.38063183
1900 0.90558594 0.89335936 0.3719517 0.31824312
2000 0.8975781 0.90160155 0.3575186 0.37494323
2100 0.90847653 0.87742186 0.33438185 