In [39]:
import os
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'

import jux
from rich import print
from jux.state import State
from jux.config import JuxBufferConfig
import jax
import jax.numpy as jnp

# Benchmark Lux

In [40]:
env, actions = jux.utils.load_replay("https://www.kaggleusercontent.com/episodes/45715004.json")
env.env_cfg.verbose = False
def lux_step(env, actions):
    for i, act in enumerate(actions):
        env.step(act)
    print(i)
%time lux_step(env, actions)

CPU times: user 3.93 s, sys: 0 ns, total: 3.93 s
Wall time: 3.93 s


# Benchmark Jux

In [41]:
# prepare an env
env, actions = jux.utils.load_replay("https://www.kaggleusercontent.com/episodes/45715004.json")
env.env_cfg.verbose = False
while env.env_steps < 100:
    act = next(actions)
    # print(env.env_steps, act)
    env.step(act)

# jit
_state_step_late_game = jax.jit(State._step_late_game)
_state_step_late_game_vmap = jax.jit(jax.vmap(_state_step_late_game))

In [42]:
# config
buf_cfg = JuxBufferConfig(MAX_N_UNITS=100)
B = 20000

## without vamp

In [43]:
# prepare state and action
jux_state = State.from_lux(env.state, buf_cfg)
jux_act = jux_state.parse_actions_from_dict(act)

# warm up jit
_state_step_late_game(jux_state, jux_act); 

In [44]:
%timeit _state_step_late_game(jux_state, jux_act); jnp.array(0).block_until_ready()

1.78 ms ± 18.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## With vmap

In [45]:
# prepare state and action
jux_state_batch = jax.tree_map(lambda x: x[None].repeat(B, axis=0), jux_state)
jux_act_batch = jax.tree_map(lambda x: x[None].repeat(B, axis=0), jux_act)

# warm up jit
_state_step_late_game_vmap(jux_state_batch, jux_act_batch);

In [46]:
%timeit _state_step_late_game_vmap(jux_state_batch, jux_act_batch); jnp.array(0).block_until_ready()

365 ms ± 88.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [47]:
B*3.85/365

210.95890410958904

In [48]:
del jux_state_batch
del jux_act_batch