In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

In [2]:
import jax
import jax.numpy as jnp
import haiku as hk
from haiku import nets
import optax
import rlax
import collections

import gymnax
from gymnax.dojos import InterleavedDojo, EvaluationDojo
from gymnax.utils import init_buffer, push_buffer, sample_buffer



# Import gymnax catch bsuite environment + Init replay buffer

In [3]:
rng, reset, step, env_params = gymnax.make("Catch-bsuite")
print(env_params)
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)
action = jnp.array([0])

capacity = 2000
buffer = init_buffer(state, obs, action, capacity)

{'max_steps_in_episode': 2000}


# Define the DQN Agent

In [4]:
Params = collections.namedtuple("Params", "online target")
ActorState = collections.namedtuple("ActorState", "count evaluation")
ActorOutput = collections.namedtuple("ActorOutput", "actions q_values")
LearnerState = collections.namedtuple("LearnerState", "count opt_state")


def build_network(num_actions: int) -> hk.Transformed:
    """Factory for a simple MLP network for approximating Q-values."""
    def q(obs):
        network = hk.Sequential(
            [hk.Flatten(),
             nets.MLP([50, num_actions])])
        return network(obs)
    return hk.without_apply_rng(hk.transform(q, apply_rng=True))


class DQN:
    """A simple DQN agent."""
    def __init__(self, obs_template, num_actions, epsilon_cfg,
                 target_period, learning_rate):
        self._obs_template = obs_template
        self._num_actions = num_actions
        self._target_period = target_period
        # Neural net and optimiser.
        self._network = build_network(num_actions)
        self._optimizer = optax.adam(learning_rate)
        self._epsilon_by_frame = optax.polynomial_schedule(**epsilon_cfg)

    def initial_params(self, key):
        sample_input = jnp.expand_dims(self._obs_template, 0)
        online_params = self._network.init(key, sample_input)
        return Params(online_params, online_params)

    def init_actor_state(self, evaluate=False):
        actor_count = jnp.zeros((), dtype=jnp.float32)
        return ActorState(actor_count, evaluate)

    def init_learner_state(self, params):
        learner_count = jnp.zeros((), dtype=jnp.float32)
        opt_state = self._optimizer.init(params.online)
        return LearnerState(learner_count, opt_state)

    def actor_step(self, key, params, obs, actor_state):
        obs = jnp.expand_dims(obs, 0)  # add dummy batch
        q = self._network.apply(params.online, obs)[0]    # remove dummy batch
        epsilon = self._epsilon_by_frame(actor_state.count)
        train_a = rlax.epsilon_greedy(epsilon).sample(key, q)
        eval_a = rlax.greedy().sample(key, q)
        a = jax.lax.select(actor_state.evaluation, eval_a, train_a)
        return (a, ActorState(actor_state.count + 1, bool(0)))

    def learner_step(self, key, params, learner_state, data):
        target_params = rlax.periodic_update(
            params.online, params.target,
            learner_state.count, self._target_period)
        dloss_dtheta = jax.grad(self._loss)(params.online,
                                            target_params,
                                           data["obs"], data["action"],
                                           data["reward"], data["done"],
                                           data["next_obs"])
        updates, opt_state = self._optimizer.update(dloss_dtheta,
                                                    learner_state.opt_state)
        online_params = optax.apply_updates(params.online, updates)
        return (Params(online_params, target_params),
                LearnerState(learner_state.count + 1, opt_state))

    def _loss(self, online_params, target_params,
              obs_tm1, a_tm1, r_t, discount_t, obs_t):
        q_tm1 = self._network.apply(online_params, obs_tm1)
        q_t_val = self._network.apply(target_params, obs_t)
        q_t_select = self._network.apply(online_params, obs_t)
        batched_loss = jax.vmap(rlax.double_q_learning)
        # TODO: Problems with chex in rlax function!
        # Rank compatibility = squeeze inputs
        # Type compatibility = make actions of type int
        td_error = batched_loss(q_tm1, a_tm1.squeeze().astype(int), r_t.squeeze(),
                                discount_t.squeeze(), q_t_val, q_t_select)
        return jnp.mean(rlax.l2_loss(td_error))

# Init Rollout/Learning Collector

In [5]:
num_actions = 3
epsilon_cfg = dict(init_value=1,
                   end_value=0.01,
                   transition_steps=1000,
                   power=1.)
target_period = 50
learning_rate = 0.005
num_steps = 2000

rng, rng_net, rng_episode = jax.random.split(rng, 3)
agent = DQN(obs, num_actions, epsilon_cfg,
            target_period, learning_rate)
agent_params = agent.initial_params(rng_net)
print(agent_params.online['mlp/~/linear_0']['b'])

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0.]


In [6]:
collector = InterleavedDojo(agent, buffer,
                            push_buffer, sample_buffer,
                            step, reset, env_params)
collector.init_dojo(agent_params)

evaluator = EvaluationDojo(agent, step, reset, env_params)
evaluator.init_dojo()

trace, reward = collector.steps_rollout(rng_episode, num_steps)

In [7]:
rng_evals = jax.random.split(rng, 10)
_, reward = evaluator.batch_rollout(rng_evals, 9,
                                    collector.agent_params)

In [8]:
jnp.sum(reward, axis=1).mean()

DeviceArray(-0.8, dtype=float32)

In [9]:
# rng, obs, state, env_params, agent_params, actor_state, learner_state
trace[4].online['mlp/~/linear_0']['b']

DeviceArray([-0.05110177, -0.07870402, -0.20903331,  0.00914063,
             -0.12843364, -0.2133825 , -0.3431929 , -0.17656963,
             -0.03682008, -0.15827283, -0.16298941, -0.14577693,
             -0.11531142, -0.03487734, -0.09916329, -0.09522908,
             -0.23822062, -0.1428581 ,  0.0028815 , -0.13443884,
             -0.10411434, -0.15547672, -0.20288026, -0.08027034,
             -0.1311672 , -0.09274957, -0.17738967,  0.00871049,
             -0.13293147, -0.18503276, -0.2511608 , -0.01409638,
             -0.02259834, -0.15740563, -0.13881627, -0.23443146,
             -0.12356479, -0.20620756, -0.15213154, -0.18423879,
             -0.06680408, -0.26913723, -0.2541214 , -0.06354713,
             -0.07263365, -0.25095072, -0.12604818, -0.2254671 ,
             -0.02294958, -0.22215189], dtype=float32)

In [10]:
trace[4].target['mlp/~/linear_0']['b']

DeviceArray([-0.05408095, -0.07936773, -0.20892192,  0.00642701,
             -0.12937585, -0.2133825 , -0.3431929 , -0.17696036,
             -0.03570051, -0.15827283, -0.16146871, -0.14700289,
             -0.11531142, -0.03505527, -0.0974686 , -0.09397108,
             -0.23811772, -0.14034916,  0.00101197, -0.13442558,
             -0.10348068, -0.15548068, -0.20295212, -0.08296234,
             -0.13158248, -0.08579965, -0.18577208,  0.00947669,
             -0.13223554, -0.18503276, -0.24886422, -0.01140631,
             -0.02271254, -0.15776223, -0.134745  , -0.23800932,
             -0.12370464, -0.20493427, -0.15213154, -0.18423711,
             -0.06602716, -0.27030075, -0.2541113 , -0.06344025,
             -0.07198612, -0.25095072, -0.12604818, -0.22617814,
             -0.02164434, -0.2215555 ], dtype=float32)

In [11]:
collector.agent_params.online['mlp/~/linear_0']['b']

DeviceArray([-0.05110177, -0.07870402, -0.20903331,  0.00914063,
             -0.12843364, -0.2133825 , -0.3431929 , -0.17656963,
             -0.03682008, -0.15827283, -0.16298941, -0.14577693,
             -0.11531142, -0.03487734, -0.09916329, -0.09522908,
             -0.23822062, -0.1428581 ,  0.0028815 , -0.13443884,
             -0.10411434, -0.15547672, -0.20288026, -0.08027034,
             -0.1311672 , -0.09274957, -0.17738967,  0.00871049,
             -0.13293147, -0.18503276, -0.2511608 , -0.01409638,
             -0.02259834, -0.15740563, -0.13881627, -0.23443146,
             -0.12356479, -0.20620756, -0.15213154, -0.18423879,
             -0.06680408, -0.26913723, -0.2541214 , -0.06354713,
             -0.07263365, -0.25095072, -0.12604818, -0.2254671 ,
             -0.02294958, -0.22215189], dtype=float32)

In [12]:
import shlex

In [17]:
enable_conda = ("source $(conda info --base)/etc/profile.d/conda.sh "
                "&& conda activate {remote_env_name}")

In [18]:
enable_conda

'source $(conda info --base)/etc/profile.d/conda.sh && conda activate {remote_env_name}'