# `gymnax`: Classic Gym Environments in JAX
### Author: [@RobertTLange](https://twitter.com/RobertTLange) [Last Update: November 2021][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/gymnax/blob/main/examples/getting_started.ipynb)
<a href="https://github.com/RobertTLange/gymnax/blob/main/docs/gymnax_logo.png?raw=true"><img src="https://github.com/RobertTLange/gymnax/blob/main/docs/gymnax_logo.png?raw=true" width="200" align="right" /></a>

## Basic API: `gymnax.make()`, `env.reset()`, `env.step()`

In [6]:
import jax
import jax.numpy as jnp
import gymnax

rng = jax.random.PRNGKey(0)
rng, key_reset, key_policy, key_step = jax.random.split(rng, 4)

env, env_params = gymnax.make("Pendulum-v1")

obs, state = env.reset(key_reset, env_params)
action = env.action_space(env_params).sample(key_policy)
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

In [7]:
# Add simple vmap for step/reset

## Jitted Episode Rollouts via `lax.scan`

In [12]:
from flax import linen as nn


class MLP(nn.Module):
    """Simple ReLU MLP."""

    num_hidden_units: int
    num_hidden_layers: int
    num_output_units: int

    @nn.compact
    def __call__(self, x, rng):
        for l in range(self.num_hidden_layers):
            x = nn.Dense(features=self.num_hidden_units)(x)
            x = nn.relu(x)
        x = nn.Dense(features=self.num_output_units)(x)
        return x
    

network = MLP(48, 1, 1)
policy_params = network.init(rng, jnp.zeros(3), None)["params"]

In [13]:
def rollout(rng_input, policy_params, env_params, num_env_steps):
    """Rollout a jitted gymnax episode with lax.scan."""
    # Reset the environment
    rng_reset, rng_episode = jax.random.split(rng_input)
    obs, state = env.reset(rng_reset, env_params)

    def policy_step(state_input, tmp):
        """lax.scan compatible step transition in jax env."""
        obs, state, policy_params, rng = state_input
        rng, rng_step, rng_net = jax.random.split(rng, 3)
        action = network.apply({"params": policy_params}, obs, rng_net)
        next_o, next_s, reward, done, _ = env.step(
          rng_step, state, action, env_params
        )
        carry = [next_o.squeeze(), next_s, policy_params, rng]
        return carry, [reward, done]

    # Scan over episode step loop
    _, scan_out = jax.lax.scan(
      policy_step,
      [obs, state, policy_params, rng_episode],
      [jnp.zeros((num_env_steps, 2))],
    )
    # Return masked sum of rewards accumulated by agent in episode
    rewards, dones = scan_out[0], scan_out[1]
    rewards = rewards.reshape(num_env_steps, 1)
    ep_mask = (jnp.cumsum(dones) < 1).reshape(num_env_steps, 1)
    return jnp.sum(rewards * ep_mask)

In [14]:
# Jit-Compiled Episode Rollout
jit_rollout = jax.jit(rollout, static_argnums=3)
jit_rollout(rng, policy_params, env_params, 200)

DeviceArray(-1600.6174, dtype=float32)

In [18]:
from gymnax.experimental.rollout import EnvRollout

roller = EnvRollout(model_forward=network.apply,
                    env_name="Pendulum-v1",
                    num_env_steps=200,
                    num_episodes=10)

roller.collect(rng, policy_params)

DeviceArray([-1076.0648, -1299.151 , -1502.7104, -1210.023 , -1287.3069,
             -1267.0559, -1529.9363, -1517.0801, -1809.3643, -1565.3792],            dtype=float32)

# Batch Rollouts via `jax.vmap`/`jax.pmap`

In [None]:
# jax.vmap across random keys for batch rollout
batch_rollout = jax.vmap(jit_rollout, in_axes=(0, None, None, None))

In [None]:
# jax.vmap across network params for "population" rollouts
pop_rollout = jax.vmap(batch_rollout, in_axes=(None, vmap_dict, None, None))

In [None]:
# jax.vmap across network params for meta-batch rollouts
meta_rollout = jax.vmap(jit_rollout, in_axes=(None, None, vmap_dict, None))

# Distributed Anakin Agent


Adapted from Hessel et al. (2021) and DeepMind's [Example Colab](https://colab.research.google.com/drive/1974D-qP17fd5mLxy6QZv-ic4yxlPJp-G?usp=sharing#scrollTo=lhnJkrYLOvcs)

In [1]:
# Imports
import chex
import os
os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=8"

import jax
import haiku as hk
from jax import lax
from jax import random
from jax import numpy as jnp
import jax.numpy as jnp
import optax
import rlax
import timeit

jax.devices()



[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

# Import `gymnax` and make `Catch-bsuite` environment transition/reset

In [2]:
import gymnax
rng, env = gymnax.make("Catch-bsuite")
obs, state = env.reset(rng)
action = env.action_space.sample(rng)
obs, state, reward, terminal, info = env.step(rng, state, action)

  lax._check_user_dtype_supported(dtype, "astype")


# Anakin Agent Setup

In [3]:
@chex.dataclass(frozen=True)
class TimeStep:
    q_values: chex.Array
    action: chex.Array
    discount: chex.Array
    reward: chex.Array

def get_network_fn(num_outputs: int):
    """Define a fully connected multi-layer haiku network."""
    def network_fn(obs: chex.Array) -> chex.Array:
        return hk.Sequential([  # flatten, hidden layer, relu, output layer.
            hk.Flatten(), hk.Linear(256), jax.nn.relu, hk.Linear(num_outputs)])(obs)
    return hk.without_apply_rng(hk.transform(network_fn))

def get_learner_fn(
    env, forward_pass, opt_update, rollout_len, agent_discount,
    lambda_, iterations):
    """Define the minimal unit of computation in Anakin."""

    def loss_fn(params, outer_rng, env_state):
        """Compute the loss on a single trajectory."""

        def step_fn(env_state, rng):
            obs = env.get_obs(env_state)
            q_values = forward_pass(params, obs[None,])[0]  # forward pass.
            action = jnp.argmax(q_values)  # greedy policy.
            obs, state, reward, terminal, info = env.step(rng, env_state, action)  # step environment.
            return env_state, TimeStep(  # return env state and transition data.
              q_values=q_values, action=action, discount=1.-terminal, reward=reward)

        step_rngs = random.split(outer_rng, rollout_len)
        env_state, rollout = lax.scan(step_fn, env_state, step_rngs)  # trajectory.
        qa_tm1 = rlax.batched_index(rollout.q_values[:-1], rollout.action[:-1])
        td_error = rlax.td_lambda(  # compute multi-step temporal diff error.
            v_tm1=qa_tm1,  # predictions.
            r_t=rollout.reward[1:],  # rewards.
            discount_t=agent_discount * rollout.discount[1:],  # discount.
            v_t=jnp.max(rollout.q_values[1:], axis=-1),  # bootstrap values.
            lambda_=lambda_)  # mixing hyper-parameter lambda.
        return jnp.mean(td_error**2), env_state

    def update_fn(params, opt_state, rng, env_state):
        """Compute a gradient update from a single trajectory."""
        rng, loss_rng = random.split(rng)
        grads, new_env_state = jax.grad(  # compute gradient on a single trajectory.
            loss_fn, has_aux=True)(params, loss_rng, env_state)
        grads = lax.pmean(grads, axis_name='j')  # reduce mean across cores.
        grads = lax.pmean(grads, axis_name='i')  # reduce mean across batch.
        updates, new_opt_state = opt_update(grads, opt_state)  # transform grads.
        new_params = optax.apply_updates(params, updates)  # update parameters.
        return new_params, new_opt_state, rng, new_env_state

    def learner_fn(params, opt_state, rngs, env_states):
        """Vectorise and repeat the update."""
        batched_update_fn = jax.vmap(update_fn, axis_name='j')  # vectorize across batch.
        def iterate_fn(_, val):  # repeat many times to avoid going back to Python.
            params, opt_state, rngs, env_states = val
            return batched_update_fn(params, opt_state, rngs, env_states)
        return lax.fori_loop(0, iterations, iterate_fn, (
            params, opt_state, rngs, env_states))

    return learner_fn

# Rollout/Step the Anakin Agent in Parallel

In [4]:
class TimeIt():
    def __init__(self, tag, frames=None):
        self.tag = tag
        self.frames = frames

    def __enter__(self):
        self.start = timeit.default_timer()
        return self

    def __exit__(self, *args):
        self.elapsed_secs = timeit.default_timer() - self.start
        msg = self.tag + (': Elapsed time=%.2fs' % self.elapsed_secs)
        if self.frames:
            msg += ', FPS=%.2e' % (self.frames / self.elapsed_secs)
        print(msg)


def run_experiment(env, batch_size, rollout_len, step_size, iterations, seed):
    """Runs experiment."""
    cores_count = len(jax.devices())  # get available TPU cores.
    network = get_network_fn(env.action_space.num_categories)  # define network.
    optim = optax.adam(step_size)  # define optimiser.

    rng, rng_e, rng_p = random.split(random.PRNGKey(seed), num=3)  # prng keys.
    obs, state = env.reset(rng_e)
    dummy_obs = obs[None,]  # dummy for net init.
    params = network.init(rng_p, dummy_obs)  # initialise params.
    opt_state = optim.init(params)  # initialise optimiser stats.

    learn = get_learner_fn(  # get batched iterated update.
      env, network.apply, optim.update, rollout_len=rollout_len,
      agent_discount=1, lambda_=0.99, iterations=iterations)
    learn = jax.pmap(learn, axis_name='i')  # replicate over multiple cores.

    broadcast = lambda x: jnp.broadcast_to(x, (cores_count, batch_size) + x.shape)
    params = jax.tree_map(broadcast, params)  # broadcast to cores and batch.
    opt_state = jax.tree_map(broadcast, opt_state)  # broadcast to cores and batch

    rng, *env_rngs = jax.random.split(rng, cores_count * batch_size + 1)
    env_obs, env_states = jax.vmap(env.reset)(jnp.stack(env_rngs))  # init envs.
    rng, *step_rngs = jax.random.split(rng, cores_count * batch_size + 1)

    reshape = lambda x: x.reshape((cores_count, batch_size) + x.shape[1:])
    step_rngs = reshape(jnp.stack(step_rngs))  # add dimension to pmap over.
    env_obs = reshape(env_obs)  # add dimension to pmap over.
    env_states = {k: reshape(env_states[k]) for k in env_states.keys()}

    with TimeIt(tag='COMPILATION'):
        learn(params, opt_state, step_rngs, env_states)  # compiles

    num_frames = cores_count * iterations * rollout_len * batch_size
    with TimeIt(tag='EXECUTION', frames=num_frames):
        params, opt_state, step_rngs, env_states = learn(  # runs compiled fn
            params, opt_state, step_rngs, env_states)

In [5]:
print('Running on', len(jax.devices()), 'cores.', flush=True)  # !expected 8!
run_experiment(env, 128, 16, 1e-4, 100, 42)

Running on 8 cores.
COMPILATION: Elapsed time=10.24s
EXECUTION: Elapsed time=12.18s, FPS=1.35e+05


# Vectorized Population Evaluation for CMA-ES