In [1]:
import jax, jax.numpy as jnp
from tensorneat.utils import State
from problem.rl_env import BraxEnv


def random_policy(state: State, obs, randkey):
    return jax.random.uniform(randkey, (8,)) * 2 - 1

In [2]:
# single evaluation without recording episode
randkey = jax.random.key(0)
env_key, policy_key = jax.random.split(randkey)
problem = BraxEnv(env_name="ant", max_step=100)
state = problem.setup()
evaluate_using_random_policy_without_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,
                                                                                   policy_key)
score = jax.jit(evaluate_using_random_policy_without_record)(state, env_key, policy_key)
score

Array(24.975231, dtype=float32)

In [3]:
# batch evaluation without recording episode
batch = 10
env_keys = jax.random.split(env_key, batch)
policy_keys = jax.random.split(policy_key, batch)

score = jax.jit(
    jax.vmap(
        evaluate_using_random_policy_without_record, 
        in_axes=(None, 0, 0)
    ))(
        state, env_keys, policy_keys
    )
score

Array([  -3.274895 ,   -6.016205 ,   -6.9032974,    9.187286 ,
       -120.19688  ,   12.389805 ,   -4.6393256,  -50.27197  ,
          9.650737 ,  -73.77956  ], dtype=float32)

In [4]:
# single evaluation with recording episode
randkey = jax.random.key(0)
env_key, policy_key = jax.random.split(randkey)
problem = BraxEnv(env_name="ant", max_step=100, record_episode=True)
evaluate_using_random_policy_with_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,
                                                                                   policy_key)
score, episode = jax.jit(evaluate_using_random_policy_with_record)(state, env_key, policy_key)
score, episode["obs"].shape, episode["action"].shape, episode["reward"].shape

(Array(18.354952, dtype=float32), (100, 27), (100, 8), (100,))

In [5]:
# batch evaluation without recording episode
batch = 10
env_keys = jax.random.split(env_key, batch)
policy_keys = jax.random.split(policy_key, batch)

scores, episodes = jax.jit(
    jax.vmap(
        evaluate_using_random_policy_with_record, 
        in_axes=(None, 0, 0)
    ))(
        state, env_keys, policy_keys
    )
score, episodes["obs"].shape, episodes["action"].shape, episodes["reward"].shape

(Array(18.354952, dtype=float32), (10, 100, 27), (10, 100, 8), (10, 100))

In [6]:
evaluate_using_random_policy_with_record = jax.jit(evaluate_using_random_policy_with_record)
evaluate_using_random_policy_without_record = jax.jit(evaluate_using_random_policy_without_record)
evaluate_using_random_policy_with_record(state, env_key, policy_key)
evaluate_using_random_policy_without_record(state, env_key, policy_key)

Array(18.354952, dtype=float32)

In [7]:
for _ in range(20):
    evaluate_using_random_policy_with_record(state, env_key, policy_key)
# 47s384ms

In [8]:
for _ in range(20):
    evaluate_using_random_policy_without_record(state, env_key, policy_key)
# 48s559ms

In [12]:
# single evaluation without recording episode
from problem.rl_env import GymNaxEnv

def random_policy(state: State, obs, randkey):
    return jax.random.uniform(randkey, ()) 

randkey = jax.random.key(0)
env_key, policy_key = jax.random.split(randkey)
problem = GymNaxEnv(env_name="CartPole-v1", max_step=500)
state = problem.setup()
evaluate_using_random_policy_without_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,
                                                                                   policy_key)
score = jax.jit(evaluate_using_random_policy_without_record)(state, env_key, policy_key)
score

Array(9., dtype=float32, weak_type=True)

In [13]:
# batch evaluation without recording episode
batch = 10
env_keys = jax.random.split(env_key, batch)
policy_keys = jax.random.split(policy_key, batch)

score = jax.jit(
    jax.vmap(
        evaluate_using_random_policy_without_record, 
        in_axes=(None, 0, 0)
    ))(
        state, env_keys, policy_keys
    )
score

Array([13., 19., 11., 12., 14., 21., 13., 11., 11., 28.],      dtype=float32, weak_type=True)

In [14]:
# single evaluation with recording episode
randkey = jax.random.key(0)
env_key, policy_key = jax.random.split(randkey)
problem = GymNaxEnv(env_name="CartPole-v1", max_step=500, record_episode=True)
evaluate_using_random_policy_with_record = lambda state, env_key, policy_key: problem.evaluate(state, env_key, random_policy,
                                                                                   policy_key)
score, episode = jax.jit(evaluate_using_random_policy_with_record)(state, env_key, policy_key)
score, episode["obs"].shape, episode["action"].shape, episode["reward"].shape

(Array(9., dtype=float32, weak_type=True), (500, 4), (500,), (500,))

In [15]:
# batch evaluation without recording episode
batch = 10
env_keys = jax.random.split(env_key, batch)
policy_keys = jax.random.split(policy_key, batch)

scores, episodes = jax.jit(
    jax.vmap(
        evaluate_using_random_policy_with_record, 
        in_axes=(None, 0, 0)
    ))(
        state, env_keys, policy_keys
    )
score, episodes["obs"].shape, episodes["action"].shape, episodes["reward"].shape

(Array(9., dtype=float32, weak_type=True), (10, 500, 4), (10, 500), (10, 500))

In [16]:
evaluate_using_random_policy_with_record = jax.jit(evaluate_using_random_policy_with_record)
evaluate_using_random_policy_without_record = jax.jit(evaluate_using_random_policy_without_record)
evaluate_using_random_policy_with_record(state, env_key, policy_key)
evaluate_using_random_policy_without_record(state, env_key, policy_key)

Array(9., dtype=float32, weak_type=True)

In [19]:
for _ in range(20):
    evaluate_using_random_policy_with_record(state, env_key, policy_key)
# 48ms

In [20]:
for _ in range(20):
    evaluate_using_random_policy_without_record(state, env_key, policy_key)
# 43ms