In [1]:
import sys
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import optax
import haiku as hk
import plotly.graph_objects as go
import numpy as np

from functools import partial
from jax import random, vmap, lax, tree_map
from chex import dataclass
from jax_tqdm import loop_tqdm
from typing import Tuple, List, Dict

sys.path.append("../")
from jym import (
    Breakout,
    DQN,
    UniformReplayBuffer,
    minatar_rollout,
    SumTree,
    Experience,
    PrioritizedExperienceReplay,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BUFFER_SIZE = 64
BATCH_SIZE = 8
DISCOUNT = 0.99
STATE_SHAPE = (10, 10, 4)
N_ACTIONS = 3
RANDOM_SEED = 0

CONV_LAYER_PARAMS = {
    "output_channels": 16,
    "kernel_shape": 3,
    "stride": 1,
}

MLP_PARAMS = {
    "output_sizes": [128, N_ACTIONS],
    "activation": jax.nn.relu,
    "activate_final": False,
}

OPTIMIZER_PARAMS = {
    "learning_rate": 1e-4,
    "decay": 0.95,  # named `smoothing constant` in the paper
    "centered": True,
    "eps": 10e-2,
}

EPSILON_DECAY_PARAMS = {
    "epsilon_start": 0.1,
    "epsilon_end": 0,
    "decay_period": 100_000,
}

buffer_state = {
    "state": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "action": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "reward": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "next_state": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "done": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
    "priority": jnp.empty((BUFFER_SIZE), dtype=jnp.float32),
}
jax.tree_map(lambda x: x.shape, buffer_state)

{'action': (64,),
 'done': (64,),
 'next_state': (64, 10, 10, 4),
 'priority': (64,),
 'reward': (64,),
 'state': (64, 10, 10, 4)}

In [3]:
key = random.PRNGKey(0)
env = Breakout()
state, obs, env_key = env.reset(key)



In [4]:
per = PrioritizedExperienceReplay(BUFFER_SIZE, BATCH_SIZE, 0.5, 0.5)
tree_state = jnp.zeros(2 * BUFFER_SIZE - 1)


@hk.transform
def model(x):
    """
    MinAtar version of DQN
    ref: https://github.com/kenjyoung/MinAtar/blob/master/examples/dqn.py
    """
    conv_layer = hk.Conv2D(**CONV_LAYER_PARAMS)
    fc = hk.nets.MLP(**MLP_PARAMS)

    x = jax.nn.relu(conv_layer(x))
    x = x.reshape(-1)
    return fc(x)


def linear_decay(
    epsilon_start: float,
    epsilon_end: float,
    current_step: int,
    decay_period: int,
) -> float:
    decay_rate = (epsilon_start - epsilon_end) / decay_period
    new_epsilon = epsilon_start - current_step * decay_rate
    return jnp.maximum(jnp.float32(epsilon_end), new_epsilon)


online_key, target_key = vmap(random.PRNGKey)(jnp.arange(2) + RANDOM_SEED)

online_net_params = model.init(online_key, random.normal(online_key, env.obs_shape))
target_net_params = model.init(target_key, random.normal(target_key, env.obs_shape))

optimizer = optax.rmsprop(**OPTIMIZER_PARAMS)
optimizer_state = optimizer.init(online_net_params)

agent = DQN(model, DISCOUNT, len(env.actions))

In [5]:
random_seed = 1
timesteps = 10
state_shape = (10, 10, 4)

i = 0

init_key, action_key, buffer_key = vmap(random.PRNGKey)(jnp.arange(3) + random_seed)
state, obs, env_key = env.reset(init_key)
all_actions = jnp.zeros([timesteps])
all_obs = jnp.zeros([timesteps, *state_shape])
all_rewards = jnp.zeros([timesteps], dtype=jnp.float32)
all_done = jnp.zeros([timesteps], dtype=jnp.bool_)
losses = jnp.zeros([timesteps], dtype=jnp.float32)

model_params = model.init(init_key, jnp.zeros(state_shape))
target_net_params = model.init(action_key, jnp.zeros(state_shape))
optimizer_state = optimizer.init(model_params)
replay_buffer = PrioritizedExperienceReplay(BUFFER_SIZE, BATCH_SIZE, 0.5, 0.5)

In [6]:
for i in range(BUFFER_SIZE):
    action, action_key = agent.act(action_key, model_params, obs, 1.0)
    state, next_state, reward, done, env_key = env.step(state, env_key, action)
    experience = Experience(
        state=env._get_obs(state),
        action=action,
        reward=reward,
        next_state=next_state,
        done=done,
    )

    buffer_state, tree_state = replay_buffer.add(
        tree_state, buffer_state, i, experience
    )



In [7]:
experiences_batch, sample_indexes, importance_weights, buffer_key = replay_buffer.sample(
    buffer_key,
    buffer_state,
    tree_state,
)

In [8]:
@partial(vmap, in_axes=(None, None, None, None))
def compute_td_error(
    model: hk.Transformed,
    online_net_params: dict,
    target_net_params: dict,
    discount: float,
    state: jnp.ndarray,
    action: jnp.ndarray,
    reward: jnp.ndarray,
    next_state: jnp.ndarray,
    done: jnp.ndarray,
    priority: jnp.ndarray,  # unused
) -> List[float]:
    td_target = (
        (1 - done)
        * discount
        * jnp.max(model.apply(target_net_params, None, next_state))
    )
    prediction = model.apply(online_net_params, None, state)[action]
    return reward + td_target - prediction


td_errors = compute_td_error(
    model, online_net_params, target_net_params, DISCOUNT, **experiences_batch
)
td_errors

Array([-0.00016929, -0.01929694,  0.08988249,  0.05142684, -0.09290352,
        0.0850728 ,  0.11186595,  0.09584332], dtype=float32)

In [9]:
sum_tree = SumTree(BUFFER_SIZE, BATCH_SIZE)


Array([ 5.63217201e+01,  2.90928230e+01,  2.72288990e+01,  1.50898829e+01,
        1.40029402e+01,  1.21438265e+01,  1.50850725e+01,  8.00000000e+00,
        7.08988237e+00,  8.00000000e+00,  6.00293970e+00,  6.11169672e+00,
        6.03213024e+00,  8.00000000e+00,  7.08507299e+00,  4.00000000e+00,
        4.00000000e+00,  3.08988237e+00,  4.00000000e+00,  4.00000000e+00,
        4.00000000e+00,  4.00000000e+00,  2.00293970e+00,  4.00000000e+00,
        2.11169672e+00,  3.05142689e+00,  2.98070312e+00,  4.00000000e+00,
        4.00000000e+00,  4.00000000e+00,  3.08507276e+00,  2.00000000e+00,
        2.00000000e+00,  2.00000000e+00,  2.00000000e+00,  2.00000000e+00,
        1.08988249e+00,  2.00000000e+00,  2.00000000e+00,  2.00000000e+00,
        2.00000000e+00,  2.00000000e+00,  2.00000000e+00,  2.00000000e+00,
        2.00000000e+00,  2.93982029e-03,  2.00000000e+00,  2.00000000e+00,
        2.00000000e+00,  9.99830723e-01,  1.11186600e+00,  2.00000000e+00,
        1.05142689e+00,  