In [63]:
import flashbax as fbx
import pandas as pd
from typing import NamedTuple
from tqdm.auto import tqdm
import haiku as hk
import jax
from jax import random, jit, vmap, tree_map, lax
from jax_tqdm import loop_tqdm
import jax.numpy as jnp
import plotly.express as px
import optax
import rlax
import chex
import gymnax

### ***Data Structures***

In [64]:
def get_network_fn(num_outputs: int):
    """Define a fully connected multi-layer haiku network."""

    def network_fn(obs: chex.Array) -> chex.Array:
        conv = hk.Conv2D(output_channels=16, kernel_shape=3, stride=1)
        fc = hk.nets.MLP(
            output_sizes=[128, num_outputs],
            activation=jax.nn.relu,
            activate_final=False,
        )
        x = jax.nn.relu(conv(obs))
        x = x.reshape(-1)
        x = fc(x)

        return x

    return hk.without_apply_rng(hk.transform(network_fn))


class TrainState(NamedTuple):
    params: hk.Params
    target_params: hk.Params
    opt_state: optax.OptState


@chex.dataclass(frozen=True)
class TimeStep:
    observation: chex.Array
    action: chex.Array
    discount: chex.Array
    reward: chex.Array

In [65]:
# We specify our parameters
env_id = "Freeway-MinAtar"
seed = 1
num_envs = 1

total_timesteps = 3000
learning_starts = 5_000
train_frequency = 1
target_network_frequency = 1000

tau = 1.0
learning_rate = 2.5e-4
start_e = 1.0
end_e = 0.1
duration = 100_000
gamma = 0.99

buffer_params = {
    "max_length": 100_000,
    "min_length": 32,
    "sample_batch_size": 32,
    "add_sequences": False,
    "add_batch_size": None,
    "priority_exponent": 0.5,
}

In [66]:
env, env_params = gymnax.make(env_id)
num_actions = env.num_actions

### ***DQN and Optimizer initialization***

In [67]:
key = random.PRNGKey(seed)
key, q_key = random.split(key, 2)

q_network = get_network_fn(num_actions)
optim = optax.adam(learning_rate=learning_rate)

dummy_obs, dummy_env_state = env.reset(key)
params = q_network.init(q_key, dummy_obs.astype(jnp.float32))
opt_state = optim.init(params)
q_state = TrainState(
    params=params,
    target_params=params,
    opt_state=opt_state,
)


Explicitly requested dtype <class 'jax.numpy.int64'> requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.


scatter inputs have incompatible types: cannot safely cast value from dtype=int32 to dtype=bool with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



### ***Flashbax Buffer initialization***

In [68]:
buffer = fbx.make_prioritised_flat_buffer(**buffer_params)
buffer = buffer.replace(
    init=jax.jit(buffer.init),
    add=jax.jit(buffer.add, donate_argnums=0),
    sample=jax.jit(buffer.sample),
    can_sample=jax.jit(buffer.can_sample),
)

dummy_timestep = TimeStep(
    observation=dummy_obs,
    action=jnp.int32(0),
    reward=jnp.float32(0.0),
    discount=jnp.float32(0.0),
)
buffer_state = buffer.init(dummy_timestep)


Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = 100000`.This allows one to control exactly how many timesteps are stored in the buffer.Note that this overrides the `max_length_time_axis` argument.



In [69]:
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    """Linear schedule function for the epsilon greedy exploration."""
    slope = (end_e - start_e) / duration
    return jnp.maximum(slope * t + start_e, end_e)


def huber_loss(x: chex.Array, delta: float = 1.0) -> chex.Array:
    """Huber loss, similar to L2 loss close to zero, L1 loss away from zero.

    See "Robust Estimation of a Location Parameter" by Huber.
    (https://projecteuclid.org/download/pdf_1/euclid.aoms/1177703732).

    Args:
      x: a vector of arbitrary shape.
      delta: the bounds for the huber loss transformation, defaults at 1.

    Note `grad(huber_loss(x))` is equivalent to `grad(0.5 * clip_gradient(x)**2)`.

    Returns:
      a vector of same shape of `x`.
    """
    chex.assert_type(x, float)

    # 0.5 * x^2                  if |x| <= d
    # 0.5 * d^2 + d * (|x| - d)  if |x| > d
    abs_x = jnp.abs(x)
    quadratic = jnp.minimum(abs_x, delta)
    # Same as max(abs_x - delta, 0) but avoids potentially doubling gradient.
    linear = abs_x - quadratic
    return 0.5 * quadratic**2 + delta * linear


@jit
def update(q_state: TrainState, batch: TimeStep):
    """
    Computes the updated model parameters and optimizer states
    for a batch of experience.
    """

    def batch_apply(params: dict, observations: jnp.ndarray):
        return vmap(q_network.apply, in_axes=(None, 0))(params, observations)

    def loss_fn(params: dict, target_params: dict, batch):
        """Computes the Q-learning TD error for a batch of timesteps"""
        q_tm1 = batch_apply(params, batch.first.observation)
        a_tm1 = batch.first.action
        r_t = batch.first.reward
        d_t = batch.first.discount * gamma
        q_t = batch_apply(target_params, batch.second.observation)
        td_error = vmap(rlax.q_learning)(q_tm1, a_tm1, r_t, d_t, q_t)
        
        return jnp.mean(huber_loss(td_error))

    loss, grads = jax.value_and_grad(loss_fn)(
        q_state.params, q_state.target_params, batch
    )
    updates, new_opt_state = optim.update(grads, q_state.opt_state)
    new_params = optax.apply_updates(q_state.params, updates)
    q_state = q_state._replace(params=new_params, opt_state=new_opt_state)

    return loss, q_state


@jit
def action_select_fn(q_state: TrainState, obs: TimeStep):
    q_values = q_network.apply(q_state.params, obs)
    action = jnp.argmax(q_values, axis=-1)

    return action


@jit
def perform_update(
    q_state: TrainState,
    buffer_state,
    sample_key: random.PRNGKey,
):
    data = buffer.sample(buffer_state, sample_key)
    loss, q_state = update(q_state, data.experience)

    return loss, q_state

In [70]:
def fill_buffer(
    rng: random.PRNGKey,
    total_timesteps: int,
    q_state: TrainState,
    buffer_state,
):
    def _conditional_reset(key):
        key, subkey = random.split(key)
        obs, env_state = env.reset(subkey)
        return obs, env_state

    @jit
    @loop_tqdm(total_timesteps, print_rate=int(total_timesteps / 100))
    def _fori_body(current_step: int, val: tuple):
        (obs, env_state, buffer_state, rng) = val
        rng, env_key, action_key, step_key = random.split(rng, num=4)

        action = env.action_space(env_params).sample(action_key)
        obs, env_state, reward, done, _ = env.step(step_key, env_state, action)

        timestep = TimeStep(
            observation=obs,
            action=action,
            reward=reward,
            discount=lax.select(done, 0.0, 0.99),
        )
        buffer_state = buffer.add(buffer_state, timestep)

        # reset if done
        obs, env_state = lax.cond(
            done,
            lambda _: _conditional_reset(env_key),
            lambda _: (obs, env_state),
            operand=None,
        )

        return (obs, env_state, buffer_state, rng)

    obs, env_state = env.reset(rng)
    init_val = (obs, env_state, buffer_state, rng)
    (obs, env_state, buffer_state, rng) = lax.fori_loop(
        0, total_timesteps, _fori_body, init_val
    )

    return buffer_state

In [71]:
buffer_state = fill_buffer(random.PRNGKey(0), 5_000, q_state, buffer_state)


Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.


Explicitly requested dtype <class 'jax.numpy.int64'> requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.


scatter inputs have incompatible types: cannot safely cast value from dtype=int32 to dtype=bool with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/5000 [00:00<?, ?it/s]

In [72]:
def tree_shape(tree):
    return tree_map(lambda x: x.shape, tree)

## ***GER***

In [73]:
def unwrap_batch(batch):
    """
    Converts a batch of experiences to batches of observations,
    actions, rewards and done flags. 
    """
    obs_tm1 = batch.first.observation
    obs_t = batch.second.observation
    a_tm1 = batch.first.action
    r_t = batch.first.reward
    d_t = batch.first.discount * gamma

    return {
        "obs_tm1": obs_tm1,
        "obs_t": obs_t,
        "a_tm1": a_tm1,
        "r_t": r_t,
        "d_t": d_t,
    }


def single_sample_loss(params, target_params, obs_tm1, obs_t, a_tm1, r_t, d_t):
    q_tm1 = q_network.apply(params, obs_tm1)
    q_t = q_network.apply(target_params, obs_t)
    td_error = rlax.q_learning(q_tm1, a_tm1, r_t, d_t, q_t)

    return huber_loss(td_error)

In [74]:
data = buffer.sample(buffer_state, random.PRNGKey(0))
per_sample_loss, per_sample_grads = vmap(
    jax.value_and_grad(single_sample_loss), in_axes=(None, None)
)(
    q_state.params,
    q_state.target_params,
    **unwrap_batch(data.experience),
)
tree_shape(per_sample_grads) # all parameters have an additional batch dimension

{'conv2_d': {'b': (32, 16), 'w': (32, 3, 3, 7, 16)},
 'mlp/~/linear_0': {'b': (32, 128), 'w': (32, 1600, 128)},
 'mlp/~/linear_1': {'b': (32, 3), 'w': (32, 128, 3)}}

In [75]:
grads_per = tree_map(
    lambda x: jnp.mean(x, axis=0),
    per_sample_grads,
)
grads_per

{'conv2_d': {'b': Array([ 0.0024511 ,  0.00449101,  0.00045069, -0.00088667,  0.00264114,
         -0.00114079,  0.00130348, -0.00130828,  0.00205221,  0.00036082,
          0.00605653,  0.00343748,  0.00446917,  0.00139386, -0.00043764,
          0.00256009], dtype=float32),
  'w': Array([[[[ 2.29266356e-04, -8.42906593e-05, -1.05393774e-05, ...,
            -2.49934783e-05, -7.44493518e-05,  5.17422777e-05],
           [ 9.14558070e-04,  1.74292261e-04, -5.73871890e-04, ...,
            -7.32992339e-05,  4.10810782e-04,  4.71305015e-04],
           [-9.74045979e-05,  4.58379625e-04,  2.31650527e-04, ...,
             8.21635331e-05,  0.00000000e+00, -4.73958280e-05],
           ...,
           [ 1.11865251e-04,  6.87608117e-05,  7.15275091e-05, ...,
            -2.90360884e-04,  0.00000000e+00,  1.70633339e-05],
           [ 4.08641281e-05,  1.11928784e-05, -4.21751320e-05, ...,
             1.35443497e-05,  1.95196801e-04, -1.16156982e-04],
           [-2.30927428e-04,  1.19236145e-

In [76]:
per_sample_norms = tree_map(
    lambda x: jnp.sqrt(
        jnp.sum(
            x**2,
            axis=tuple(range(1, x.ndim)), # sum over all dims except batch dim
        ),
    ),
    per_sample_grads,
)

tree_shape(per_sample_norms)

{'conv2_d': {'b': (32,), 'w': (32,)},
 'mlp/~/linear_0': {'b': (32,), 'w': (32,)},
 'mlp/~/linear_1': {'b': (32,), 'w': (32,)}}

In [77]:
tree, _ = jax.tree_util.tree_flatten(per_sample_norms)
jnp.array(tree).mean(axis=0)

Array([0.06935442, 0.01756447, 0.01155434, 0.03667615, 0.06808832,
       0.03059145, 0.08253931, 0.06296004, 0.00396752, 0.04660731,
       0.12405833, 0.05548818, 0.00023469, 0.10483947, 0.03016827,
       0.12912917, 0.09218165, 0.06446829, 0.0802656 , 0.09330662,
       0.09774294, 0.00037563, 0.01540068, 0.00419724, 0.08082651,
       0.03862018, 0.15389241, 0.14672226, 0.03395448, 0.0530236 ,
       0.1369601 , 0.03773179], dtype=float32)

## ***Normal Update***

In [78]:
@jit
def update(q_state: TrainState, batch: TimeStep):
    """
    Computes the updated model parameters and optimizer states
    for a batch of experience.
    """

    def batch_apply(params: dict, observations: jnp.ndarray):
        return vmap(q_network.apply, in_axes=(None, 0))(params, observations)

    def loss_fn(params: dict, target_params: dict, batch):
        """Computes the Q-learning TD error for a batch of timesteps"""
        q_tm1 = batch_apply(params, batch.first.observation)
        a_tm1 = batch.first.action
        r_t = batch.first.reward
        d_t = batch.first.discount * gamma
        q_t = batch_apply(target_params, batch.second.observation)
        td_error = vmap(rlax.q_learning)(q_tm1, a_tm1, r_t, d_t, q_t)
        
        return jnp.mean(huber_loss(td_error))

    loss, grads = jax.value_and_grad(loss_fn)(
        q_state.params, q_state.target_params, batch
    )
    updates, new_opt_state = optim.update(grads, q_state.opt_state)
    new_params = optax.apply_updates(q_state.params, updates)
    q_state = q_state._replace(params=new_params, opt_state=new_opt_state)

    return grads

In [79]:
grads = update(q_state, data.experience)
tree_shape(grads)

{'conv2_d': {'b': (16,), 'w': (3, 3, 7, 16)},
 'mlp/~/linear_0': {'b': (128,), 'w': (1600, 128)},
 'mlp/~/linear_1': {'b': (3,), 'w': (128, 3)}}

## ***Checking closeness between PER updates and DQN gradients***

In [80]:
# Some elements are not exactly equal (maybe because of rounding ?)
chex.assert_trees_all_close(grads, grads_per, rtol=1e-2)

In [81]:
# when taking the gradient means, we get equality with 1e-6 precision
grads_means = tree_map(lambda x: jnp.mean(x), grads)
grads_per_means = tree_map(lambda x: jnp.mean(x), grads_per)
chex.assert_trees_all_close(grads_means, grads_per_means, rtol=1e-6)

## ***Full GER update function***

In [99]:
@jit
def ger_update(q_state: TrainState, buffer_state, batch: TimeStep):
    """
    Computes the updated model parameters and optimizer states
    for a batch of experience.
    """

    def unwrap_batch(batch: TimeStep):
        """
        Converts a batch of experiences to batches of observations,
        actions, rewards and done flags.
        """
        obs_tm1 = batch.first.observation
        obs_t = batch.second.observation
        a_tm1 = batch.first.action
        r_t = batch.first.reward
        d_t = batch.first.discount * gamma

        return {
            "obs_tm1": obs_tm1,
            "obs_t": obs_t,
            "a_tm1": a_tm1,
            "r_t": r_t,
            "d_t": d_t,
        }

    def single_sample_loss(params, target_params, obs_tm1, obs_t, a_tm1, r_t, d_t):
        """Returns the huber loss of a single experience."""
        q_tm1 = q_network.apply(params, obs_tm1)
        q_t = q_network.apply(target_params, obs_t)
        td_error = rlax.q_learning(q_tm1, a_tm1, r_t, d_t, q_t)

        return huber_loss(td_error)

    # get per-sample losses and gradients
    per_sample_loss, per_sample_grads = vmap(
        jax.value_and_grad(single_sample_loss), in_axes=(None, None)
    )(q_state.params, q_state.target_params, **unwrap_batch(batch.experience))

    # compute per-sample gradient norms
    per_sample_norms = tree_map(
        lambda x: jnp.sqrt(
            jnp.sum(
                x**2,
                axis=tuple(range(1, x.ndim)),  # sum over all dims except batch dim
            ),
        ),
        per_sample_grads,
    )
    per_sample_norms, _ = jax.tree_util.tree_flatten(per_sample_norms)
    per_sample_norms = jnp.array(per_sample_norms).mean(axis=0)

    # importance sampling weights
    importance_weights = (1.0 / batch.priorities).astype(jnp.float32)
    importance_weights **= 0.5  # beta
    importance_weights /= jnp.max(importance_weights)

    # updating priorities
    new_priorities = per_sample_norms / importance_weights

    # grads = average per-sample grads across the batch dimension
    grads = tree_map(lambda x: jnp.mean(x, axis=0), per_sample_grads)
    # loss = average per-sample loss weighted by importance
    loss = jnp.mean(per_sample_loss * importance_weights)

    updates, new_opt_state = optim.update(grads, q_state.opt_state)
    new_params = optax.apply_updates(q_state.params, updates)
    q_state = q_state._replace(params=new_params, opt_state=new_opt_state)
    buffer_state = buffer.set_priorities(buffer_state, batch.indices, new_priorities)

    return loss, q_state, buffer_state

In [102]:
loss, q_state, buffer_state = ger_update(q_state, buffer_state, data)