In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

In [2]:
import jax
import jax.numpy as jnp
import haiku as hk
from flax import linen as nn

import gymnax
from gymnax.experimental.dojos import EvaluationDojo

# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Pendulum-v0")
print(env_params)

num_steps = 200
parallel_episodes = 10
rng, rng_net, rng_episode, rng_reset = jax.random.split(rng, 4)
rng_batch = jax.random.split(rng, parallel_episodes)



FrozenDict({
    max_speed: 8,
    max_torque: 2.0,
    dt: 0.05,
    g: 10.0,
    m: 1.0,
    l: 1.0,
    max_steps_in_episode: 200,
})


In [3]:
class MinimalEvaluationAgent():
    """ A Minimal Wrapper for an evaluation agent. """
    def __init__(self, policy):
        """ Init all key features of the agent. E.g. this may include:
            - Policy function network forward function
            - Exploitation schedule to use in evaluation
            - Here: Deterministic Agent - but could also be stochastic!
        """
        self.policy = policy

    def actor_step(self, key, agent_params, obs, actor_state):
        """ Policy forward pass + return action and new state. """
        action = self.policy(agent_params, obs)
        return action, actor_state
    
    def init_actor_state(self, evaluate):
        return None

# Simple Plain JAX MLP Policy

In [4]:
def init_policy_mlp(rng_input, sizes, scale=1e-2):
    """ Initialize the weights of all layers of a relu + linear layer """
    # Initialize a single layer with Gaussian weights - helper function
    def initialize_layer(m, n, key, scale):
        w_key, b_key = jax.random.split(key)
        return (scale * jax.random.normal(w_key, (n, m)),
                scale * jax.random.normal(b_key, (n,)))

    keys = jax.random.split(rng_input, len(sizes)+1)
    W1, b1 = initialize_layer(sizes[0], sizes[1],
                              keys[0], scale)
    W2, b2 = initialize_layer(sizes[1], sizes[2],
                              keys[1], scale)
    params = {"W1": W1, "b1": b1, "W2": W2, "b2": b2}
    return params


def PolicyJAX(params, obs):
    """ Compute forward pass and return action from deterministic policy """
    def relu_layer(W, b, x):
        """ Simple ReLu layer for single sample """
        return jnp.maximum(0, (jnp.dot(W, x) + b))
    # Simple single hidden layer MLP: Obs -> Hidden -> Action
    activations = relu_layer(params["W1"], params["b1"], obs)
    mean_policy = jnp.dot(params["W2"], activations) + params["b2"]
    return mean_policy

input_dim = 3
policy_params = init_policy_mlp(rng_net, sizes=[input_dim, 16, 1])
agent = MinimalEvaluationAgent(PolicyJAX)

In [12]:
collector = EvaluationDojo(agent, step, reset, env_params)
collector.init_dojo(policy_params)
trace, reward = collector.steps_rollout(rng_episode, num_steps)
traces, rewards = collector.batch_rollout(rng_batch, num_steps)

{'W1': DeviceArray([[-0.00144986, -0.00069665, -0.00387892],
             [ 0.01257899,  0.01681408, -0.01116779],
             [-0.00631701,  0.01585951,  0.01501916],
             [-0.00559352, -0.01859516,  0.00729775],
             [ 0.01136356, -0.02308235,  0.00455822],
             [ 0.00179699, -0.00209346, -0.00381558],
             [-0.00845862, -0.01149663, -0.01519296],
             [ 0.01073758, -0.0068458 , -0.01145768],
             [ 0.00275306,  0.0001691 ,  0.0185804 ],
             [ 0.00978198,  0.0066713 , -0.00369183],
             [ 0.01377987, -0.00318499,  0.00491256],
             [ 0.00328942,  0.0002241 , -0.00772289],
             [ 0.00664143, -0.00185791,  0.0073858 ],
             [ 0.00764628, -0.00054423, -0.00770936],
             [ 0.00191605,  0.00185747, -0.00518309],
             [-0.01456241, -0.00582085, -0.02329833]], dtype=float32), 'b1': DeviceArray([-0.00135108, -0.01499715,  0.00870826,  0.00763346,
              0.01009146,  0.02091292,  0

In [15]:
%timeit traces, rewards = collector.steps_rollout(rng_episode, num_steps)
%timeit trace, reward = collector.batch_rollout(rng_batch, num_steps)

153 µs ± 1.44 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
697 µs ± 55.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# Haiku MLP Policy

In [None]:
def policy_fct(x):
    """ Standard MLP policy network."""
    mlp = hk.Sequential([
      hk.Flatten(),
      hk.Linear(16), jax.nn.relu,
      hk.Linear(1),
    ])
    return mlp(x)


PolicyHaiku = hk.without_apply_rng(hk.transform(policy_fct))
obs, state = reset(rng_net, env_params)
policy_params = PolicyHaiku.init(rng_net, obs)
agent = MinimalEvaluationAgent(PolicyHaiku.apply)

In [None]:
collector = EvaluationDojo(agent, step, reset, env_params)
collector.init_dojo(policy_params)
trace, reward = collector.steps_rollout(rng_episode, num_steps)
traces, rewards = collector.batch_rollout(rng_batch, num_steps)

In [None]:
%timeit traces, rewards = collector.steps_rollout(rng_episode, num_steps)
%timeit traces, rewards = collector.batch_rollout(rng_batch, num_steps)

# Flax MLP Policy

In [None]:
class PolicyFLAX(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(16, name='fc1')(x)
        x = nn.relu(x)
        action = nn.Dense(1, name='fc2')(x)
        return action

obs, state = reset(rng_net, env_params)
policy_params = PolicyFLAX().init(rng_net, obs)
agent = MinimalEvaluationAgent(PolicyFLAX().apply)

In [None]:
collector = EvaluationDojo(agent, step, reset, env_params)
collector.init_dojo(policy_params)
trace, reward = collector.steps_rollout(rng_episode, num_steps)
traces, rewards = collector.batch_rollout(rng_batch, num_steps)

In [None]:
%timeit traces, rewards = collector.steps_rollout(rng_episode, num_steps)
%timeit trace, reward = collector.batch_rollout(rng_batch, num_steps)

# Trax MLP Policy

In [None]:
run_trax = False
if run_trax:
    # Trax import takes forever!!!
    # But in terms of runtime it is a lot faster than Haiku/Flax
    import trax
    from trax import layers as tl

    # Problem with Trax: Takes input differently
    # ---- Haiku, JAX, Flax model(params, input)
    # ---- Trax model(input, params)
    # Need helper function that re-routes inputs

    def policy_fct():
        model = tl.Serial(
          tl.Dense(16),
          tl.Relu(),
          tl.Dense(1),
        )
        return model

    policy = policy_fct()
    policy_params, _ = policy.init(trax.shapes.signature(obs))
    def PolicyTrax(params, obs):
        """ Helper for correct mapping of policy params & input"""
        return policy(obs, params)
    agent = MinimalEvaluationAgent(PolicyTrax)

In [None]:
if run_trax:
    collector = EvaluationDojo(agent, step, reset, env_params)
    collector.init_dojo(policy_params)
    trace, reward = collector.steps_rollout(rng_episode, num_steps)
    traces, rewards = collector.batch_rollout(rng_batch, num_steps)

In [None]:
if run_trax:
    %timeit traces, rewards = collector.steps_rollout(rng_episode, num_steps)
    %timeit trace, reward = collector.batch_rollout(rng_batch, num_steps)

# Replay Buffer Tryout

In [None]:
from gymnax.utils import init_buffer, push_buffer, sample_buffer
from gymnax.dojos import InterleavedDojo
from gymnax.agents import MinimalInterleavedAgent

In [None]:
# Make a dummy step transition & store it in buffer
rng, reset, step, env_params = gymnax.make("Pendulum-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)
action = jnp.array([1])
next_obs, next_state, reward, done, _ = step(key_step, env_params,
                                             state, action)


In [None]:
capacity = 11
env_params["max_steps_in_episode"] = 5
# Initialize buffer with templates
buffer = init_buffer(state, obs, action, capacity)
step_experience = {"state": state,
                   "next_state": next_state,
                   "obs": obs,
                   "next_obs": next_obs,
                   "action": action,
                   "reward": reward,
                   "done": done}
for i in range(11):
    next_obs, next_state, reward, done, _ = step(key_step, env_params,
                                                 state, action)
    step_experience = {"state": state,
                       "next_state": next_state,
                       "obs": obs,
                       "next_obs": next_obs,
                       "action": action,
                       "reward": reward,
                       "done": done}
    buffer = push_buffer(buffer, step_experience)
    # Auto-reset environment and use obs/state if episode terminated
    obs_reset, state_reset = reset(rng_reset, env_params)
    next_obs = done * obs_reset + (1 - done) * next_obs
    next_state = done * state_reset + (1 - done) * next_state
    # Update state/obs for next step
    state = next_state
    obs = next_obs

In [None]:
buffer

In [None]:
rng, key_sample = jax.random.split(rng)
sample_buffer(key_sample, buffer, 3)

In [None]:
agent = MinimalInterleavedAgent(PolicyJAX)
collector = InterleavedDojo(agent, buffer, push_buffer, sample_buffer,
                            step, reset, env_params)
collector.init_dojo(policy_params)
trace, reward = collector.steps_rollout(rng_episode, num_steps)

In [None]:
collector.buffer

In [None]:
!pwd

In [None]:
import zipfile

def zipdir(path: str, zip_fname: str):
    """ Zip a directory to upload afterwards to GCloud Storage. """
    # ziph is zipfile handle
    ziph = zipfile.ZipFile(zip_fname, 'w', zipfile.ZIP_DEFLATED)
    for root, dirs, files in os.walk(path):
        for file in files:
            ziph.write(os.path.join(root, file))
    ziph.close()

In [None]:
import os
path, file = os.path.split()

In [None]:
path

In [None]:
file

In [None]:
path = "/Users/rtl/Dropbox/core-code/mle-toolbox/examples"
prefix_len = len(path)

In [None]:
for root, dirs, files in os.walk(path):
    print(root, dirs, files)
    print(os.path.join(root[prefix_len+1:], file))