# 03 - Evolving a MNIST CNN
### [Last Update: February 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/03_cnn_mnist.ipynb)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import jax
import jax.numpy as jnp



In [None]:
from evosax import Augmented_RS
from ParameterReshaper, NetworkMapper
from evosax.problems import SupervisedFitness

rng = jax.random.PRNGKey(0)
network = NetworkMapper["All_CNN_C"](
        depth_1=1,
        depth_2=1,
        features_1=16,
        features_2=8,
        kernel_1=3,
        kernel_2=5,
        strides_1=1,
        strides_2=1,
        final_window=(28, 28),
        num_output_units=10,
    )
pholder = jnp.zeros((1, 28, 28, 1))
params = network.init(
    rng,
    x=pholder,
    rng=rng,
)

param_reshaper = ParameterReshaper(params['params'])

In [None]:
evaluator = SupervisedFitness("MNIST", batch_size=128)
evaluator.set_apply_fn(network.apply)
rollout = jax.jit(jax.vmap(evaluator.rollout, in_axes=(None, param_reshaper.vmap_dict)))

In [None]:
strategy = Augmented_RS(popsize=100,
                        num_dims=param_reshaper.total_params,
                        elite_ratio=0.2
                        opt_name="clipup")

es_params = {
        "lrate_init": 0.01,  # Adam learning rate outer step
        "lrate_decay": 0.999,
        "lrate_limit": 0.001,
        "sigma_init": 0.1,
        "sigma_decay": 0.999,
        "sigma_limit": 0.01,
        "init_min": -0.1,
        "init_max": 0.1
}

In [None]:
num_generations = 350
num_rollouts = 20
state = strategy.initialize(rng, es_params)

for gen in range(num_generations):
    rng, rng_init, rng_ask, rng_eval = jax.random.split(rng, 4)
    x, state = strategy.ask(rng_ask, state, es_params)
    reshaped_params = param_reshaper.reshape(x)
    batch_rng = jax.random.split(rng_eval, num_rollouts)
    fitness = rollout(batch_rng, reshaped_params).mean(axis=1)
    state = strategy.tell(x, fitness, state, es_params)