In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='true'

import jux.utils
from jux.state import State
from jux.config import JuxBufferConfig
import jax
import jax.numpy as jnp

# Benchmark Lux

In [2]:
env, actions = jux.utils.load_replay("https://www.kaggleusercontent.com/episodes/45715004.json")
env.env_cfg.verbose = False
def lux_step(env, actions):
    for i, act in enumerate(actions):
        env.step(act)
    print(i)
%time lux_step(env, actions)

1004
CPU times: user 4.26 s, sys: 0 ns, total: 4.26 s
Wall time: 4.26 s


# Benchmark Jux

In [3]:
# prepare an env
env, actions = jux.utils.load_replay("https://www.kaggleusercontent.com/episodes/45715004.json")
env.env_cfg.verbose = False
while env.env_steps < 100:
    act = next(actions)
    # print(env.env_steps, act)
    env.step(act)

# jit
_state_step_late_game = jax.jit(State._step_late_game)
_state_step_late_game_vmap = jax.jit(jax.vmap(_state_step_late_game))

In [4]:
# config
buf_cfg = JuxBufferConfig(MAX_N_UNITS=1000)
B = 1000

## without vamp

In [5]:
# prepare state and action
jux_state = State.from_lux(env.state, buf_cfg)
jux_act = jux_state.parse_actions_from_dict(act)

# warm up jit
_state_step_late_game(jux_state, jux_act); 

In [6]:
%timeit _state_step_late_game(jux_state, jux_act); jnp.array(0).block_until_ready()

1.81 ms ± 56 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## With vmap

In [7]:
# prepare state and action
if "jux_act_batch" in vars():
    del jux_act_batch
if "jux_state_batch" in vars():
    del jux_state_batch
jux_state_batch = jax.tree_map(lambda x: x[None].repeat(B, axis=0), jux_state)
jux_act_batch = jax.tree_map(lambda x: x[None].repeat(B, axis=0), jux_act)

# warm up jit
_state_step_late_game_vmap(jux_state_batch, jux_act_batch);

In [8]:
%timeit (_state_step_late_game_vmap(jux_state_batch, jux_act_batch), jnp.array(0).block_until_ready())

13.9 ms ± 222 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
# please replace numbers before by above timeit output
cpu_time = 4.26
gpu_time = 13.9

print(f"We obtained {B*cpu_time/gpu_time:.0f}x speedup over native python implementation.")
print(f"The device is {jax.devices()[0].device_kind}, with batch_size = {B}, and MAX_N_UNITS = {buf_cfg.MAX_N_UNITS}")

We obtained 306x speedup over native python implementation.
The device is Tesla V100-SXM2-32GB, with batch_size = 1000, and MAX_N_UNITS = 1000


# Profile

In [None]:
jax.profiler.start_trace("/tmp/tensorboard")
# jax.vmap(State._step_late_game)(jux_state_batch, jux_act_batch)
_state_step_late_game_vmap(jux_state_batch, jux_act_batch)
jnp.array(0).block_until_ready()
jax.profiler.stop_trace()


In [11]:
lowered = _state_step_late_game_vmap.lower(jux_state_batch, jux_act_batch)
compiled = lowered.compile()

In [12]:
cost_analysis = compiled.cost_analysis()
cost_sum = sum(cost_analysis[0].values())
print(f"cost_sum={cost_sum / 2**30:.2f}G")

cost_sum=16.86G


In [13]:
cost = list(cost_analysis[0].items())
cost.sort(key=lambda x: x[1], reverse=True)
cost = cost[:10]
print("Top-10 costly operators:")
[(n, c/2**30) for n, c in cost]

Top-10 costly operators:


[('flops', 5.457577705383301),
 ('bytes accessed', 5.432867527008057),
 ('bytes accessed output {}', 1.4869519472122192),
 ('bytes accessed operand 4 {}', 0.6276771426200867),
 ('bytes accessed operand 0 {}', 0.5038005113601685),
 ('bytes accessed operand 2 {}', 0.48835527896881104),
 ('bytes accessed operand 1 {}', 0.47469067573547363),
 ('bytes accessed operand 5 {}', 0.4347339868545532),
 ('bytes accessed operand 3 {}', 0.27461785078048706),
 ('bytes accessed operand 6 {}', 0.2414938062429428)]