In [None]:
import jux
from rich import print
from jux.state import State
from jux.config import JuxBufferConfig
import jax
import jax.numpy as jnp

# Benchmark Lux

In [None]:
env, actions = jux.utils.load_replay("https://www.kaggleusercontent.com/episodes/45715004.json")
def lux_step(env, actions):
    for i, act in enumerate(actions):
        env.step(act)
    print(i)
%time lux_step(env, actions)

# Benchmark Jux

## without vamp

In [None]:
# config
buf_cfg = JuxBufferConfig(MAX_N_UNITS=1000)
B = 500

# prepare an env
env, actions = jux.utils.load_replay("https://www.kaggleusercontent.com/episodes/45715004.json")
while env.env_steps < 100:
    act = next(actions)
    # print(env.env_steps, act)
    env.step(act)

# prepare state and action
jux_state = State.from_lux(env.state, buf_cfg)
jux_act = jux_state.parse_actions_from_dict(act)

# jit
_state_step_late_game = jax.jit(State._step_late_game)
_state_step_late_game(jux_state, jux_act); # warm up jit

In [None]:
%timeit _state_step_late_game(jux_state, jux_act); jnp.array(0).block_until_ready()

## With vmap

In [None]:
jux_state_batch = jux.tree_util.batch_into_leaf([jux_state]* B)
jux_act_batch = jux.tree_util.batch_into_leaf([jux_act]* B)
_state_step_late_game_vmap = jax.jit(jax.vmap(_state_step_late_game))
_state_step_late_game_vmap(jux_state_batch, jux_act_batch);

In [None]:
%timeit _state_step_late_game_vmap(jux_state_batch, jux_act_batch); jnp.array(0).block_until_ready()