In [14]:
import jax
import jax.numpy as jnp
import gymnax
import equinox as eqx

In [2]:
jax.device_count(), jax.devices()

(1, [CpuDevice(id=0)])

In [38]:
rng = jax.random.PRNGKey(0)
rng, key_reset, key_policy, key_step = jax.random.split(rng, 4)

# Create the Pendulum-v1 environment
env, env_params = gymnax.make("Pendulum-v1")

# Inspect default environment settings
env_params

EnvParams(max_speed=8.0, max_torque=2.0, dt=0.05, g=10.0, m=1.0, l=1.0, max_steps_in_episode=200)

In [6]:
obs, state = env.reset(key_reset, env_params)
obs, state

(Array([-0.939326  , -0.34302574, -0.6520283 ], dtype=float32),
 EnvState(theta=Array(-2.7914565, dtype=float32), theta_dot=Array(-0.6520283, dtype=float32), last_u=Array(0., dtype=float32, weak_type=True), time=Array(0, dtype=int32, weak_type=True)))

In [7]:
action = env.action_space(env_params).sample(key_policy)
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)
n_obs, n_state, reward, done

(Array([-0.9494436 , -0.31393763, -0.6159719 ], dtype=float32),
 EnvState(theta=Array(-2.8222551, dtype=float32), theta_dot=Array(-0.6159719, dtype=float32), last_u=Array(1.9555049, dtype=float32), time=Array(1, dtype=int32, weak_type=True)),
 Array(-7.8385677, dtype=float32),
 Array(False, dtype=bool, weak_type=True))

In [8]:
vmap_reset = jax.vmap(env.reset, in_axes=(0, None))
vmap_step = jax.vmap(env.step, in_axes=(0, 0, 0, None))

In [9]:
num_envs = 8
vmap_keys = jax.random.split(rng, num_envs)

In [12]:
obs, state = vmap_reset(vmap_keys, env_params)
obs.shape

(8, 3)

In [13]:
n_obs, n_state, reward, done, _ = vmap_step(vmap_keys, state, jnp.zeros(num_envs), env_params)
print(n_obs.shape)

(8, 3)


In [60]:
class MLP(eqx.Module):
    mlp: eqx.nn.MLP
    def __init__(self, in_size, out_size, *, key) -> None:
        self.mlp = eqx.nn.MLP(in_size=in_size, out_size=out_size, width_size=32, depth=2,
                             key=key)
    
    def __call__(self, x, *, key):
        return self.mlp(x)

key = jax.random.PRNGKey(32)
policy = MLP(in_size=env.observation_space(params=env_params).shape[0], 
             out_size=env.action_space(params=env_params).shape[0],
                key=key)

In [162]:
def rollout(rng_input, policy, env, env_params, steps_in_episode, epoch):
    """Rollout a jitted gymnax episode with lax.scan."""
    # Reset the environment
    rng_reset, rng_episode = jax.random.split(rng_input)
    obs, state = env.reset(rng_reset, env_params)

    def policy_step(state_input, tmp):
        """lax.scan compatible step transition in jax env."""
        obs, state, rng = state_input
        rng, rng_step, rng_net = jax.random.split(rng, 3)
        action = policy(obs, key=rng_net)
        next_obs, next_state, reward, done, _ = env.step(
          rng_step, state, action, env_params
        )
        carry = [next_obs, next_state, rng]
        return carry, [obs, action, reward, next_obs, done, state]

    # Scan over episode step loop
    _, scan_out = jax.lax.scan(
      policy_step,
      [obs, state, rng_episode],
      (),
      steps_in_episode
    )
    # Return masked sum of rewards accumulated by agent in episode
    obs, action, reward, next_obs, done, states = scan_out
    return obs, action, reward, next_obs, done, states

In [163]:
jit_rollout = eqx.filter_jit(rollout)

In [164]:
obs, action, reward, next_obs, done, states = jit_rollout(rng, policy, env, env_params, 200, 0)
obs.shape, reward.shape, jnp.sum(reward)

((200, 3), (200,), Array(-933.41895, dtype=float32))

In [165]:
obs, action, reward, next_obs, done, states = eqx.filter_vmap(jit_rollout, in_axes=(None, None, None, None, None, 0))(rng, policy, env, env_params, 200, jnp.arange(30))
obs.shape, reward.shape, jnp.sum(reward)

((30, 200, 3), (30, 200), Array(-28002.578, dtype=float32))

In [168]:
obs, action, reward, next_obs, done, states = jax.vmap(jit_rollout, in_axes=(None, None, None, None, None, 0))(rng, policy, env, env_params, 2000, jnp.arange(30000))
obs.shape, reward.shape, jnp.sum(reward)

((30000, 2000, 3), (30000, 2000), Array(-3.9451293e+08, dtype=float32))