# Example using DQN and Flashbax in gym environments

### [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/flashbax/blob/main/examples/gym_dqn_example.ipynb)

In [1]:
import random
import time
from typing import NamedTuple
from tqdm.auto import tqdm
from jax_tqdm import loop_tqdm
import plotly.express as px
import haiku as hk
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
import optax
import rlax
import chex


In [2]:
import flashbax as fbx

#### Define Network and data classes

In [3]:
def get_network_fn(num_outputs: int):
    """Define a fully connected multi-layer haiku network."""

    def network_fn(obs: chex.Array) -> chex.Array:
        conv1 = hk.Conv2D(output_channels=32, kernel_shape=8, stride=4)
        conv2 = hk.Conv2D(output_channels=64, kernel_shape=4, stride=2)
        conv3 = hk.Conv2D(output_channels=64, kernel_shape=3, stride=1)
        fc = hk.nets.MLP(
            output_sizes=[512, num_outputs],
            activation=jax.nn.relu,
            activate_final=False,
        )

        x = conv1(obs.astype(jnp.float32))
        x = jax.nn.relu(conv2(x))
        x = jax.nn.relu(conv3(x))
        x = x.reshape(-1)
        x = fc(x)

        return x

    return hk.without_apply_rng(hk.transform(network_fn))



class TrainState(NamedTuple):
    params: hk.Params
    target_params: hk.Params
    opt_state: optax.OptState



@chex.dataclass(frozen=True)

class TimeStep:
    observation: chex.Array
    action: chex.Array
    discount: chex.Array
    reward: chex.Array

#### Training Parameters

In [4]:
env_id = "ALE/Seaquest-v5"
seed = 42
num_envs = 1

total_timesteps = 100_000
learning_starts = 1_000

DISCOUNT = 0.99
BATCH_SIZE = 32
BUFFER_SIZE = 100_000
TARGET_NET_UPDATE_FREQ = 8_000
TRAIN_FREQ = 4

TAU = 1.0
GAMMA = 0.99
ALPHA, BETA = 0.5, 0.5

OPTIMIZER_PARAMS = {
    "learning_rate": 0.00025,
    "decay": 0.95,  # named `smoothing constant` in the paper
    "centered": True,
    "eps": 0.00001,
}

EPSILON_DECAY_PARAMS = {
    "epsilon_start": 0.1,
    "epsilon_end": 0,
    "decay_period": 100_000,
}

buffer_params = {
    "max_length": BUFFER_SIZE,
    "min_length": BATCH_SIZE,
    "sample_batch_size": BATCH_SIZE,
    "add_sequences": False,
    "add_batch_size": None,
    "priority_exponent": ALPHA,
}

#### Set up environment

In [5]:
# We then set up the environments
def make_env(env_id, seed):
    def thunk():
        env = gym.make(env_id)
        env = gym.wrappers.AutoResetWrapper(env)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.action_space.seed(seed)

        return env

    return thunk


if num_envs == 1:
    envs = make_env(env_id, seed)()
    assert isinstance(
        envs.action_space, gym.spaces.Discrete
    ), "only discrete action space is supported"
    num_actions = envs.action_space.n
else:
    envs = gym.vector.SyncVectorEnv(
        [make_env(env_id, seed + i) for i in range(num_envs)]
    )
    assert isinstance(
        envs.single_action_space, gym.spaces.Discrete
    ), "only discrete action space is supported"
    num_actions = envs.single_action_space.n

#### Train DQN agent

In [6]:
random.seed(seed)
np.random.seed(seed)
key = jax.random.PRNGKey(seed)
key, q_key = jax.random.split(key, 2)

q_network = get_network_fn(num_actions)
optim = optax.rmsprop(**OPTIMIZER_PARAMS)

dummy_obs, _ = envs.reset(seed=seed)
if num_envs > 1:
    dummy_obs = dummy_obs[0]
params = q_network.init(q_key, dummy_obs.astype(jnp.float32))
opt_state = optim.init(params)
q_state = TrainState(params=params, target_params=params, opt_state=opt_state)

buffer = fbx.make_prioritised_flat_buffer(**buffer_params)
buffer = buffer.replace(
    init=jax.jit(buffer.init),
    add=jax.jit(buffer.add, donate_argnums=0),
    sample=jax.jit(buffer.sample),
    can_sample=jax.jit(buffer.can_sample),
)

dummy_timestep = TimeStep(
    observation=dummy_obs,
    action=jnp.int32(0),
    reward=jnp.float32(0.0),
    discount=jnp.float32(0.0),
)
buffer_state = buffer.init(dummy_timestep)

def linear_decay(
    epsilon_start: float,
    epsilon_end: float,
    current_step: int,
    decay_period: int,
) -> float:
    decay_rate = (epsilon_start - epsilon_end) / decay_period
    new_epsilon = epsilon_start - current_step * decay_rate
    return jnp.maximum(jnp.float32(epsilon_end), new_epsilon)

@jax.jit
def update(q_state: TrainState, buffer_state, batch: TimeStep):
    """
    Computes the updated model parameters and optimizer states
    for a batch of experience.
    """

    def batch_apply(params: dict, observations: jnp.ndarray):
        return jax.vmap(q_network.apply, in_axes=(None, 0))(
            params,
            observations,
        )

    def loss_fn(params: dict, target_params: dict, batch):
        """Computes the Q-learning TD error for a batch of timesteps"""
        # TD error
        q_tm1 = batch_apply(params, batch.experience.first.observation)
        a_tm1 = batch.experience.first.action
        r_t = batch.experience.first.reward
        d_t = batch.experience.first.discount * GAMMA
        q_t = batch_apply(target_params, batch.experience.second.observation)
        q_t_select = batch_apply(params, batch.experience.second.observation)
        td_error = jax.vmap(rlax.double_q_learning)(
            q_tm1, a_tm1, r_t, d_t, q_t, q_t_select
        )

        # Priorities
        batch_loss = rlax.l2_loss(td_error)
        importance_weights = (1.0 / batch.priorities).astype(jnp.float32)
        importance_weights **= BETA
        importance_weights /= jnp.max(importance_weights)

        # Loss
        loss = jnp.mean(importance_weights * batch_loss)
        new_priorities = jnp.abs(td_error) + 1e-7
        return loss, new_priorities

    grads, new_priorities = jax.grad(loss_fn, has_aux=True)(
        q_state.params, q_state.target_params, batch
    )
    updates, new_opt_state = optim.update(grads, q_state.opt_state)
    new_params = optax.apply_updates(q_state.params, updates)
    q_state = q_state._replace(params=new_params, opt_state=new_opt_state)
    buffer_state = buffer.set_priorities(
        buffer_state, batch.indices, new_priorities
    )

    return q_state, buffer_state

@jax.jit
def action_select_fn(q_state: TrainState, obs: TimeStep):
    q_values = q_network.apply(q_state.params, obs)
    action = jnp.argmax(q_values, axis=-1)

    return action

@jax.jit
def perform_update(
    q_state: TrainState,
    buffer_state,
    sample_key: jax.random.PRNGKey,
):
    """Samples a batch from the replay buffer and updates network parameters."""
    batch = buffer.sample(buffer_state, sample_key)
    q_state, buffer_state = update(q_state, buffer_state, batch)

    return q_state, buffer_state



In [11]:
def update_step(
    current_step,
    learning_starts,
    train_frequency,
    buffer_state,
    key,
    q_state,
    target_network_frequency,
    tau,
):
    def train_update_fn(args):
        key, q_state, buffer_state = args
        key, sample_key = jax.random.split(key)
        q_state, buffer_state = perform_update(q_state, buffer_state, sample_key)
        return q_state, buffer_state

    def no_train_update_fn(args):
        """Bypasses the update step"""
        key, q_state, buffer_state = args
        return q_state, buffer_state

    def update_target_network_fn(q_state):
        q_state = q_state._replace(
            target_params=optax.incremental_update(
                q_state.params, q_state.target_params, tau
            )
        )
        return q_state

    def no_update_target_network_fn(q_state):
        """Bypasses the target network update"""
        return q_state

    # Check for training condition
    q_state, buffer_state = jax.lax.cond(
        (current_step > learning_starts)
        & (current_step % train_frequency == 0)
        & buffer.can_sample(buffer_state),
        train_update_fn,
        no_train_update_fn,
        operand=(key, q_state, buffer_state),
    )

    # Check for target network update condition
    q_state = jax.lax.cond(
        current_step % TARGET_NET_UPDATE_FREQ == 0,
        update_target_network_fn,
        no_update_target_network_fn,
        operand=q_state,
    )

    return q_state, buffer_state


def rollout(
    rng: jax.random.PRNGKey,
    total_timesteps: int,
    q_state: TrainState,
    buffer_state,
):
    def _conditional_reset(key):
        key, subkey = jax.random.split(key)
        obs = envs.reset()[0]
        return obs

    @jax.jit
    @loop_tqdm(total_timesteps)
    def _fori_body(current_step: int, val: tuple):
        (obs, q_state, buffer_state, rng, logs) = val
        rng, env_key, action_key, step_key = jax.random.split(rng, num=4)
        epsilon = linear_decay(current_step=current_step, **EPSILON_DECAY_PARAMS)

        explore = jax.random.uniform(env_key) < epsilon
        action = jax.lax.cond(
            explore,
            lambda _: envs.action_space.sample(),
            lambda _: action_select_fn(q_state, obs),
            operand=None,
        )
        obs, reward, done, _, _ = envs.step(action)

        logs["rewards"] = logs["rewards"].at[current_step].set(reward)
        logs["dones"] = logs["dones"].at[current_step].set(done)

        timestep = TimeStep(
            observation=obs,
            action=action,
            reward=reward,
            discount=jax.lax.select(done, 0.0, DISCOUNT),
        )
        buffer_state = buffer.add(buffer_state, timestep)

        q_state, buffer_state = update_step(
            current_step,
            learning_starts,
            TRAIN_FREQ,
            buffer_state,
            rng,
            q_state,
            TARGET_NET_UPDATE_FREQ,
            TAU,
        )

        # reset if done
        obs = jax.lax.cond(
            done,
            lambda _: _conditional_reset(env_key),
            lambda _: (obs),
            operand=None,
        )

        return (obs, q_state, buffer_state, rng, logs)

    logs = {
        "rewards": jnp.zeros(total_timesteps),
        "dones": jnp.zeros(total_timesteps),
    }
    obs = envs.reset()[0]
    init_val = (obs, q_state, buffer_state, rng, logs)
    (obs, q_state, buffer_state, rng, logs) = jax.lax.fori_loop(
        0, total_timesteps, _fori_body, init_val
    )

    return q_state, buffer_state, logs


q_state, buffer_state, logs = rollout(
    jax.random.PRNGKey(0), total_timesteps, q_state, buffer_state
)

TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[].
The error occurred while tracing the function wrapper_progress_bar at c:\Users\ryanp\anaconda3\lib\site-packages\jax_tqdm\pbar.py:87 for jit. This concrete value was not available in Python because it depends on the values of the arguments i, val[0], val[1].params['conv2_d']['b'], val[1].params['conv2_d']['w'], val[1].params['conv2_d_1']['b'], val[1].params['conv2_d_1']['w'], val[1].params['conv2_d_2']['b'], val[1].params['conv2_d_2']['w'], val[1].params['mlp/~/linear_0']['b'], val[1].params['mlp/~/linear_0']['w'], val[1].params['mlp/~/linear_1']['b'], val[1].params['mlp/~/linear_1']['w'], val[1].target_params['conv2_d']['b'], val[1].target_params['conv2_d']['w'], val[1].target_params['conv2_d_1']['b'], val[1].target_params['conv2_d_1']['w'], val[1].target_params['conv2_d_2']['b'], val[1].target_params['conv2_d_2']['w'], val[1].target_params['mlp/~/linear_0']['b'], val[1].target_params['mlp/~/linear_0']['w'], val[1].target_params['mlp/~/linear_1']['b'], val[1].target_params['mlp/~/linear_1']['w'], val[1].opt_state[0].mu['conv2_d']['b'], val[1].opt_state[0].mu['conv2_d']['w'], val[1].opt_state[0].mu['conv2_d_1']['b'], val[1].opt_state[0].mu['conv2_d_1']['w'], val[1].opt_state[0].mu['conv2_d_2']['b'], val[1].opt_state[0].mu['conv2_d_2']['w'], val[1].opt_state[0].mu['mlp/~/linear_0']['b'], val[1].opt_state[0].mu['mlp/~/linear_0']['w'], val[1].opt_state[0].mu['mlp/~/linear_1']['b'], val[1].opt_state[0].mu['mlp/~/linear_1']['w'], val[1].opt_state[0].nu['conv2_d']['b'], val[1].opt_state[0].nu['conv2_d']['w'], val[1].opt_state[0].nu['conv2_d_1']['b'], val[1].opt_state[0].nu['conv2_d_1']['w'], val[1].opt_state[0].nu['conv2_d_2']['b'], val[1].opt_state[0].nu['conv2_d_2']['w'], val[1].opt_state[0].nu['mlp/~/linear_0']['b'], val[1].opt_state[0].nu['mlp/~/linear_0']['w'], val[1].opt_state[0].nu['mlp/~/linear_1']['b'], val[1].opt_state[0].nu['mlp/~/linear_1']['w'], and val[3].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

#### Performance Evaluation

In [None]:
print("Evaluating...")
envs = make_env(env_id, seed)()
test_returns = []
obs, _ = envs.reset(seed=seed) # obs = np.array
for current_step in tqdm(range(10_000)):
    actions = action_select_fn(q_state, obs) # obs = np.array -> jnp.array
    actions = jax.device_get(actions) # actions = jnp.array -> np.array

    next_obs, rewards, terminated, truncated, infos = envs.step(actions)

    # Get Episode Return Statistics
    if "final_info" in infos:
        if isinstance(infos["final_info"], dict):
            # print(f"Evaluating Step : {global_step}, episodic_return={infos['episode']['r'][0]}")
            test_returns.append(infos['episode']['r'][0])

    # Update the observation
    obs = next_obs

envs.close()

Evaluating...


  0%|          | 0/10000 [00:00<?, ?it/s]

In [None]:
px.line(test_returns)