# 03 - Evolving a MNIST CNN with OpenES [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/03_cnn_mnist.ipynb)

## Installation

You will need Python 3.10 or later, and a working JAX installation. For example, you can install JAX on NVIDIA GPU with:

In [None]:
%pip install -U "jax[cuda]"

Then, install `evosax` from PyPi:

In [None]:
%pip install -U "evosax[examples]"

In [None]:
import jax
import jax.numpy as jnp
from evosax import NetworkMapper, OpenES, ParameterReshaper
from evosax.problems import VisionProblem

rng = jax.random.key(0)

# The CNN architecture uses two differently sized (kernel etc.) conv blocks
network = NetworkMapper["CNN"](
    depth_1=1,  # number of conv layers in block 1
    depth_2=1,  # number of conv layers in block 2
    features_1=8,
    features_2=16,
    kernel_1=5,
    kernel_2=5,
    strides_1=1,
    strides_2=1,
    num_linear_layers=0,
    num_output_units=10,
)
solution = jnp.zeros((1, 28, 28, 1))
solution = network.init(
    rng,
    x=solution,
    rng=rng,
)

# Note: ParameterReshaper automatically detects if multiple devices are available
# and uses pmap to reshape raw parameter vectors onto different devices
# For testing (mean and best so far), on the other hand, we always only use a single device
param_reshaper = ParameterReshaper(solution)
test_param_reshaper = ParameterReshaper(solution, n_devices=1)

In [None]:
# Set up the dataloader for batch evaluations (may take a sec)
train_problem = VisionProblem("MNIST", batch_size=1024, test=False)
test_problem = VisionProblem("MNIST", batch_size=10000, test=True, n_devices=1)

train_problem.set_apply_fn(network.apply)
test_problem.set_apply_fn(network.apply)

In [3]:
strategy = OpenES(
    population_size=100, num_dims=param_reshaper.total_params, opt_name="adam"
)
# Update basic parameters of PGPE strategy
params = strategy.default_params

In [4]:
from evosax import FitnessShaper

fit_shaper = FitnessShaper(
    centered_rank=True, z_score=False, w_decay=0.1, maximize=False
)

In [None]:
num_generations = 2500
print_every_k_gens = 100
state = strategy.init(rng, params)

for generation in range(num_generations):
    rng, rng_ask, rng_eval = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_ask, state, params)
    reshaped_params = param_reshaper.reshape(x)
    # Evaluate will pytree w. train_acc, train_loss
    train_loss, train_acc = train_problem.eval(rng_eval, reshaped_params)
    fit_re = fit_shaper.apply(x, train_loss.mean(axis=1))
    state = strategy.tell(x, fit_re, state, params)

    if generation % print_every_k_gens == 0:
        # Perform evaluation for best and mean members
        mean_params = state.mean.reshape(1, -1)
        reshaped_test_params = test_param_reshaper.reshape(mean_params)
        test_loss, test_acc = test_problem.eval(rng_eval, reshaped_test_params)
        print(
            f"Generation: {generation} | Train Acc: {train_acc.mean()} | Test Acc: {test_acc.mean()}"
        )