In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'
# os.environ['CUDA_VISIBLE_DEVICES']='0'

import numpy as np
import pandas as pd
from tqdm import tqdm

import time
from typing import Callable

import traceback
import jux
import jux.utils
from jux.state import State
from jux.config import JuxBufferConfig
import jax
import jax.numpy as jnp
import warnings
import chex
from tqdm.auto import tqdm

warnings.filterwarnings('default')

def cuda_sync():
    (jax.device_put(0.) + 0).block_until_ready()
    return


def timeit(func: Callable, setup=lambda: None, number=100, finalize=lambda: None, repeat=1):
    exe_time = []
    for _ in range(repeat):
        setup()
        start = time.perf_counter()
        for _ in range(number):
            func()
        finalize()
        end = time.perf_counter()
        exe_time.append((end - start) / number)
    return np.mean(exe_time)

def mem_size(pytree):
    size = sum(a.size * a.dtype.itemsize for a in jax.tree_util.tree_leaves(pytree))
    if size < 1024:
        return f"{size}B"
    elif size < 1024 ** 2:
        return f"{size / 1024:.2f}KB"
    elif size < 1024 ** 3:
        return f"{size / 1024 ** 2:.2f}MB"
    else:
        return f"{size / 1024 ** 3:.2f}GB"

In [2]:
# test_id = ['45715004']
test_id = [
    '45715004', '45777510', '45779101', '45780455', '45780520', '45780686', '45781606', '45780751', '45780882',
    '45781046', '45781047', '45781050', '45781208', '45781608', '45781677', '45781212', '45781214', '45780845',
    '45781375', '45785597'
]
test_url = [f"https://www.kaggleusercontent.com/episodes/{id}.json" for id in test_id]

# Benchmark Lux

In [3]:
cpu_time_exec = list()
def lux_step(env, actions):
    time0 = time.perf_counter()
    for i, act in enumerate(actions):
        env.step(act)
    return (time.perf_counter() - time0)/(i+1)

for url in test_url:
    env, actions = jux.utils.load_replay(url)
    env.env_cfg.verbose = False
    cpu_time_exec.append(lux_step(env, actions))
cpu_mean_time = np.mean(cpu_time_exec)
print(f"{cpu_mean_time = }")

cpu_mean_time = 0.0036133938854050842


# Benchmark JUX

In [4]:
# prepare an env
envs = []
acts = []
N_prepare = 100
for url in test_url:
    env, actions = jux.utils.load_replay(url)
    env.env_cfg.verbose = False
    while env.env_steps < N_prepare:
        act = next(actions)
        env.step(act)
    envs.append(env)
    act = next(actions)
    acts.append(act)

# jit
_state_step_late_game = jax.jit(chex.assert_max_traces(State._step_late_game, n=1))
_state_step_late_game_vmap = jax.jit(jax.vmap(_state_step_late_game))

In [5]:
# config
max_n_units_range = [100] + [200*i for i in range(1, 6)]
batch_size_range = [1, 100, 1000, 5000, 10000, 20000]

## without vamp

In [6]:
jit_table = []
def jit_record(buf_cfg):
    time_record = list()
    N = 100
    for (env, act, id) in zip(envs, acts, test_id):
        # print(f"{id = }, {buf_cfg.MAX_N_UNITS = }")
        # prepare state and action
        # try:
        jux_state = State.from_lux(env.state, buf_cfg)
        jux_act = jux_state.parse_actions_from_dict(act)
        unit_jit = lambda *_: _state_step_late_game(jux_state, jux_act)
        chex.clear_trace_counter()
        exe_time = timeit(func=unit_jit, setup=unit_jit, number=N, finalize=cuda_sync)
        time_record.append(exe_time)
        # except:
        #     traceback.print_exc()
        #     print(f"{id = }")
    # print(f"{mem_size(jux_act) = }, {mem_size(jux_state)=}")

    return np.mean(time_record)
    
for n_units in tqdm(max_n_units_range):
    jit_table.append(jit_record(JuxBufferConfig(MAX_N_UNITS=n_units)))
jit_mean = np.mean(jit_table)
print(f"{jit_mean = }")

  0%|          | 0/6 [00:00<?, ?it/s]

jit_mean = 0.0014724215596409827


In [7]:
for n_agents, t in zip(max_n_units_range, jit_table):
    print(f"MAX_N_UNITS={n_agents}, t={t*1000:.3f}ms")

MAX_N_UNITS=100, t=1.457ms
MAX_N_UNITS=200, t=1.439ms
MAX_N_UNITS=400, t=1.485ms
MAX_N_UNITS=600, t=1.498ms
MAX_N_UNITS=800, t=1.484ms
MAX_N_UNITS=1000, t=1.472ms


## With vmap

In [None]:
vmap_table = pd.DataFrame(columns=[f"UNITS_{n_units}" for n_units in max_n_units_range], index=batch_size_range)
def vmap_record(buf_cfg, B):
    time_record = list()
    N = 100
    # for (env, act, id) in zip(envs, acts, tqdm(test_id)):
    # prepare state and action
    env, act = envs[1], acts[1]
    jux_state = State.from_lux(env.state, buf_cfg)
    jux_act = jux_state.parse_actions_from_dict(act)
    if 'jux_state_batch' in vars():
        del jux_state_batch
    if 'jux_act_batch' in vars():
        del jux_act_batch
    jux_state_batch = jax.tree_map(lambda x: x[None].repeat(B, axis=0), jux_state)
    jux_act_batch = jax.tree_map(lambda x: x[None].repeat(B, axis=0), jux_act)
    # print(f"{mem_size(jux_state_batch) = }, {mem_size(jux_act_batch)=}")

    unit_jit_vmap = lambda: _state_step_late_game_vmap(jux_state_batch, jux_act_batch)
    chex.clear_trace_counter()
    exe_time = timeit(func=unit_jit_vmap, setup=unit_jit_vmap, number=N, finalize=cuda_sync)
    time_record.append(exe_time)
    return np.mean(time_record)

for MAX_N_UNITS in tqdm(max_n_units_range, desc="MAX_N_UNITS loop progress:", position=0):
    for B in tqdm(batch_size_range, desc="batch size loop progress:", position=1):
        buf_cfg_name = f"UNITS_{MAX_N_UNITS}"
        # print(f"{MAX_N_UNITS = }, {B = }")
        buf_cfg = JuxBufferConfig(MAX_N_UNITS=MAX_N_UNITS)
        vmap_table[buf_cfg_name][B] = vmap_record(buf_cfg, B)

In [10]:
vmap_table

Unnamed: 0,UNITS_100,UNITS_200,UNITS_400,UNITS_600,UNITS_800,UNITS_1000
1,0.001411,0.001431,0.001397,0.0015,0.001408,0.001427
100,0.001665,0.001748,0.00174,0.002034,0.002129,0.002165
1000,0.005962,0.006172,0.008174,0.009857,0.011653,0.014039
5000,0.021775,0.025962,0.035247,0.044279,0.052299,0.061447
10000,0.041783,0.050537,0.067829,0.085996,0.102093,0.119915
20000,0.08195,0.099187,0.13389,0.17101,0.202751,0.238558


In [11]:
cpu_mean_time * (np.array(batch_size_range).reshape(-1, 1) /vmap_table)

Unnamed: 0,UNITS_100,UNITS_200,UNITS_400,UNITS_600,UNITS_800,UNITS_1000
1,2.561007,2.525758,2.586163,2.409614,2.565711,2.532051
100,217.019545,206.704587,207.660212,177.668597,169.712173,166.866293
1000,606.089101,585.438343,442.074918,366.575051,310.086621,257.377074
5000,829.694111,695.908349,512.583774,408.027223,345.456292,294.023138
10000,864.791769,714.993644,532.720617,420.179719,353.930721,301.330273
20000,881.855188,728.603563,539.756962,422.595193,356.436337,302.93567


In [12]:
print(jax.devices()[0].device_kind)

Tesla V100-SXM2-32GB
