In [None]:
!pip install jax==0.2.4
!pip install jaxlib==0.1.57

In [8]:
import requests
import os
import numpy as np
from jax.config import config

# TPU JAX Colab Setup

In [None]:
TPU_DRIVER_MODE = 0

def setup_tpu():
    """Sets up Colab to run on TPU.
    Note: make sure the Colab Runtime is set to Accelerator: TPU.
    """
    global TPU_DRIVER_MODE

    if not TPU_DRIVER_MODE:
        colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
        url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver_nightly'
        requests.post(url)
        TPU_DRIVER_MODE = 1

    # The following is required to use TPU Driver as JAX's backend.
    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

setup_tpu()

In [9]:
import jax
jax.devices()

[CpuDevice(id=0)]

# Import benchmark function from script

In [10]:
from speed_accelerator import load_config, benchmark_accelerator

In [None]:
use_tpu = False
save_fname = "results/tpu_speed_no_jit"

print(f"JAX device: {jax.devices()}, {save_fname}")
config = load_config("../examples/configs/train/cma_config.json")
train_config, net_config, log_config = (config.train_config,
                                        config.net_config,
                                        config.log_config)
num_evaluations = 1000 + 1  # Dont use first - used for compilation
population_sizes = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
#network_sizes = [16, 48, 80, 112, 144]
network_sizes = [48]
store_times = np.zeros((3, len(population_sizes), len(network_sizes)))
for i, hidden_size in enumerate(network_sizes):
    for j, pop_size in enumerate(population_sizes):
        mean, std, jit_t = benchmark_accelerator(num_evaluations, pop_size, hidden_size,
                                                  net_config, train_config, log_config,
                                                  True, use_tpu)
        # Jitted time, mean, std
        store_times[0, len(network_sizes)-1-i, j] = jit_t
        store_times[1, len(network_sizes)-1-i, j] = mean
        store_times[2, len(network_sizes)-1-i, j] = std
        print(hidden_size, pop_size, mean, std, jit_t)
np.save(save_fname, store_times)


JAX device: [CpuDevice(id=0)], results/tpu_speed_no_jit
