# Benchmarking Predator Prey Coevolution using JAX

This notebook benchmarks Maelstrom with JAX to implement accelerated coevolution on the classic predator-prey problem.

First, install `mealstrom` and `snake-eyes` from PyPi if they are not already installed. Depending on your environment, you may also need to install JAX.

In [1]:
# !pip install maelstrom-evolution snake-eyes-parser
# !pip install jax jaxlib

With installation complete, we can now import `maelstrom` and `snake-eyes` for use within this notebook. In this example, we also import `time` and `random` for timing and seeding, respectively.

In [2]:
from maelstrom import Maelstrom
from snake_eyes import read_config
import time
import random

## Primitive Definition
With `maelstrom` imported, we can now define our strong-typed primitives for this example. For brevity and reusability, these primitives are defined in a separate file and may simply be imported for use in this notebook.

In [3]:
from primitives import *

## Fitness Function Definition
We also need to define a fitness function. Again, we define this function in a separate file, imported in this program, and then referenced from within the configuration file for this example.

In [4]:
from predator_prey import *

## Configuration
Now, we load a configuration file for this example using the `snake-eyes` configuration parser. For transparency, we'll print the content of this file as it provides meaningful insight about how defined constants and functions may be referenced from within the configuratino file.

In [5]:
config = read_config('./configs/maelstrom.cfg', globals(), locals())
for section, params in config.items():
    print(section)
    if params == {}:
        print(f"  {params}")
    else:
        for key, value in params.items():
            print(f"  {key}: {value}")

DEFAULT
  {}
MAELSTROM
  islands: {'1': 'ISLAND'}
  migration_edges: []
  evaluations: 1000
GENERAL
  runs: 2
  pop_size: 10
  num_children: 10
  parent_selection: k_tournament
  k_parent: 10
  mutation: 0.1
  survival_strategy: plus
  survival_selection: truncation
  depth_limit: 5
  hard_limit: 20
  depth_min: 1
  output_type: Angle
ISLAND
  populations: {'predators': 'predators', 'prey': 'prey'}
  evaluation_function: <function round_robin_evaluation at 0x7f15d04203a0>
  evaluation_kwargs: {'predator_move_speed': 0.06, 'prey_move_speed': 0.1, 'agent_radius': 0.1, 'time_limit': 200}
  champions_per_generation: 5
predators
  roles: General
  output_type: Angle
  depth_limit: 5
  hard_limit: 20
  depth_min: 1
  pop_size: 10
  num_children: 10
  parent_selection: k_tournament
  k_parent: 10
  mutation: 0.1
  survival_strategy: plus
  survival_selection: truncation
prey
  roles: General
  output_type: Angle
  depth_limit: 5
  hard_limit: 20
  depth_min: 1
  pop_size: 10
  num_children: 1

## Benchmarking
With everything imported and configured, we will now configure the parameters we intend to vary during benchmarking. In this case, we will define a list of population sizes along with the number of samples we want to collect for each size and a variable to store our data. In this case, the size value is used for both $\mu$ and $\lambda$ in a $(\mu+\lambda)$ performing round-robin competitions at each generation. As such, given the constant evaluation limit of each run of evolution, as population size increases we can expect to see fewer generations and larger competitions performed at each generation. This allows us to meaningfully quantify the impact of repeating this experiment with and without accelerator utilization in JAX.

In [9]:
sizes = [100, 200, 500]
samples = 2
data = {size:[] for size in sizes}

In [10]:
for size in sizes:
    config['predators']['pop_size'] = size
    config['predators']['num_children'] = size
    config['prey']['pop_size'] = size
    config['prey']['num_children'] = size
    for _ in range(samples):
        random.seed(42)
        start = time.time()
        evolver = Maelstrom(**config['MAELSTROM'], **config)
        evolver.run()
        data[size].append(time.time()-start)

  0%|          | 0/1000 [00:00<?, ? evals/s]

  0%|          | 0/1000 [00:00<?, ? evals/s]

  0%|          | 0/1000 [00:00<?, ? evals/s]

  0%|          | 0/1000 [00:00<?, ? evals/s]

  0%|          | 0/1000 [00:00<?, ? evals/s]

Process ForkPoolWorker-184:
Traceback (most recent call last):
  File "/home/seals/anaconda3/envs/publishtest/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/seals/anaconda3/envs/publishtest/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/seals/anaconda3/envs/publishtest/lib/python3.10/multiprocessing/pool.py", line 114, in worker
    task = get()
KeyboardInterrupt
  File "/home/seals/anaconda3/envs/publishtest/lib/python3.10/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
  File "/home/seals/anaconda3/envs/publishtest/lib/python3.10/multiprocessing/queues.py", line 365, in get
    res = self._reader.recv_bytes()
  File "/home/seals/anaconda3/envs/publishtest/lib/python3.10/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/home/seals/anaconda3/envs/publishtest/lib/python3.10/

KeyboardInterrupt: 

In [11]:
from statistics import mean, stdev

for size, values in data.items():
    print(f'population size: {size} (n = {len(values)})')
    print(f'mean time (s): {mean(values)}')
    print(f'standard deviation: {stdev(values)}')

population size: 100 (n = 2)
mean time (s): 155.77811908721924
standard deviation: 13.928719492912036
population size: 200 (n = 2)
mean time (s): 322.87042367458344
standard deviation: 48.829607460772394
population size: 500 (n = 1)
mean time (s): 882.1611678600311


StatisticsError: variance requires at least two data points