In [1]:
!ls
from google.colab import drive
drive.mount('/content/gdrive')
!git clone https://username:password@github.com/RobertTLange/evosax.git

gdrive	sample_data
Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [2]:
%cd gdrive/My Drive/evosax
! pip install -e .

/content/gdrive/My Drive/evo_benchmark/evosax
Obtaining file:///content/gdrive/My%20Drive/evo_benchmark/evosax
Installing collected packages: evosax
  Found existing installation: evosax 0.0.1
    Can't uninstall 'evosax'. No files were found to uninstall.
  Running setup.py develop for evosax
Successfully installed evosax


In [3]:
!pip install jax==0.2.4
!pip install jaxlib==0.1.57
!pip install commentjson



In [4]:
import requests
import os
import numpy as np
from jax.config import config

# TPU JAX Colab Setup

In [5]:
TPU_DRIVER_MODE = 0

def setup_tpu():
    """Sets up Colab to run on TPU.
    Note: make sure the Colab Runtime is set to Accelerator: TPU.
    """
    global TPU_DRIVER_MODE

    if not TPU_DRIVER_MODE:
        colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
        url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver_nightly'
        requests.post(url)
        TPU_DRIVER_MODE = 1

    # The following is required to use TPU Driver as JAX's backend.
    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

setup_tpu()

In [6]:
import jax
jax.devices()

[TpuDevice(id=0, host_id=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, host_id=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, host_id=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, host_id=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, host_id=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, host_id=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, host_id=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, host_id=0, coords=(1,1,0), core_on_chip=1)]

# Import benchmark function from script

In [7]:
from evosax.strategies.cma_es import init_strategy, ask, tell
from evosax.utils import flat_to_mlp

In [8]:
import sys
sys.path.append('/content/gdrive/My Drive/evo_benchmark/evosax/examples')
from ffw_pendulum import generation_rollout

In [12]:
import commentjson
def load_config(config_fname: str):
    """ Load in a config JSON file and return as a dictionary """
    json_config = commentjson.loads(open(config_fname, 'r').read())
    dict_config = DotDic(json_config)

    # Make inner dictionaries indexable like a class
    for key, value in dict_config.items():
        if isinstance(value, dict):
            dict_config[key] = DotDic(value)
    return dict_config


class DotDic(dict):
    """
    a dictionary that supports dot notation
    as well as dictionary access notation
    usage: d = DotDict() or d = DotDict({'val1':'first'})
    set attributes: d.val2 = 'second' or d['val2'] = 'second'
    get attributes: d.val2 or d['val2']
    """
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

    def __deepcopy__(self, memo=None):
        return DotDic(copy.deepcopy(dict(self), memo=memo))

    def __init__(self, dct):
        for key, value in dct.items():
            if hasattr(value, 'keys'):
                value = DotDic(value)
            self[key] = value


In [10]:
def benchmark_accelerator(num_evaluations, population_size, hidden_size,
                           net_config, train_config, log_config, use_jit=True,
                           use_tpu=False):
    """ Evaluate speed for different population sizes run. """
    # Want to eval gains from acc over different population size + archs
    train_config.pop_size = population_size
    train_config.hidden_size = hidden_size

    # Start by setting the random seeds for reproducibility
    rng = jax.random.PRNGKey(train_config.seed_id)

    # Define logger, CMA-ES strategy
    net_config.network_size[1] = int(train_config.hidden_size)
    num_params = (net_config.network_size[0] * net_config.network_size[1]
                  + net_config.network_size[1]
                  + net_config.network_size[1]*net_config.network_size[2]
                  + net_config.network_size[2])
    mean_init = jnp.zeros(num_params)
    elite_size = int(train_config.pop_size * train_config.elite_percentage)

    generation_times = []
    for eval in range(num_evaluations):
        es_params, es_memory = init_strategy(mean_init,
                                           train_config.sigma_init,
                                           train_config.pop_size,
                                           elite_size)
        # Only track time for actual generation ask-tell inference
        start_t = time.time()

        # Train the network using the training loop
        run_single_generation(rng, elite_size,
                              train_config.num_evals_per_gen,
                              train_config.num_env_steps,
                              net_config.network_size,
                              dict(train_config.env_params),
                              es_params, es_memory, use_jit, use_tpu)

        # Save wall-clock time for evaluation
        if eval > 0:
            generation_times.append(time.time() - start_t)
        else:
            jit_time = time.time() - start_t
    return np.mean(generation_times), np.std(generation_times), jit_time


def run_single_generation(rng, elite_size, num_evals_per_gen,
                          num_env_steps, network_size, env_params, es_params,
                          es_memory, use_jit=True, use_tpu=False):
    """ Run the training loop over a set of epochs. """
    # Loop over different generations and search!
    rng, rng_input = jax.random.split(rng)
    x, es_memory = ask(rng_input, es_memory, es_params)
    generation_params = flat_to_mlp(x, sizes=network_size)

    # Evaluate the fitness of the generation members
    rng, rng_input = jax.random.split(rng)
    rollout_keys = jax.random.split(rng_input, num_evals_per_gen)

    population_returns = generation_rollout(rollout_keys,
                                            generation_params,
                                            env_params, num_env_steps)

    values = - population_returns.mean(axis=1)

    # Update the CMA-ES strategy
    es_memory = tell(x, values, elite_size, es_params, es_memory)
    return


In [21]:
import jax.numpy as jnp
import time
save_fname = "benchmarks/results/tpu_speed_no_jit"

print(f"JAX device: {jax.devices()}, {save_fname}")
config = load_config("examples/cma_config.json")
train_config, net_config, log_config = (config.train_config,
                                        config.net_config,
                                        config.log_config)
num_evaluations = 1 + 1  # Dont use first - used for compilation
population_sizes = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
#network_sizes = [16, 48, 80, 112, 144]
network_sizes = [48]
store_times = np.zeros((3, len(network_sizes), len(population_sizes)))
for i, hidden_size in enumerate(network_sizes):
    for j, pop_size in enumerate(population_sizes):
        mean, std, jit_t = benchmark_accelerator(num_evaluations, pop_size, hidden_size,
                                                  net_config, train_config, log_config,
                                                  True, False)
        # Jitted time, mean, std
        store_times[0, len(network_sizes)-1-i, j] = jit_t
        store_times[1, len(network_sizes)-1-i, j] = mean
        store_times[2, len(network_sizes)-1-i, j] = std
        print(hidden_size, pop_size, mean, std, jit_t)
np.save(save_fname, store_times)


JAX device: [TpuDevice(id=0, host_id=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, host_id=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, host_id=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, host_id=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, host_id=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, host_id=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, host_id=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, host_id=0, coords=(1,1,0), core_on_chip=1)], benchmarks/results/tpu_speed_no_jit
48 100 1.6559696197509766 0.0 1.6062541007995605
48 200 1.730271339416504 0.0 1.6364848613739014
48 300 1.9260773658752441 0.0 8.203169107437134
48 400 1.8260862827301025 0.0 11.239391803741455
48 500 1.945971965789795 0.0 15.34292984008789
48 600 2.058655023574829 0.0 19.204785346984863
48 700 2.101370096206665 0.0 23.883458852767944
48 800 2.2158946990966797 0.0 30.245986223220825
48 900 2.4260644912719727 0.0 38.08869552612305
48 1000 2.586442232131958 0.0 45.572