In [1]:
import jax
import gymnax

rng = jax.random.PRNGKey(0)
rng, key_reset, key_act, key_step = jax.random.split(rng, 4)

# Instantiate the environment & its settings.
env, env_params = gymnax.make("Breakout-MinAtar")

# Reset the environment.
obs, state = env.reset(key_reset, env_params)

# Sample a random action.
action = env.action_space(env_params).sample(key_act)

# Perform the step transition.
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

  ).astype(self.dtype)


In [2]:
env.action_set

Array([0, 1, 3], dtype=int32)

In [3]:
import sys

sys.path.append("../")

from functools import partial
from typing import Callable, List
import plotly.express as px
import pandas as pd
import haiku as hk
import jax.numpy as jnp
import optax
from jax import jit, lax, random, vmap
from jax_tqdm import loop_tqdm

from jym import (
    BaseDeepRLAgent,
    Experience,
    PrioritizedExperienceReplay,
    SumTree,
    DQN_PER,
)


@partial(vmap, in_axes=(None, None, None, None))
def compute_td_error(
    model: hk.Transformed,
    online_net_params: dict,
    target_net_params: dict,
    discount: float,
    state: jnp.ndarray,
    action: jnp.ndarray,
    reward: jnp.ndarray,
    next_state: jnp.ndarray,
    done: jnp.ndarray,
    priority: jnp.ndarray,  # unused
) -> List[float]:
    """
    Computes the td errors for a batch of experiences.
    Errors are clipped to [-1, 1] for statibility reasons.
    """
    td_target = (
        (1 - done)
        * discount
        * jnp.max(model.apply(target_net_params, None, next_state))
    )
    prediction = model.apply(online_net_params, None, state)[action]
    return jnp.clip(reward + td_target - prediction, a_min=-1, a_max=1)


def gymnax_rollout(
    timesteps: int,
    random_seed: int,
    target_net_update_freq: int,
    model: hk.Transformed,
    optimizer: optax.GradientTransformation,
    buffer_state: dict,
    tree_state: jnp.ndarray,
    agent: BaseDeepRLAgent,
    env_name: str,
    state_shape: int,
    buffer_size: int,
    batch_size: int,
    alpha: float,
    beta: float,
    discount: float,
    epsilon_decay_fn: Callable,
    decay_params: dict,
) -> dict[jnp.ndarray | dict]:
    @loop_tqdm(timesteps)
    @jit
    def _fori_body(i: int, val: tuple):
        (
            online_net_params,
            target_net_params,
            optimizer_state,
            buffer_state,
            tree_state,
            rng,
            state,
            obs,
            all_actions,
            all_obs,
            all_rewards,
            all_done,
            losses,
        ) = val

        rng, env_key, action_key, buffer_key = jax.random.split(rng, 4)
        epsilon = epsilon_decay_fn(current_step=i, **decay_params)
        action, action_key = agent.act(action_key, online_net_params, obs, epsilon)
        action = env.action_set[action]

        next_state, state, reward, done, _ = env.step(
            env_key, state, action, env_params
        )

        experience = Experience(
            state=env.get_obs(state),
            action=action,
            reward=reward,
            next_state=next_state,
            done=done,
        )

        buffer_state, tree_state = replay_buffer.add(
            tree_state, buffer_state, i, experience
        )

        (
            experiences_batch,
            sample_indexes,
            importance_weights,
            buffer_key,
        ) = replay_buffer.sample(buffer_key, buffer_state, tree_state)

        # compute individual td errors for the sampled batch and
        # update the tree state using the batched absolute td errors
        td_errors = compute_td_error(
            model, online_net_params, target_net_params, discount, **experiences_batch
        )
        tree_state = replay_buffer.sum_tree.batch_update(
            tree_state, sample_indexes, jnp.abs(td_errors)
        )
        
        online_net_params, optimizer_state, loss = agent.update(
            online_net_params,
            target_net_params,
            optimizer,
            optimizer_state,
            importance_weights,
            experiences_batch,
        )

        # update the target parameters every ``target_net_update_freq`` steps
        target_net_params = lax.cond(
            i % target_net_update_freq == 0,
            lambda _: online_net_params,
            lambda _: target_net_params,
            operand=None,
        )

        all_actions = all_actions.at[i].set(action)
        all_obs = all_obs.at[i].set(next_state)
        all_rewards = all_rewards.at[i].set(reward)
        all_done = all_done.at[i].set(done)
        losses = losses.at[i].set(loss)

        val = (
            online_net_params,
            target_net_params,
            optimizer_state,
            buffer_state,
            tree_state,
            rng,
            state,
            obs,
            all_actions,
            all_obs,
            all_rewards,
            all_done,
            losses,
        )

        return val

    key = random.PRNGKey(0)
    rng, init_key, action_key, buffer_key = jax.random.split(key, 4)
    env, env_params = gymnax.make(env_name)
    obs, state = env.reset(init_key, env_params)
    all_actions = jnp.zeros([timesteps])
    all_obs = jnp.zeros([timesteps, *state_shape])
    all_rewards = jnp.zeros([timesteps], dtype=jnp.float32)
    all_done = jnp.zeros([timesteps], dtype=jnp.bool_)
    losses = jnp.zeros([timesteps], dtype=jnp.float32)

    online_net_params = model.init(init_key, jnp.zeros(state_shape))
    target_net_params = model.init(init_key, jnp.zeros(state_shape))
    optimizer_state = optimizer.init(online_net_params)
    replay_buffer = PrioritizedExperienceReplay(buffer_size, batch_size, alpha, beta)

    val_init = (
        online_net_params,
        target_net_params,
        optimizer_state,
        buffer_state,
        tree_state,
        rng,
        state,
        obs,
        all_actions,
        all_obs,
        all_rewards,
        all_done,
        losses,
    )

    vals = lax.fori_loop(0, timesteps, _fori_body, val_init)
    output_dict = {}
    keys = [
        "online_net_params",
        "target_net_params",
        "optimizer_state",
        "buffer_state",
        "tree_state",
        "rng",
        "state",
        "obs",
        "all_actions",
        "all_obs",
        "all_rewards",
        "all_done",
        "losses",
    ]
    for idx, value in enumerate(vals):
        output_dict[keys[idx]] = value

    return output_dict

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# MinAtar Breakout params
DISCOUNT = 0.99
BATCH_SIZE = 32
BUFFER_SIZE = 100_000
TARGET_NET_UPDATE_FREQ = 1000

# Replay buffer params
ALPHA, BETA = 0.5, 1

# other params
RANDOM_SEED = 0
STATE_SHAPE = (10, 10, 4)
N_ACTIONS = 3

CONV_LAYER_PARAMS = {
    "output_channels": 16,
    "kernel_shape": 3,
    "stride": 1,
}
MLP_PARAMS = {
    "output_sizes": [128, N_ACTIONS],
    "activation": jax.nn.relu,
    "activate_final": False,
}
OPTIMIZER_PARAMS = {
    "learning_rate": 1e-4,
    "decay": 0.95,  # named `smoothing constant` in the paper
    "centered": True,
    "eps": 10e-2,
}

buffer_state = {
    "state": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "action": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "reward": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "next_state": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "done": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
    "priority": jnp.empty((BUFFER_SIZE), dtype=jnp.float32),
}
jax.tree_map(lambda x: x.shape, buffer_state)

{'action': (100000,),
 'done': (100000,),
 'next_state': (100000, 10, 10, 4),
 'priority': (100000,),
 'reward': (100000,),
 'state': (100000, 10, 10, 4)}

In [5]:
per = PrioritizedExperienceReplay(BUFFER_SIZE, BATCH_SIZE, ALPHA, BETA)
tree_state = jnp.zeros(2 * BUFFER_SIZE - 1)
sum_tree = SumTree(BUFFER_SIZE, BATCH_SIZE)


@hk.transform
def model(x):
    """
    MinAtar version of DQN
    ref: https://github.com/kenjyoung/MinAtar/blob/master/examples/dqn.py
    """
    conv_layer = hk.Conv2D(**CONV_LAYER_PARAMS)
    fc = hk.nets.MLP(**MLP_PARAMS)

    x = jax.nn.relu(conv_layer(x))
    x = x.reshape(-1)
    return fc(x)


# def linear_decay(
#     epsilon_start: float,
#     epsilon_end: float,
#     current_step: int,
#     decay_period: int,
# ) -> float:
#     decay_rate = (epsilon_start - epsilon_end) / decay_period
#     new_epsilon = epsilon_start - current_step * decay_rate
#     return jnp.maximum(jnp.float32(epsilon_end), new_epsilon)

def no_decay(
    epsilon_start: float,
    epsilon_end: float,
    current_step: int,
    decay_period: int,
) -> float:
    return epsilon_start


online_key, target_key = vmap(random.PRNGKey)(jnp.arange(2) + RANDOM_SEED)

online_net_params = model.init(online_key, random.normal(online_key, env.obs_shape))
target_net_params = model.init(target_key, random.normal(target_key, env.obs_shape))

optimizer = optax.rmsprop(**OPTIMIZER_PARAMS)
optimizer_state = optimizer.init(online_net_params)

agent = DQN_PER(model, DISCOUNT, len(env.action_set))

In [6]:
EPSILON_DECAY_PARAMS = {
    "epsilon_start": 0.1,
    "epsilon_end": 0,
    "decay_period": 100_000,
}

rollout_params = {
    "timesteps": 100_000,
    "random_seed": RANDOM_SEED,
    "target_net_update_freq": TARGET_NET_UPDATE_FREQ,
    "model": model,
    "optimizer": optimizer,
    "buffer_state": buffer_state,
    "tree_state": tree_state,
    "agent": agent,
    "env_name": "Breakout-MinAtar",
    "state_shape": STATE_SHAPE,
    "buffer_size": BUFFER_SIZE,
    "batch_size": BATCH_SIZE,
    "alpha": ALPHA,
    "beta": BETA,
    "discount": DISCOUNT,
    "epsilon_decay_fn": no_decay,
    "decay_params": EPSILON_DECAY_PARAMS,
}

out = gymnax_rollout(**rollout_params)

Running for 100,000 iterations: 100%|██████████| 100000/100000 [06:51<00:00, 243.13it/s]


In [7]:
fig = px.line(out["losses"], title="Loss during training")
fig.show()

df = pd.DataFrame(
    data={
        "episode": out["all_done"].cumsum(),
        "reward": out["all_rewards"],
    },
)
df["episode"] = df["episode"].shift().fillna(0)
episodes_df = df.groupby("episode").agg("sum")

fig = px.line(episodes_df, y="reward", title=f"Performances of DQN on the Breakout Environment", )
fig.show()