In [1]:
import sys
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import optax
import haiku as hk
import plotly.graph_objects as go
import numpy as np

from functools import partial
from jax import random, vmap, lax, tree_map
from chex import dataclass
from jax_tqdm import loop_tqdm
from typing import Tuple, List, Dict

sys.path.append("../../")
from jym import (
    Breakout,
    DQN_PER,
    per_rollout,
    SumTree,
    Experience,
    PrioritizedExperienceReplay,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# MinAtar Breakout params
DISCOUNT = 0.99
BATCH_SIZE = 32
BUFFER_SIZE = 100_000
TARGET_NET_UPDATE_FREQ = 1000

# Replay buffer params
ALPHA, BETA = 0.5, 0.5

# other params
RANDOM_SEED = 0
STATE_SHAPE = (10, 10, 4)
N_ACTIONS = 3

CONV_LAYER_PARAMS = {
    "output_channels": 16,
    "kernel_shape": 3,
    "stride": 1,
}
MLP_PARAMS = {
    "output_sizes": [128, N_ACTIONS],
    "activation": jax.nn.relu,
    "activate_final": False,
}
OPTIMIZER_PARAMS = {
    "learning_rate": 1e-4,
    "decay": 0.95,  # named `smoothing constant` in the paper
    "centered": True,
    "eps": 10e-2,
}
EPSILON_DECAY_PARAMS = {
    "epsilon_start": 0.1,
    "epsilon_end": 0,
    "decay_period": 100_000,
}

buffer_state = {
    "state": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "action": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "reward": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "next_state": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "done": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
    "priority": jnp.empty((BUFFER_SIZE), dtype=jnp.float32),
}
jax.tree_map(lambda x: x.shape, buffer_state)

{'action': (100000,),
 'done': (100000,),
 'next_state': (100000, 10, 10, 4),
 'priority': (100000,),
 'reward': (100000,),
 'state': (100000, 10, 10, 4)}

In [3]:
key = random.PRNGKey(0)
env = Breakout()
state, obs, env_key = env.reset(key)



In [4]:
per = PrioritizedExperienceReplay(BUFFER_SIZE, BATCH_SIZE, ALPHA, BETA)
tree_state = jnp.zeros(2 * BUFFER_SIZE - 1)
sum_tree = SumTree(BUFFER_SIZE, BATCH_SIZE)


@hk.transform
def model(x):
    """
    MinAtar version of DQN
    ref: https://github.com/kenjyoung/MinAtar/blob/master/examples/dqn.py
    """
    conv_layer = hk.Conv2D(**CONV_LAYER_PARAMS)
    fc = hk.nets.MLP(**MLP_PARAMS)

    x = jax.nn.relu(conv_layer(x))
    x = x.reshape(-1)
    return fc(x)


def linear_decay(
    epsilon_start: float,
    epsilon_end: float,
    current_step: int,
    decay_period: int,
) -> float:
    decay_rate = (epsilon_start - epsilon_end) / decay_period
    new_epsilon = epsilon_start - current_step * decay_rate
    return jnp.maximum(jnp.float32(epsilon_end), new_epsilon)


online_key, target_key = vmap(random.PRNGKey)(jnp.arange(2) + RANDOM_SEED)

online_net_params = model.init(online_key, random.normal(online_key, env.obs_shape))
target_net_params = model.init(target_key, random.normal(target_key, env.obs_shape))

optimizer = optax.rmsprop(**OPTIMIZER_PARAMS)
optimizer_state = optimizer.init(online_net_params)

agent = DQN_PER(model, DISCOUNT, len(env.actions))

In [14]:
rollout_params = {
    "timesteps": 100_000,
    "random_seed": RANDOM_SEED,
    "target_net_update_freq": TARGET_NET_UPDATE_FREQ,
    "model": model,
    "optimizer": optimizer,
    "buffer_state": buffer_state,
    "tree_state": tree_state,
    "agent": agent,
    "env": env,
    "state_shape": STATE_SHAPE,
    "buffer_size": BUFFER_SIZE,
    "batch_size": BATCH_SIZE,
    "alpha": ALPHA,
    "beta": BETA,
    "discount": DISCOUNT,
    "epsilon_decay_fn": linear_decay,
    "decay_params": EPSILON_DECAY_PARAMS,
}

out = per_rollout(**rollout_params)


scatter inputs have incompatible types: cannot safely cast value from dtype=int32 to dtype=bool with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bool with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.

Running for 100,000 iterations: 100%|██████████| 100000/100000 [06:11<00:00, 269.08it/s]


In [16]:
fig = px.line(out["losses"], title="Loss during training")
fig.show()

df = pd.DataFrame(
    data={
        "episode": out["all_done"].cumsum(),
        "reward": out["all_rewards"],
    },
)
df["episode"] = df["episode"].shift().fillna(0)
episodes_df = df.groupby("episode").agg("sum")

fig = px.line(episodes_df, y="reward", title=f"Performances of DQN on the Breakout Environment", )
fig.show()

### ***To fix: Episodes always last for 1000 steps***

In [None]:
df[df["episode"]==0]

Unnamed: 0,episode,reward
0,0.0,0.0
1,0.0,0.0
2,0.0,0.0
3,0.0,0.0
4,0.0,0.0
...,...,...
995,0.0,0.0
996,0.0,0.0
997,0.0,0.0
998,0.0,0.0


In [10]:
assert

SyntaxError: invalid syntax (2389114725.py, line 1)

## ***Debugging section***

@partial(vmap, in_axes=(None, None, None, None))
def compute_td_error(
    model: hk.Transformed,
    online_net_params: dict,
    target_net_params: dict,
    discount: float,
    state: jnp.ndarray,
    action: jnp.ndarray,
    reward: jnp.ndarray,
    next_state: jnp.ndarray,
    done: jnp.ndarray,
    priority: jnp.ndarray, # unused
) -> List[float]:
    td_target = (
        (1 - done)
        * discount
        * jnp.max(model.apply(target_net_params, None, next_state))
    )
    prediction = model.apply(online_net_params, None, state)[action]
    return reward + td_target - prediction

random_seed = 1
timesteps = 10
state_shape = (10, 10, 4)
buffer_size = 32

i = 0

init_key, action_key, buffer_key = vmap(random.PRNGKey)(jnp.arange(3) + random_seed)
state, obs, env_key = env.reset(init_key)
all_actions = jnp.zeros([timesteps])
all_obs = jnp.zeros([timesteps, *state_shape])
all_rewards = jnp.zeros([timesteps], dtype=jnp.float32)
all_done = jnp.zeros([timesteps], dtype=jnp.bool_)
losses = jnp.zeros([timesteps], dtype=jnp.float32)

online_net_params = model.init(init_key, jnp.zeros(state_shape))
target_net_params = model.init(action_key, jnp.zeros(state_shape))
optimizer_state = optimizer.init(online_net_params)
replay_buffer = PrioritizedExperienceReplay(buffer_size, 8, 0.5, 0.5)

replay_buffer.add(tree_state, buffer_state, i, experience)

epsilon = linear_decay(current_step=i, **EPSILON_DECAY_PARAMS)
action, action_key = agent.act(action_key, online_net_params, obs, epsilon)
state, next_state, reward, done, env_key = env.step(state, env_key, action)
experience = Experience(
    state=env._get_obs(state),
    action=action,
    reward=reward,
    next_state=next_state,
    done=done,
)

buffer_state, tree_state = replay_buffer.add(tree_state, buffer_state, i, experience)

(
    experiences_batch,
    sample_indexes,
    importance_weights,
    buffer_key,
) = replay_buffer.sample(buffer_key, buffer_state, tree_state)



(
    experiences_batch,
    sample_indexes,
    importance_weights,
    buffer_key,
) = replay_buffer.sample(
    buffer_key,
    buffer_state,
    tree_state,
)
experiences_batch =  experiences_batch[0]
td_errors = compute_td_error(
    model,
    online_net_params,
    target_net_params,
    0.99,
    **experiences_batch # (dict)[0] => dict
)
tree_state = replay_buffer.sum_tree.batch_update(
    tree_state, sample_indexes, jnp.abs(td_errors)
)

online_net_params, optimizer_state, loss = agent.update(
    online_net_params,
    target_net_params,
    optimizer,
    optimizer_state,
    importance_weights,
    experiences_batch,
)

# update the target parameters every ``target_net_update_freq`` steps
target_net_params = lax.cond(
    i % 10 == 0,
    lambda _: online_net_params,
    lambda _: target_net_params,
    operand=None,
)