# 03 - Vision - Evolving CNN [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/03_vision.ipynb)

## Installation

You will need Python 3.10 or later, and a working JAX installation. For example, you can install JAX on NVIDIA GPU with:

In [None]:
%pip install -U "jax[cuda]"

Then, install `evosax` from PyPi:

In [None]:
%pip install -U "evosax[examples]"

## Import

In [157]:
import jax
import optax

In [2]:
seed = 0
key = jax.random.key(seed)

## MNIST

In [10]:
from evosax.problems import CNN, identity_output_fn
from evosax.problems import TorchVisionProblem as Problem

network = CNN(
    num_filters=[8, 16],
    kernel_sizes=[(5, 5), (5, 5)],
    strides=[(1, 1), (1, 1)],
    mlp_layer_sizes=[10],
    output_fn=identity_output_fn,
)

problem = Problem(
    task_name="MNIST",
    network=network,
    batch_size=1024,
)

key, subkey = jax.random.split(key)
problem_state = problem.init(key)

key, subkey = jax.random.split(key)
solution = problem.sample(subkey)

In [4]:
print(f"Number of pararmeters: {sum(leaf.size for leaf in jax.tree.leaves(solution))}")

Number of pararmeters: 128874


In [5]:
from evosax.algorithms import Open_ES as ES

num_generations = 512
lr_schedule = optax.exponential_decay(
    init_value=0.01,
    transition_steps=num_generations,
    decay_rate=0.1,
)
std_schedule = optax.exponential_decay(
    init_value=0.05,
    transition_steps=num_generations,
    decay_rate=0.2,
)
es = ES(
    population_size=256,
    solution=solution,
    optimizer=optax.adam(learning_rate=lr_schedule),
    std_schedule=std_schedule,
)

params = es.default_params

### Run

In [6]:
def step(carry, key):
    state, params, problem_state = carry
    key_ask, key_eval, key_tell = jax.random.split(key, 3)

    population, state = es.ask(key_ask, state, params)

    fitness, problem_state, _ = problem.eval(key_eval, population, problem_state)

    state, metrics = es.tell(
        key_tell, population, fitness, state, params
    )  # Minimize fitness

    return (state, params, problem_state), metrics

In [9]:
key, subkey = jax.random.split(key)
state = es.init(subkey, solution, params)

fitness_log = []
log_period = 32
for i in range(num_generations // log_period):
    key, subkey = jax.random.split(key)
    keys = jax.random.split(subkey, log_period)
    (state, params, problem_state), metrics = jax.lax.scan(
        step,
        (state, params, problem_state),
        keys,
    )

    mean = es.get_mean(state)

    key, subkey = jax.random.split(key)
    fitness, problem_state, info = problem.eval_test(
        key, jax.tree.map(lambda x: x[None], mean), problem_state
    )
    print(
        f"Generation {(i + 1) * log_period:03d} | Mean fitness: {fitness.mean():.2f} | Accuracy: {info['accuracy'].mean():.2f}"
    )

Generation 032 | Mean fitness: 0.90 | Accuracy: 0.70
Generation 064 | Mean fitness: 0.34 | Accuracy: 0.89
Generation 096 | Mean fitness: 0.26 | Accuracy: 0.92
Generation 128 | Mean fitness: 0.18 | Accuracy: 0.95
Generation 160 | Mean fitness: 0.18 | Accuracy: 0.95
Generation 192 | Mean fitness: 0.14 | Accuracy: 0.96
Generation 224 | Mean fitness: 0.16 | Accuracy: 0.96
Generation 256 | Mean fitness: 0.15 | Accuracy: 0.96
Generation 288 | Mean fitness: 0.14 | Accuracy: 0.96
Generation 320 | Mean fitness: 0.15 | Accuracy: 0.96
Generation 352 | Mean fitness: 0.14 | Accuracy: 0.96
Generation 384 | Mean fitness: 0.12 | Accuracy: 0.97
Generation 416 | Mean fitness: 0.12 | Accuracy: 0.96
Generation 448 | Mean fitness: 0.10 | Accuracy: 0.97
Generation 480 | Mean fitness: 0.11 | Accuracy: 0.97
Generation 512 | Mean fitness: 0.11 | Accuracy: 0.97


## CIFAR-10

In [151]:
from evosax.problems import CNN, identity_output_fn
from evosax.problems import TorchVisionProblem as Problem

network = CNN(
    num_filters=[8, 16],
    kernel_sizes=[(5, 5), (5, 5)],
    strides=[(2, 2), (2, 2)],
    mlp_layer_sizes=[128, 10],
    output_fn=identity_output_fn,
)

problem = Problem(
    task_name="CIFAR10",
    network=network,
    batch_size=1024,
)

key, subkey = jax.random.split(key)
problem_state = problem.init(key)

key, subkey = jax.random.split(key)
solution = problem.sample(subkey)

In [152]:
print(f"Number of pararmeters: {sum(leaf.size for leaf in jax.tree.leaves(solution))}")

Number of pararmeters: 136314


In [153]:
from evosax.algorithms import Open_ES as ES

num_generations = 512
lr_schedule = optax.exponential_decay(
    init_value=0.01,
    transition_steps=num_generations,
    decay_rate=0.1,
)
std_schedule = optax.exponential_decay(
    init_value=0.05,
    transition_steps=num_generations,
    decay_rate=0.1,
)
es = ES(
    population_size=512,
    solution=solution,
    optimizer=optax.adam(learning_rate=lr_schedule),
    std_schedule=std_schedule,
)

params = es.default_params

### Run

In [154]:
def step(carry, key):
    state, params, problem_state = carry
    key_ask, key_eval, key_tell = jax.random.split(key, 3)

    population, state = es.ask(key_ask, state, params)

    fitness, problem_state, _ = problem.eval(key_eval, population, problem_state)

    state, metrics = es.tell(
        key_tell, population, fitness, state, params
    )  # Minimize fitness

    return (state, params, problem_state), metrics

In [155]:
key, subkey = jax.random.split(key)
state = es.init(subkey, solution, params)

In [156]:
fitness_log = []
log_period = 32
for i in range(num_generations // log_period):
    key, subkey = jax.random.split(key)
    keys = jax.random.split(subkey, log_period)
    (state, params, problem_state), metrics = jax.lax.scan(
        step,
        (state, params, problem_state),
        keys,
    )

    mean = es.get_mean(state)

    key, subkey = jax.random.split(key)
    fitness, problem_state, info = problem.eval_test(
        key, jax.tree.map(lambda x: x[None], mean), problem_state
    )
    print(
        f"Generation {(i + 1) * log_period:03d} | Mean fitness: {fitness.mean():.2f} | Accuracy: {info['accuracy'].mean():.2f}"
    )

Generation 032 | Mean fitness: 2.07 | Accuracy: 0.25
Generation 064 | Mean fitness: 1.88 | Accuracy: 0.32
Generation 096 | Mean fitness: 1.80 | Accuracy: 0.33
Generation 128 | Mean fitness: 1.72 | Accuracy: 0.37
Generation 160 | Mean fitness: 1.70 | Accuracy: 0.41
Generation 192 | Mean fitness: 1.67 | Accuracy: 0.39
Generation 224 | Mean fitness: 1.66 | Accuracy: 0.39
Generation 256 | Mean fitness: 1.63 | Accuracy: 0.41
Generation 288 | Mean fitness: 1.67 | Accuracy: 0.39
Generation 320 | Mean fitness: 1.60 | Accuracy: 0.41
Generation 352 | Mean fitness: 1.60 | Accuracy: 0.42
Generation 384 | Mean fitness: 1.57 | Accuracy: 0.44
Generation 416 | Mean fitness: 1.60 | Accuracy: 0.40
Generation 448 | Mean fitness: 1.55 | Accuracy: 0.44
Generation 480 | Mean fitness: 1.58 | Accuracy: 0.44
Generation 512 | Mean fitness: 1.56 | Accuracy: 0.41
