In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

In [3]:
import jax
import jax.numpy as jnp
import haiku as hk
from flax import linen as nn

import gymnax
from gymnax.dojos import EvaluationDojo

# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Pendulum-v0")
print(env_params)

parallel_episodes = 10
rng, rng_net, rng_episode = jax.random.split(rng, 3)
rng_batch = jax.random.split(rng, parallel_episodes)



{'max_speed': 8, 'max_torque': 2.0, 'dt': 0.05, 'g': 10.0, 'm': 1.0, 'l': 1.0, 'max_steps_in_episode': 200}


In [4]:
class MinimalEvaluationAgent():
    def __init__(self, policy):
        """ Init all key features of the agent. E.g. this may include:
            - Policy function network forward function
            - Exploitation schedule to use in evaluation
            - Here: Deterministic Agent - but could also be stochastic!
        """
        self.policy = policy

    def actor_step(self, key, agent_params, obs, actor_state):
        """ Policy forward pass + return action and new state. """
        action = self.policy(agent_params, obs)
        return action, actor_state
    
    def init_actor_state(self):
        return None

# Simple Plain JAX MLP Policy

In [5]:
def init_policy_mlp(rng_input, sizes, scale=1e-2):
    """ Initialize the weights of all layers of a relu + linear layer """
    # Initialize a single layer with Gaussian weights - helper function
    def initialize_layer(m, n, key, scale):
        w_key, b_key = jax.random.split(key)
        return (scale * jax.random.normal(w_key, (n, m)),
                scale * jax.random.normal(b_key, (n,)))

    keys = jax.random.split(rng_input, len(sizes)+1)
    W1, b1 = initialize_layer(sizes[0], sizes[1],
                              keys[0], scale)
    W2, b2 = initialize_layer(sizes[1], sizes[2],
                              keys[1], scale)
    params = {"W1": W1, "b1": b1, "W2": W2, "b2": b2}
    return params


def PolicyJAX(params, obs):
    """ Compute forward pass and return action from deterministic policy """
    def relu_layer(W, b, x):
        """ Simple ReLu layer for single sample """
        return jnp.maximum(0, (jnp.dot(W, x) + b))
    # Simple single hidden layer MLP: Obs -> Hidden -> Action
    activations = relu_layer(params["W1"], params["b1"], obs)
    mean_policy = jnp.dot(params["W2"], activations) + params["b2"]
    return mean_policy

input_dim = 3
policy_params = init_policy_mlp(rng_net, sizes=[input_dim, 16, 1])
agent = MinimalEvaluationAgent(PolicyJAX)

In [None]:
collector = EvaluationDojo(agent, step, reset, env_params)
collector.init_dojo(policy_params)
trace, reward = collector.episode_rollout(rng_episode)
traces, rewards = collector.batch_rollout(rng_batch)

In [None]:
%timeit traces, rewards = collector.episode_rollout(rng_episode)
%timeit trace, reward = collector.batch_rollout(rng_batch)

# Haiku MLP Policy

In [None]:
def policy_fct(x):
    """ Standard MLP policy network."""
    mlp = hk.Sequential([
      hk.Flatten(),
      hk.Linear(16), jax.nn.relu,
      hk.Linear(1),
    ])
    return mlp(x)


PolicyHaiku = hk.without_apply_rng(hk.transform(policy_fct))
obs, state = reset(rng_net, env_params)
policy_params = PolicyHaiku.init(rng_net, obs)
agent = MinimalEvaluationAgent(PolicyHaiku.apply)

In [None]:
collector = EvaluationDojo(agent, step, reset, env_params)
collector.init_dojo(policy_params)
trace, reward = collector.episode_rollout(rng_episode)
traces, rewards = collector.batch_rollout(rng_batch)

In [None]:
%timeit traces, rewards = collector.episode_rollout(rng_episode)
%timeit traces, rewards = collector.batch_rollout(rng_batch)

# Flax MLP Policy

In [None]:
class PolicyFLAX(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(16, name='fc1')(x)
        x = nn.relu(x)
        action = nn.Dense(1, name='fc2')(x)
        return action

obs, state = reset(rng_net, env_params)
policy_params = PolicyFLAX().init(rng_net, obs)
agent = MinimalEvaluationAgent(PolicyFLAX().apply)

In [None]:
collector = EvaluationDojo(agent, step, reset, env_params)
collector.init_dojo(policy_params)
trace, reward = collector.episode_rollout(rng_episode)
traces, rewards = collector.batch_rollout(rng_batch)

In [None]:
%timeit traces, rewards = collector.episode_rollout(rng_episode)
%timeit trace, reward = collector.batch_rollout(rng_batch)

# Trax MLP Policy

In [None]:
run_trax = True
if run_trax:
    # Trax import takes forever!!!
    import trax
    from trax import layers as tl

    # Problem with Trax: Takes input differently
    # ---- Haiku, JAX, Flax model(params, input)
    # ---- Trax model(input, params)
    # Need helper function that re-routes inputs

    def policy_fct():
        model = tl.Serial(
          tl.Dense(16),
          tl.Relu(),
          tl.Dense(1),
        )
        return model

    policy = policy_fct()
    policy_params, _ = policy.init(trax.shapes.signature(obs))
    def PolicyTrax(params, obs):
        """ Helper for correct mapping of policy params & input"""
        return policy(obs, params)
    agent = MinimalEvaluationAgent(PolicyTrax)

In [None]:
if run_trax:
    collector = EvaluationDojo(agent, step, reset, env_params)
    collector.init_dojo(policy_params)
    trace, reward = collector.episode_rollout(rng_episode)
    traces, rewards = collector.batch_rollout(rng_batch)

In [None]:
if run_trax:
    %timeit traces, rewards = collector.episode_rollout(rng_episode)
    %timeit trace, reward = collector.batch_rollout(rng_batch)

# Replay Buffer Tryout

In [None]:
from gymnax.utils import init_buffer, push_buffer, sample_buffer

In [None]:
# Make a dummy step transition & store it in buffer
rng, reset, step, env_params = gymnax.make("Pendulum-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)
action = jnp.array([1])
next_obs, next_state, reward, done, _ = step(key_step, env_params,
                                             state, action)


In [None]:
capacity = 5
# Initialize buffer with templates
buffer = init_buffer(state, obs, action, capacity)
step_experience = {"state": state,
                   "next_state": next_state,
                   "obs": obs,
                   "next_obs": next_obs,
                   "action": action,
                   "reward": reward,
                   "done": done}
for i in range(10):
    buffer = push_buffer(buffer, step_experience)

In [None]:
rng, key_sample = jax.random.split(rng)
sample_buffer(key_sample, buffer, 3)

In [None]:
from gymnax.dojos import InterleavedDojo

class MinimalInterleavedAgent():
    def __init__(self, policy):
        """ Init all key features of the agent. E.g. this may include:
            - Policy/Value function network forward function
            - Optimizer to use in learner_step = Use optax!
            - Exploration schedule to use in actor_step
        """
        self.policy = policy

    def actor_step(self, key, agent_params, obs, actor_state):
        """ Policy forward pass + return action and new state. """
        action = self.policy(agent_params, obs)
        return action, actor_state
    
    def learner_step(self, key, agent_params, learner_state):
        """ Update the network params + return new state (e.g. of opt). """
        return agent_params, learner_state
    
    def init_learner_state(self, agent_params):
        return None
    
    def init_actor_state(self):
        return None

agent = MinimalInterleavedAgent(PolicyJAX)
collector = InterleavedDojo(agent, buffer, push_buffer, sample_buffer,
                            step, reset, env_params)
collector.init_dojo(policy_params)
trace, reward = collector.episode_rollout(rng_episode)

In [None]:
collector.buffer