In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

In [2]:
import jax
import jax.numpy as jnp
import haiku as hk
from flax import linen as nn

import gymnax
from gymnax.rollouts import BaseWrapper

# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Pendulum-v0")

max_steps_in_episode = 200
parallel_episodes = 10

rng, rng_net, rng_episode = jax.random.split(rng, 3)
rng_batch = jax.random.split(rng, parallel_episodes)



# Simple Plain JAX MLP Policy

In [None]:
def init_policy_mlp(rng_input, sizes, scale=1e-2):
    """ Initialize the weights of all layers of a relu + linear layer """
    # Initialize a single layer with Gaussian weights - helper function
    def initialize_layer(m, n, key, scale):
        w_key, b_key = jax.random.split(key)
        return (scale * jax.random.normal(w_key, (n, m)),
                scale * jax.random.normal(b_key, (n,)))

    keys = jax.random.split(rng_input, len(sizes)+1)
    W1, b1 = initialize_layer(sizes[0], sizes[1],
                              keys[0], scale)
    W2, b2 = initialize_layer(sizes[1], sizes[2],
                              keys[1], scale)
    params = {"W1": W1, "b1": b1, "W2": W2, "b2": b2}
    return params


def ffw_policy(params, obs):
    """ Compute forward pass and return action from deterministic policy """
    def relu_layer(W, b, x):
        """ Simple ReLu layer for single sample """
        return jnp.maximum(0, (jnp.dot(W, x) + b))
    # Simple single hidden layer MLP: Obs -> Hidden -> Action
    activations = relu_layer(params["W1"], params["b1"], obs)
    mean_policy = jnp.dot(params["W2"], activations) + params["b2"]
    return mean_policy

policy_params = init_policy_mlp(rng_net, sizes=[3, 16, 1])

In [None]:
collector = BaseWrapper(ffw_policy, step, reset,
                        env_params, max_steps_in_episode)

trace, reward = collector.episode_rollout(rng_episode, policy_params)
traces, rewards = collector.batch_rollout(rng_batch, policy_params)

# Haiku MLP Policy

In [None]:
def policy_fct(x):
    """ Standard MLP policy network."""
    mlp = hk.Sequential([
      hk.Flatten(),
      hk.Linear(16), jax.nn.relu,
      hk.Linear(1),
    ])
    return mlp(x)


ffw_policy = hk.without_apply_rng(hk.transform(policy_fct))
obs, state = reset(rng_net, env_params)
policy_params = ffw_policy.init(rng_net, obs)

In [None]:
collector = BaseWrapper(ffw_policy.apply, step, reset,
                        env_params, max_steps_in_episode)

trace, reward = collector.episode_rollout(rng_episode, policy_params)
traces, rewards = collector.batch_rollout(rng_batch, policy_params)

# Flax MLP Policy

In [None]:
class PolicyMLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(16, name='fc1')(x)
        x = nn.relu(x)
        action = nn.Dense(1, name='fc2')(x)
        return action

obs, state = reset(rng_net, env_params)
policy_params = PolicyMLP().init(rng_net, obs)

In [None]:
collector = BaseWrapper(PolicyMLP().apply, step, reset,
                        env_params, max_steps_in_episode)

trace, reward = collector.episode_rollout(rng_episode, policy_params)
traces, rewards = collector.batch_rollout(rng_batch, policy_params)

In [None]:
rewards