In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

In [2]:
import jax
import jax.numpy as jnp
import haiku as hk
from flax import linen as nn

import gymnax
from gymnax.rollouts import DeterministicRollouts

# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Pendulum-v0")
print(env_params)

parallel_episodes = 10
rng, rng_net, rng_episode = jax.random.split(rng, 3)
rng_batch = jax.random.split(rng, parallel_episodes)



{'max_speed': 8, 'max_torque': 2.0, 'dt': 0.05, 'g': 10.0, 'm': 1.0, 'l': 1.0, 'max_steps_in_episode': 200}


# Simple Plain JAX MLP Policy

In [None]:
def init_policy_mlp(rng_input, sizes, scale=1e-2):
    """ Initialize the weights of all layers of a relu + linear layer """
    # Initialize a single layer with Gaussian weights - helper function
    def initialize_layer(m, n, key, scale):
        w_key, b_key = jax.random.split(key)
        return (scale * jax.random.normal(w_key, (n, m)),
                scale * jax.random.normal(b_key, (n,)))

    keys = jax.random.split(rng_input, len(sizes)+1)
    W1, b1 = initialize_layer(sizes[0], sizes[1],
                              keys[0], scale)
    W2, b2 = initialize_layer(sizes[1], sizes[2],
                              keys[1], scale)
    params = {"W1": W1, "b1": b1, "W2": W2, "b2": b2}
    return params


def PolicyJAX(params, obs):
    """ Compute forward pass and return action from deterministic policy """
    def relu_layer(W, b, x):
        """ Simple ReLu layer for single sample """
        return jnp.maximum(0, (jnp.dot(W, x) + b))
    # Simple single hidden layer MLP: Obs -> Hidden -> Action
    activations = relu_layer(params["W1"], params["b1"], obs)
    mean_policy = jnp.dot(params["W2"], activations) + params["b2"]
    return mean_policy

policy_params = init_policy_mlp(rng_net, sizes=[3, 16, 1])

In [None]:
collector = DeterministicRollouts(PolicyJAX, step, reset, env_params)
collector.init_collector(policy_params)
trace, reward = collector.episode_rollout(rng_episode)
traces, rewards = collector.batch_rollout(rng_batch)

In [None]:
%timeit traces, rewards = collector.episode_rollout(rng_episode)
%timeit trace, reward = collector.batch_rollout(rng_batch)

# Haiku MLP Policy

In [None]:
def policy_fct(x):
    """ Standard MLP policy network."""
    mlp = hk.Sequential([
      hk.Flatten(),
      hk.Linear(16), jax.nn.relu,
      hk.Linear(1),
    ])
    return mlp(x)


PolicyHaiku = hk.without_apply_rng(hk.transform(policy_fct))
obs, state = reset(rng_net, env_params)
policy_params = PolicyHaiku.init(rng_net, obs)

In [None]:
collector = DeterministicRollouts(PolicyHaiku.apply, step, reset, env_params)
collector.init_collector(policy_params)
trace, reward = collector.episode_rollout(rng_episode)
traces, rewards = collector.batch_rollout(rng_batch)

In [None]:
%timeit traces, rewards = collector.episode_rollout(rng_episode)
%timeit traces, rewards = collector.batch_rollout(rng_batch)

# Flax MLP Policy

In [None]:
class PolicyFLAX(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(16, name='fc1')(x)
        x = nn.relu(x)
        action = nn.Dense(1, name='fc2')(x)
        return action

obs, state = reset(rng_net, env_params)
policy_params = PolicyFLAX().init(rng_net, obs)

In [None]:
collector = DeterministicRollouts(PolicyFLAX().apply, step, reset, env_params)
collector.init_collector(policy_params)
trace, reward = collector.episode_rollout(rng_episode)
traces, rewards = collector.batch_rollout(rng_batch)

In [None]:
%timeit traces, rewards = collector.episode_rollout(rng_episode)
%timeit trace, reward = collector.batch_rollout(rng_batch)

# Trax MLP Policy

In [None]:
run_trax = True
if run_trax:
    # Trax import takes forever!!!
    import trax
    from trax import layers as tl

    # Problem with Trax: Takes input differently
    # ---- Haiku, JAX, Flax model(params, input)
    # ---- Trax model(input, params)
    # Need helper function that re-routes inputs

    def policy_fct():
        model = tl.Serial(
          tl.Dense(16),
          tl.Relu(),
          tl.Dense(1),
        )
        return model

    policy = policy_fct()
    policy_params, _ = policy.init(trax.shapes.signature(obs))
    def PolicyTrax(params, obs):
        """ Helper for correct mapping of policy params & input"""
        return policy(obs, params)

In [None]:
if run_trax:
    collector = DeterministicRollouts(PolicyTrax, step, reset, env_params)
    collector.init_collector(policy_params)
    trace, reward = collector.episode_rollout(rng_episode)
    traces, rewards = collector.batch_rollout(rng_batch)

In [None]:
if run_trax:
    %timeit traces, rewards = collector.episode_rollout(rng_episode)
    %timeit trace, reward = collector.batch_rollout(rng_batch)

# Replay Buffer Tryout

In [3]:
from gymnax.rollouts import ReplayBuffer

In [4]:
# Make a dummy step transition & store it in buffer
rng, reset, step, env_params = gymnax.make("Pendulum-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)
action = jnp.array([1])
next_obs, next_state, reward, done, _ = step(key_step, env_params,
                                             state, action)


In [5]:
capacity = 5
# Initialize buffer with templates
buffer = ReplayBuffer(state, obs, action, capacity)
for i in range(10):
    buffer.push(state, next_state, obs, next_obs, action, reward, done)

In [6]:
rng, key_sample = jax.random.split(rng)
buffer.sample(key_sample, 3)

TypeError: Argument '<gymnax.rollouts.replay_buffer.ReplayBuffer object at 0x7f805f411d30>' of type <class 'gymnax.rollouts.replay_buffer.ReplayBuffer'> is not a valid JAX type.

# Play around with reshaping of updated params after vmapping

In [None]:
print(trace[4][0][0])
print(trace[4][0][1])
print(trace[4][1])
print(trace[4][2][0])
print(trace[4][2][1])

In [None]:
print(traces[4][0][0][0])
print(traces[4][0][1][0])
print(traces[4][1])
print(traces[4][2][0][0])
print(traces[4][2][1][0])

In [None]:
policy_params