In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

In [2]:
import jax
import jax.numpy as jnp
import haiku as hk
from flax import linen as nn

import gymnax
from gymnax.rollouts import DeterministicRollouts

# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Pendulum-v0")
print(env_params)

parallel_episodes = 10
rng, rng_net, rng_episode = jax.random.split(rng, 3)
rng_batch = jax.random.split(rng, parallel_episodes)



{'max_speed': 8, 'max_torque': 2.0, 'dt': 0.05, 'g': 10.0, 'm': 1.0, 'l': 1.0, 'max_steps_in_episode': 200}


# Simple Plain JAX MLP Policy

In [3]:
def init_policy_mlp(rng_input, sizes, scale=1e-2):
    """ Initialize the weights of all layers of a relu + linear layer """
    # Initialize a single layer with Gaussian weights - helper function
    def initialize_layer(m, n, key, scale):
        w_key, b_key = jax.random.split(key)
        return (scale * jax.random.normal(w_key, (n, m)),
                scale * jax.random.normal(b_key, (n,)))

    keys = jax.random.split(rng_input, len(sizes)+1)
    W1, b1 = initialize_layer(sizes[0], sizes[1],
                              keys[0], scale)
    W2, b2 = initialize_layer(sizes[1], sizes[2],
                              keys[1], scale)
    params = {"W1": W1, "b1": b1, "W2": W2, "b2": b2}
    return params


def PolicyJAX(params, obs):
    """ Compute forward pass and return action from deterministic policy """
    def relu_layer(W, b, x):
        """ Simple ReLu layer for single sample """
        return jnp.maximum(0, (jnp.dot(W, x) + b))
    # Simple single hidden layer MLP: Obs -> Hidden -> Action
    activations = relu_layer(params["W1"], params["b1"], obs)
    mean_policy = jnp.dot(params["W2"], activations) + params["b2"]
    return mean_policy

policy_params = init_policy_mlp(rng_net, sizes=[3, 16, 1])

In [4]:
collector = DeterministicRollouts(PolicyJAX, step, reset, env_params)
trace, reward = collector.episode_rollout(rng_episode, policy_params)
traces, rewards = collector.batch_rollout(rng_batch, policy_params)

In [5]:
%timeit trace, reward = collector.episode_rollout(rng_episode, policy_params)

561 µs ± 42.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# Haiku MLP Policy

In [6]:
def policy_fct(x):
    """ Standard MLP policy network."""
    mlp = hk.Sequential([
      hk.Flatten(),
      hk.Linear(16), jax.nn.relu,
      hk.Linear(1),
    ])
    return mlp(x)


PolicyHaiku = hk.without_apply_rng(hk.transform(policy_fct))
obs, state = reset(rng_net, env_params)
policy_params = PolicyHaiku.init(rng_net, obs)

In [7]:
collector = DeterministicRollouts(PolicyHaiku.apply, step, reset, env_params)
trace, reward = collector.episode_rollout(rng_episode, policy_params)
traces, rewards = collector.batch_rollout(rng_batch, policy_params)

In [8]:
%timeit trace, reward = collector.episode_rollout(rng_episode, policy_params)

1.94 ms ± 273 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# Flax MLP Policy

In [9]:
class PolicyFLAX(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(16, name='fc1')(x)
        x = nn.relu(x)
        action = nn.Dense(1, name='fc2')(x)
        return action

obs, state = reset(rng_net, env_params)
policy_params = PolicyFLAX().init(rng_net, obs)

In [10]:
collector = DeterministicRollouts(PolicyFLAX().apply, step, reset, env_params)
trace, reward = collector.episode_rollout(rng_episode, policy_params)
traces, rewards = collector.batch_rollout(rng_batch, policy_params)

In [11]:
%timeit trace, reward = collector.episode_rollout(rng_episode, policy_params)

1.86 ms ± 358 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# Trax MLP Policy

In [15]:
if False:
    # Trax import takes forever!!!
    import trax
    from trax import layers as tl

    # Problem with Trax: Takes input differently
    # ---- Haiku, JAX, Flax model(params, input)
    # ---- Trax model(input, params)

    def policy_fct():
        model = tl.Serial(
          tl.Dense(16),
          tl.Relu(),
          tl.Dense(1),
        )
        return model

    PolicyTrax = policy_fct()
    policy_params, _ = PolicyTrax.init(trax.shapes.signature(obs))

    collector = DeterministicRollouts(PolicyTrax, step, reset, env_params)
    trace, reward = collector.episode_rollout(rng_episode, policy_params)
    traces, rewards = collector.batch_rollout(rng_batch, policy_params)

    %timeit trace, reward = collector.episode_rollout(rng_episode, policy_params)

LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/<ipython-input-15-75af8b860cb3>, line 5
  layer input shapes: ((ShapeDtype{shape:(3, 16), dtype:float32}, ShapeDtype{shape:(16,), dtype:float32}), ((), (), ()), (ShapeDtype{shape:(16, 1), dtype:float32}, ShapeDtype{shape:(1,), dtype:float32}))

  File [...]/trax/layers/base.py, line 412, in weights
    sublayer.weights = sublayer_weights

  File [...]/trax/layers/base.py, line 400, in weights
    for w in weights:

  File [...]/site-packages/jax/core.py, line 470, in __iter__
    return iter(self.aval._iter(self))

  File [...]/_src/lax/lax.py, line 1940, in _iter
    raise TypeError("iteration over a 0-d array")  # same as numpy error

TypeError: iteration over a 0-d array