# 03 - Evolving a MNIST CNN with OpenES
### [Last Update: February 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/03_cnn_mnist.ipynb)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

!pip install git+https://github.com/RobertTLange/evosax.git@main

In [2]:
import jax
import jax.numpy as jnp

from evosax import OpenES, ParameterReshaper, NetworkMapper
from evosax.problems import VisionFitness

rng = jax.random.PRNGKey(0)
network = NetworkMapper["CNN"](
        depth_1=1,
        depth_2=1,
        features_1=8,
        features_2=16,
        kernel_1=5,
        kernel_2=5,
        strides_1=1,
        strides_2=1,
        num_linear_layers=0,
        num_output_units=10,
    )
pholder = jnp.zeros((1, 28, 28, 1))
params = network.init(
    rng,
    x=pholder,
    rng=rng,
)

param_reshaper = ParameterReshaper(params['params'])
test_param_reshaper = ParameterReshaper(params['params'], n_devices=1)

ParameterReshaper: More than one device detected. Please make sure that the ES population size divides evenly across the number of devices to pmap/parallelize over.


In [3]:
# Set up the dataloader for batch evaluations (may take a sec)
train_evaluator = VisionFitness("MNIST", batch_size=1024, test=False)
test_evaluator = VisionFitness("MNIST", batch_size=10000, test=True, n_devices=1)

train_evaluator.set_apply_fn(param_reshaper.vmap_dict, network.apply)
test_evaluator.set_apply_fn(param_reshaper.vmap_dict, network.apply)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


SupervisedFitness: More than one device detected. Please make sure that the ES population size divides evenly across the number of devices to pmap/parallelize over.


In [4]:
from evosax import OpenES
strategy = OpenES(popsize=100, num_dims=param_reshaper.total_params, opt_name="adam")
es_params = {
    "sigma_init": 0.01,  # Initial scale of isotropic Gaussian noise
    "sigma_decay": 0.999,  # Multiplicative decay factor
    "sigma_limit": 0.01,  # Smallest possible scale
    "lrate_init": 0.001,  # Initial learning rate
    "lrate_decay": 0.9999,  # Multiplicative decay factor
    "lrate_limit": 0.0001,  # Smallest possible lrate
    "beta_1": 0.99,   # Adam - beta_1
    "beta_2": 0.999,  # Adam - beta_2
    "eps": 1e-8,  # eps constant,
    "init_min": 0.0,  # Range of parameter archive initialization - Min
    "init_max": 0.0,  # Range of parameter archive initialization - Max
    "clip_min": -10,  # Range of parameter proposals - Min
    "clip_max": 10  # Range of parameter proposals - Max
}

In [5]:
from evosax import FitnessShaper
fit_shaper = FitnessShaper(centered_rank=True,
                           z_score=True,
                           w_decay=0.1,
                           maximize=True)

In [6]:
num_generations = 2500
print_every_k_gens = 100
state = strategy.initialize(rng, es_params)

for gen in range(num_generations):
    rng, rng_ask, rng_eval = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_ask, state, es_params)
    reshaped_params = param_reshaper.reshape(x)
    # rollout will pytree w. train_acc, train_loss
    train_loss, train_acc = train_evaluator.rollout(rng_eval, reshaped_params)
    fit_re = fit_shaper.apply(x, train_loss.mean(axis=1))
    state = strategy.tell(x, fit_re, state, es_params)

    if gen % print_every_k_gens == 0:
        # Perform evaluation for best and mean members
        mean_params = state["mean"].reshape(1, -1)
        reshaped_test_params = test_param_reshaper.reshape(mean_params)
        test_loss, test_acc = test_evaluator.rollout(
            rng_eval, reshaped_test_params
        )
        print(f"Generation: {gen} | Train Acc: {train_acc.mean()} | Test Acc: {test_acc.mean()}")



Generation: 0 | Train Acc: 0.10170897841453552 | Test Acc: 0.10279999673366547
Generation: 100 | Train Acc: 0.13413085043430328 | Test Acc: 0.11589999496936798
Generation: 200 | Train Acc: 0.27433592081069946 | Test Acc: 0.274399995803833
Generation: 300 | Train Acc: 0.6929785013198853 | Test Acc: 0.7301999926567078
Generation: 400 | Train Acc: 0.7566698789596558 | Test Acc: 0.7777999639511108
Generation: 500 | Train Acc: 0.7677441239356995 | Test Acc: 0.8165000081062317
Generation: 600 | Train Acc: 0.8138769268989563 | Test Acc: 0.8417999744415283
Generation: 700 | Train Acc: 0.82777339220047 | Test Acc: 0.852899968624115
Generation: 800 | Train Acc: 0.8446386456489563 | Test Acc: 0.8614999651908875
Generation: 900 | Train Acc: 0.8374804258346558 | Test Acc: 0.8689000010490417
Generation: 1000 | Train Acc: 0.8595898151397705 | Test Acc: 0.8758999705314636
Generation: 1100 | Train Acc: 0.8499706983566284 | Test Acc: 0.8810999989509583
Generation: 1200 | Train Acc: 0.8672558665275574 | 