In [1]:
import flashbax as fbx
import pandas as pd
from typing import NamedTuple
from tqdm.auto import tqdm
import haiku as hk
import jax
from jax import random, jit, vmap, tree_map, lax
from jax_tqdm import loop_tqdm
import jax.numpy as jnp
import plotly.express as px
import optax
import rlax
import chex
import gymnax

### ***Data Structures***

In [2]:
def get_network_fn(num_outputs: int):
    """Define a fully connected multi-layer haiku network."""

    def network_fn(obs: chex.Array, rng: chex.PRNGKey) -> chex.Array:
        return hk.Sequential(
            [
                hk.Flatten(),
                hk.Linear(256),
                jax.nn.leaky_relu,
                hk.Linear(128),
                jax.nn.leaky_relu,
                hk.Linear(num_outputs),
            ]
        )(obs)

    return hk.without_apply_rng(hk.transform(network_fn))


class TrainState(NamedTuple):
    params: hk.Params
    target_params: hk.Params
    opt_state: optax.OptState


@chex.dataclass(frozen=True)
class TimeStep:
    observation: chex.Array
    action: chex.Array
    discount: chex.Array
    reward: chex.Array

In [3]:
# We specify our parameters
env_id = "CartPole-v1"
seed = 1
num_envs = 1

total_timesteps = 50_000
learning_starts = 1_000
train_frequency = 5
target_network_frequency = 500

tau = 1.0
learning_rate = 1e-3
start_e = 1.0
end_e = 0.01
exploration_fraction = 0.5
gamma = 0.99
importance_sampling_exponent = 0.6

buffer_params = {
    "max_length": 50_000,
    "min_length": 128,
    "sample_batch_size": 128,
    "add_sequences": False,
    "add_batch_size": None,
    "priority_exponent": 0.6,
}

In [4]:
env, env_params = gymnax.make(env_id)
num_actions = env.num_actions

### ***DQN and Optimizer initialization***

In [5]:
key = random.PRNGKey(seed)
key, q_key = random.split(key, 2)

q_network = get_network_fn(num_actions)
optim = optax.adam(learning_rate=learning_rate)

dummy_obs, dummy_env_state = env.reset(key)
params = q_network.init(q_key, dummy_obs.astype(jnp.float32), None)
opt_state = optim.init(params)
q_state = TrainState(
    params=params,
    target_params=params,
    opt_state=opt_state,
)

### ***Flashbax Buffer initialization***

In [6]:
buffer = fbx.make_prioritised_flat_buffer(**buffer_params)
buffer = buffer.replace(
    init=jax.jit(buffer.init),
    add=jax.jit(buffer.add, donate_argnums=0),
    sample=jax.jit(buffer.sample),
    can_sample=jax.jit(buffer.can_sample),
)

dummy_timestep = TimeStep(
    observation=dummy_obs,
    action=jnp.int32(0),
    reward=jnp.float32(0.0),
    discount=jnp.float32(0.0),
)
buffer_state = buffer.init(dummy_timestep)



In [7]:
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    """Linear schedule function for the epsilon greedy exploration."""
    slope = (end_e - start_e) / duration
    return jnp.maximum(slope * t + start_e, end_e)




@jit
def update(q_state: TrainState, buffer_state, batch: TimeStep):
    """
    Computes the updated model parameters and optimizer states
    for a batch of experience.
    """

    def batch_apply(params: dict, observations: jnp.ndarray):
        return vmap(q_network.apply, in_axes=(None, 0, None))(
            params, observations, None
        )

    def loss_fn(params: dict, target_params: dict, batch):
        """Computes the Q-learning TD error for a batch of timesteps"""
        q_tm1 = batch_apply(params, batch.experience.first.observation)
        a_tm1 = batch.experience.first.action
        r_t = batch.experience.first.reward
        d_t = batch.experience.first.discount * gamma
        q_t = batch_apply(target_params, batch.experience.second.observation)
        q_t_select = batch_apply(params, batch.experience.second.observation)
        td_error = vmap(rlax.double_q_learning)(q_tm1, a_tm1, r_t, d_t, q_t, q_t_select)

        batch_loss = rlax.l2_loss(td_error)
        importance_weights = (1.0 / batch.priorities).astype(jnp.float32)
        importance_weights **= importance_sampling_exponent
        importance_weights /= jnp.max(importance_weights)

        loss = jnp.mean(importance_weights * batch_loss)
        new_priorities = jnp.abs(td_error) + 1e-7
        return loss, new_priorities

    grads, new_priorities = jax.grad(loss_fn, has_aux=True)(
        q_state.params, q_state.target_params, batch
    )
    updates, new_opt_state = optim.update(grads, q_state.opt_state)
    new_params = optax.apply_updates(q_state.params, updates)
    q_state = q_state._replace(params=new_params, opt_state=new_opt_state)
    buffer_state = buffer.set_priorities(buffer_state, batch.indices, new_priorities)

    return q_state, buffer_state


@jit
def action_select_fn(q_state: TrainState, obs: TimeStep):
    q_values = q_network.apply(q_state.params, obs, None)
    action = jnp.argmax(q_values, axis=-1)

    return action


@jit
def perform_update(
    q_state: TrainState,
    buffer_state,
    sample_key: random.PRNGKey,
):
    batch = buffer.sample(buffer_state, sample_key)
    q_state, buffer_state = update(q_state, buffer_state, batch)

    return q_state, buffer_state

In [8]:
def update_step(
    current_step,
    learning_starts,
    train_frequency,
    buffer_state,
    key,
    q_state,
    target_network_frequency,
    tau,
):
    def train_update_fn(args):
        key, q_state, buffer_state = args
        key, sample_key = jax.random.split(key)
        q_state, buffer_state = perform_update(q_state, buffer_state, sample_key)
        return q_state, buffer_state

    def no_train_update_fn(args):
        """Bypasses the update step"""
        key, q_state, buffer_state = args
        return q_state, buffer_state

    def update_target_network_fn(q_state):
        q_state = q_state._replace(
            target_params=optax.incremental_update(
                q_state.params, q_state.target_params, tau
            )
        )
        return q_state

    def no_update_target_network_fn(q_state):
        """Bypasses the target network update"""
        return q_state

    # Check for training condition
    q_state, buffer_state = lax.cond(
        (current_step > learning_starts)
        & (current_step % train_frequency == 0)
        & buffer.can_sample(buffer_state),
        train_update_fn,
        no_train_update_fn,
        operand=(key, q_state, buffer_state),
    )

    # Check for target network update condition
    q_state = lax.cond(
        current_step % target_network_frequency == 0,
        update_target_network_fn,
        no_update_target_network_fn,
        operand=q_state,
    )

    return q_state, buffer_state


def rollout(
    rng: random.PRNGKey,
    total_timesteps: int,
    q_state: TrainState,
    buffer_state,
):
    def _conditional_reset(key):
        key, subkey = random.split(key)
        obs, env_state = env.reset(subkey)
        return obs, env_state

    @jit
    @loop_tqdm(total_timesteps)
    def _fori_body(current_step: int, val: tuple):
        (obs, env_state, q_state, buffer_state, rng, logs) = val
        rng, env_key, action_key, step_key = random.split(rng, num=4)
        epsilon = linear_schedule(
            start_e, end_e, exploration_fraction * total_timesteps, current_step
        )

        explore = random.uniform(env_key) < epsilon
        action = lax.cond(
            explore,
            lambda _: env.action_space(env_params).sample(action_key),
            lambda _: action_select_fn(q_state, obs),
            operand=None,
        )
        obs, env_state, reward, done, _ = env.step(step_key, env_state, action)
        
        logs["rewards"] = logs["rewards"].at[current_step].set(reward)
        logs["dones"] = logs["dones"].at[current_step].set(done)

        timestep = TimeStep(
            observation=obs,
            action=action,
            reward=reward,
            discount=lax.select(done, 0.0, 0.99),
        )
        buffer_state = buffer.add(buffer_state, timestep)

        q_state, buffer_state = update_step(
            current_step,
            learning_starts,
            train_frequency,
            buffer_state,
            rng,
            q_state,
            target_network_frequency,
            tau,
        )

        # reset if done
        obs, env_state = lax.cond(
            done,
            lambda _: _conditional_reset(env_key),
            lambda _: (obs, env_state),
            operand=None,
        )

        return (obs, env_state, q_state, buffer_state, rng, logs)

    logs = {
        "rewards": jnp.zeros(total_timesteps),
        "dones": jnp.zeros(total_timesteps),
    }
    obs, env_state = env.reset(rng)
    init_val = (obs, env_state, q_state, buffer_state, rng, logs)
    (obs, env_state, q_state, buffer_state, rng, logs) = lax.fori_loop(
        0, total_timesteps, _fori_body, init_val
    )

    return q_state, buffer_state, logs


q_state, buffer_state, logs = rollout(random.PRNGKey(0), 50_000, q_state, buffer_state)

  return lax_numpy.astype(arr, dtype)


  0%|          | 0/50000 [00:00<?, ?it/s]

In [9]:
df = pd.DataFrame(
    data={
        "episode": logs["dones"].cumsum(),
        "reward": logs["rewards"],
    },
)
df["episode"] = df["episode"].shift().fillna(0)
episodes_df = df.groupby("episode").agg("sum")

px.line(
    episodes_df,
    y="reward",
    title=f"Performances of DQN on {env_id}",
)

### ***Performance Evaluation***

In [10]:
print("Evaluating...")

test_steps = 10_000
rng = hk.PRNGSequence(0)

logs_test = {
    "rewards": jnp.zeros(test_steps),
    "dones": jnp.zeros(test_steps),
}

obs, env_state = env.reset(next(rng))
for current_step in tqdm(range(test_steps)):
    action = action_select_fn(q_state, obs)
    obs, env_state, reward, done, _ = env.step(next(rng), env_state, action)
    logs_test["dones"] = logs_test["dones"].at[current_step].set(done)
    logs_test["rewards"] = logs_test["rewards"].at[current_step].set(reward)

    if done:
        obs, env_state = env.reset(next(rng))

df = pd.DataFrame(
    data={
        "episode": logs_test["dones"].cumsum(),
        "reward": logs_test["rewards"],
    },
)
df["episode"] = df["episode"].shift().fillna(0)
episodes_df = df.groupby("episode").agg("sum")

px.line(
    episodes_df,
    y="reward",
    title=f"Performances of DQN on {env_id}",
)

Evaluating...


  0%|          | 0/10000 [00:00<?, ?it/s]