In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

In [2]:
import jax
import jax.numpy as jnp
import haiku as hk
from haiku import nets
import optax
import rlax
import collections

import gymnax
from gymnax.dojos import InterleavedDojo
from gymnax.utils import init_buffer, push_buffer, sample_buffer



# Import gymnax catch bsuite environment + Init replay buffer

In [3]:
rng, reset, step, env_params = gymnax.make("Catch-bsuite")
print(env_params)
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)
action = jnp.array([0])

capacity = 2000
buffer = init_buffer(state, obs, action, capacity)

{'max_steps_in_episode': 2000}


# Define the DQN Agent

In [36]:
Params = collections.namedtuple("Params", "online target")
ActorState = collections.namedtuple("ActorState", "count evaluation")
ActorOutput = collections.namedtuple("ActorOutput", "actions q_values")
LearnerState = collections.namedtuple("LearnerState", "count opt_state")


def build_network(num_actions: int) -> hk.Transformed:
    """Factory for a simple MLP network for approximating Q-values."""
    def q(obs):
        network = hk.Sequential(
            [hk.Flatten(),
             nets.MLP([50, num_actions])])
        return network(obs)
    return hk.without_apply_rng(hk.transform(q, apply_rng=True))


class DQN:
    """A simple DQN agent."""
    def __init__(self, obs_template, num_actions, epsilon_cfg,
                 target_period, learning_rate):
        self._obs_template = obs_template
        self._num_actions = num_actions
        self._target_period = target_period
        # Neural net and optimiser.
        self._network = build_network(num_actions)
        self._optimizer = optax.adam(learning_rate)
        self._epsilon_by_frame = optax.polynomial_schedule(**epsilon_cfg)

    def initial_params(self, key):
        sample_input = jnp.expand_dims(self._obs_template, 0)
        online_params = self._network.init(key, sample_input)
        return Params(online_params, online_params)

    def init_actor_state(self):
        actor_count = jnp.zeros((), dtype=jnp.float32)
        return ActorState(actor_count, bool(0))

    def init_learner_state(self, params):
        learner_count = jnp.zeros((), dtype=jnp.float32)
        opt_state = self._optimizer.init(params.online)
        return LearnerState(learner_count, opt_state)

    def actor_step(self, key, params, obs, actor_state):
        obs = jnp.expand_dims(obs, 0)  # add dummy batch
        q = self._network.apply(params.online, obs)[0]    # remove dummy batch
        epsilon = self._epsilon_by_frame(actor_state.count)
        train_a = rlax.epsilon_greedy(epsilon).sample(key, q)
        eval_a = rlax.greedy().sample(key, q)
        a = jax.lax.select(actor_state.evaluation, eval_a, train_a)
        return (a, ActorState(actor_state.count + 1, bool(0)))

    def learner_step(self, key, params, learner_state, data):
        target_params = rlax.periodic_update(
            params.online, params.target,
            learner_state.count, self._target_period)
        dloss_dtheta = jax.grad(self._loss)(params.online,
                                            target_params,
                                           data["obs"], data["action"],
                                           data["reward"], data["done"],
                                           data["next_obs"])
        updates, opt_state = self._optimizer.update(dloss_dtheta,
                                                    learner_state.opt_state)
        online_params = optax.apply_updates(params.online, updates)
        return (Params(online_params, target_params),
                LearnerState(learner_state.count + 1, opt_state))

    def _loss(self, online_params, target_params,
              obs_tm1, a_tm1, r_t, discount_t, obs_t):
        q_tm1 = self._network.apply(online_params, obs_tm1)
        q_t_val = self._network.apply(target_params, obs_t)
        q_t_select = self._network.apply(online_params, obs_t)
        batched_loss = jax.vmap(rlax.double_q_learning)
        # TODO: Problems with chex in rlax function!
        # Rank compatibility = squeeze inputs
        # Type compatibility = make actions of type int
        td_error = batched_loss(q_tm1, a_tm1.squeeze().astype(int), r_t.squeeze(),
                                discount_t.squeeze(), q_t_val, q_t_select)
        return jnp.mean(rlax.l2_loss(td_error))

# Init Rollout/Learning Collector

In [37]:
num_actions = 3
epsilon_cfg = dict(init_value=1,
                   end_value=0.01,
                   transition_steps=1000,
                   power=1.)
target_period = 50
learning_rate = 0.005

parallel_episodes = 5
rng, rng_net, rng_episode = jax.random.split(rng, 3)
rng_batch = jax.random.split(rng, parallel_episodes)
agent = DQN(obs, num_actions, epsilon_cfg,
            target_period, learning_rate)
agent_params = agent.initial_params(rng_net)

In [38]:
collector = InterleavedDojo(agent, buffer,
                            push_buffer, sample_buffer,
                            step, reset, env_params)
collector.init_dojo(agent_params)

In [39]:
trace, reward = collector.episode_rollout(rng_episode)