# Benchmarking Predator Prey Coevolution using JAX

This notebook benchmarks Maelstrom with JAX to implement accelerated coevolution on the classic predator-prey problem.

First, install `mealstrom` and `snake-eyes` from PyPi if they are not already installed. Depending on your environment, you may also need to install JAX.

In [None]:
# !pip install maelstrom-evolution snake-eyes-parser
# !pip install jax jaxlib

With installation complete, we can now import `maelstrom` and `snake-eyes` for use within this notebook. In this example, we also import `time` and `random` for timing and seeding, respectively.

In [None]:
from maelstrom import Maelstrom
from snake_eyes import read_config
import time
import random

## Primitive Definition
With `maelstrom` imported, we can now define our strong-typed primitives for this example. For brevity and reusability, these primitives are defined in a separate file and may simply be imported for use in this notebook.

In [None]:
from primitives import *

## Fitness Function Definition
We also need to define a fitness function. Again, we define this function in a separate file, imported in this program, and then referenced from within the configuration file for this example.

In [None]:
from predator_prey import *

## Configuration
Now, we load a configuration file for this example using the `snake-eyes` configuration parser. For transparency, we'll print the content of this file as it provides meaningful insight about how defined constants and functions may be referenced from within the configuratino file.

In [None]:
config = read_config('./configs/maelstrom.cfg', globals(), locals())
for section, params in config.items():
    print(section)
    if params == {}:
        print(f"  {params}")
    else:
        for key, value in params.items():
            print(f"  {key}: {value}")

## Benchmarking
With everything imported and configured, we will now configure the parameters we intend to vary during benchmarking. In this case, we will define a list of population sizes along with the number of samples we want to collect for each size and a variable to store our data. In this case, the size value is used for both $\mu$ and $\lambda$ in a $(\mu+\lambda)$ performing round-robin competitions at each generation. As such, given the constant evaluation limit of each run of evolution, as population size increases we can expect to see fewer generations and larger competitions performed at each generation. This allows us to meaningfully quantify the impact of repeating this experiment with and without accelerator utilization in JAX.

In [None]:
sizes = [100, 200, 300]
samples = 2
data = {size:[] for size in sizes}

In [None]:
for size in sizes:
    config['predators']['pop_size'] = size
    config['predators']['num_children'] = size
    config['prey']['pop_size'] = size
    config['prey']['num_children'] = size
    for _ in range(samples):
        random.seed(42)
        start = time.time()
        evolver = Maelstrom(**config['MAELSTROM'], **config)
        evolver.run()
        data[size].append(time.time()-start)

In [None]:
from statistics import mean, stdev

for size, values in data.items():
    print(f'population size: {size} (n = {len(values)})')
    print(f'mean time (s): {mean(values)}')
    print(f'standard deviation: {stdev(values)}')