# 03 - Evolving a MNIST CNN with OpenES
### [Last Update: February 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/03_cnn_mnist.ipynb)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

!pip install -q git+https://github.com/RobertTLange/evosax.git@main
!pip install -q torch torchvision

In [1]:
import jax
import jax.numpy as jnp
from evosax import NetworkMapper, OpenES, ParameterReshaper
from evosax.problems import VisionFitness

rng = jax.random.key(0)

# The CNN architecture uses two differently sized (kernel etc.) conv blocks
network = NetworkMapper["CNN"](
    depth_1=1,  # number of conv layers in block 1
    depth_2=1,  # number of conv layers in block 2
    features_1=8,
    features_2=16,
    kernel_1=5,
    kernel_2=5,
    strides_1=1,
    strides_2=1,
    num_linear_layers=0,
    num_output_units=10,
)
pholder = jnp.zeros((1, 28, 28, 1))
params = network.init(
    rng,
    x=pholder,
    rng=rng,
)

# Note: ParameterReshaper automatically detects if multiple devices are available
# and uses pmap to reshape raw parameter vectors onto different devices
# For testing (mean and best so far), on the other hand, we always only use a single device
param_reshaper = ParameterReshaper(params)
test_param_reshaper = ParameterReshaper(params, n_devices=1)

ParameterReshaper: 11274 parameters detected for optimization.
ParameterReshaper: 11274 parameters detected for optimization.


In [2]:
# Set up the dataloader for batch evaluations (may take a sec)
train_evaluator = VisionFitness("MNIST", batch_size=1024, test=False)
test_evaluator = VisionFitness("MNIST", batch_size=10000, test=True, n_devices=1)

train_evaluator.set_apply_fn(network.apply)
test_evaluator.set_apply_fn(network.apply)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /Users/rob/data/MNIST/raw/train-images-idx3-ubyte.gz


9913344it [00:00, 37979258.61it/s]                             


Extracting /Users/rob/data/MNIST/raw/train-images-idx3-ubyte.gz to /Users/rob/data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /Users/rob/data/MNIST/raw/train-labels-idx1-ubyte.gz


29696it [00:00, 35135134.44it/s]         

Extracting /Users/rob/data/MNIST/raw/train-labels-idx1-ubyte.gz to /Users/rob/data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /Users/rob/data/MNIST/raw/t10k-images-idx3-ubyte.gz


1649664it [00:00, 24493843.31it/s]         

Extracting /Users/rob/data/MNIST/raw/t10k-images-idx3-ubyte.gz to /Users/rob/data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /Users/rob/data/MNIST/raw/t10k-labels-idx1-ubyte.gz


5120it [00:00, 13190931.50it/s]         


Extracting /Users/rob/data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /Users/rob/data/MNIST/raw



In [3]:
strategy = OpenES(
    population_size=100, num_dims=param_reshaper.total_params, opt_name="adam"
)
# Update basic parameters of PGPE strategy
es_params = strategy.default_params

In [4]:
from evosax import FitnessShaper

fit_shaper = FitnessShaper(
    centered_rank=True, z_score=False, w_decay=0.1, maximize=False
)

In [5]:
num_generations = 2500
print_every_k_gens = 100
state = strategy.initialize(rng, es_params)

for gen in range(num_generations):
    rng, rng_ask, rng_eval = jax.random.split(rng, 3)
    x, state = strategy.ask(rng_ask, state, es_params)
    reshaped_params = param_reshaper.reshape(x)
    # rollout will pytree w. train_acc, train_loss
    train_loss, train_acc = train_evaluator.rollout(rng_eval, reshaped_params)
    fit_re = fit_shaper.apply(x, train_loss.mean(axis=1))
    state = strategy.tell(x, fit_re, state, es_params)

    if gen % print_every_k_gens == 0:
        # Perform evaluation for best and mean members
        mean_params = state.mean.reshape(1, -1)
        reshaped_test_params = test_param_reshaper.reshape(mean_params)
        test_loss, test_acc = test_evaluator.rollout(rng_eval, reshaped_test_params)
        print(
            f"Generation: {gen} | Train Acc: {train_acc.mean()} | Test Acc: {test_acc.mean()}"
        )

Generation: 0 | Train Acc: 0.10171874612569809 | Test Acc: 0.10279999673366547


KeyboardInterrupt: 