# Predator Prey Coevolution using JAX

This example demonstrates how one may use Maelstrom with JAX to implement accelerated coevolution on the classic predator-prey problem.

First, install `mealstrom` and `snake-eyes` from PyPi if they are not already installed. Depending on your environment, you may also need to install JAX.

In [None]:
# !pip install maelstrom-evolution snake-eyes-parser
# !pip install jax jaxlib

With installation complete, we can now import `maelstrom` and `snake-eyes` for use within this notebook.

In [1]:
from maelstrom import Maelstrom
from snake_eyes import read_config

## Primitive Definition
With `maelstrom` imported, we can now define our strong-typed primitives for this example. For brevity and reusability, these primitives are defined in a separate file and may simply be imported for use in this notebook.

In [2]:
from primitives import *

## Fitness Function Definition
We also need to define a fitness function. Again, we define this function in a separate file, imported in this program, and then referenced from within the configuration file for this example.

In [3]:
from predator_prey import round_robin_evaluation

## Configuration
Now, we load a configuration file for this example using the `snake-eyes` configuration parser. For transparency, we'll print the content of this file as it provides meaningful insight about how defined constants and functions may be referenced from within the configuratino file.

In [4]:
config = read_config('./configs/maelstrom.cfg', globals(), locals())
for section, params in config.items():
    print(section)
    if params == {}:
        print(f"  {params}")
    else:
        for key, value in params.items():
            print(f"  {key}: {value}")

DEFAULT
  {}
MAELSTROM
  islands: {'1': 'ISLAND'}
  migration_edges: []
  evaluations: 5000
GENERAL
  runs: 2
  pop_size: 10
  num_children: 10
  parent_selection: k_tournament
  k_parent: 10
  mutation: 0.1
  survival_strategy: plus
  survival_selection: truncation
  depth_limit: 5
  hard_limit: 20
  depth_min: 1
  output_type: Angle
ISLAND
  populations: {'predators': 'predators', 'prey': 'prey'}
  evaluation_function: <function round_robin_evaluation at 0x7f27eb9f7ac0>
  evaluation_kwargs: {'predator_move_speed': 0.06, 'prey_move_speed': 0.1, 'agent_radius': 0.1, 'time_limit': 200}
  champions_per_generation: 5
predators
  roles: General
  output_type: Angle
  depth_limit: 5
  hard_limit: 20
  depth_min: 1
  pop_size: 10
  num_children: 10
  parent_selection: k_tournament
  k_parent: 10
  mutation: 0.1
  survival_strategy: plus
  survival_selection: truncation
prey
  roles: General
  output_type: Angle
  depth_limit: 5
  hard_limit: 20
  depth_min: 1
  pop_size: 10
  num_children: 1

## Experiment
With everything imported and configured, we can now instantiate a `Maelstrom` using parameters from the configuration file and run an experiment.

In [5]:
evolver = Maelstrom(**config['MAELSTROM'], **config)
evolver.run()

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


  0%|          | 0/5000 [00:00<?, ? evals/s]

<maelstrom.Maelstrom at 0x7f27eb9fb550>

## Experiment Log
Note that the fitness function used in this example returned data that was automatically stored in the `Maelstrom` instance and organized by island.

In [6]:
for island, log in evolver.log.items():
    print(f'Island {island}')
    for key, value in log.items():
        print(key)
        print(value)

Island 1
avg_predator_fitness
[Array(58.609993, dtype=float32), Array(95.4525, dtype=float32), Array(136.29251, dtype=float32), Array(155.98251, dtype=float32), Array(169.7975, dtype=float32), Array(160.43002, dtype=float32), Array(176.30003, dtype=float32), Array(167.69748, dtype=float32), Array(173.60004, dtype=float32), Array(169.39996, dtype=float32), Array(165.16753, dtype=float32), Array(157.03748, dtype=float32), Array(167.33253, dtype=float32), Array(168.86996, dtype=float32)]
max_predator_fitness
[Array(167.9, dtype=float32), Array(174.55, dtype=float32), Array(182.45, dtype=float32), Array(183.35, dtype=float32), Array(182., dtype=float32), Array(179.35, dtype=float32), Array(176.3, dtype=float32), Array(175.65, dtype=float32), Array(173.6, dtype=float32), Array(173.4, dtype=float32), Array(173.35, dtype=float32), Array(173.65, dtype=float32), Array(175.3, dtype=float32), Array(172.45, dtype=float32)]
avg_prey_fitness
[Array(141.39, dtype=float32), Array(104.54751, dtype=floa