In [1]:
import sys
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import optax
import haiku as hk
import plotly.graph_objects as go
import numpy as np

from jax import random, vmap, lax
from chex import dataclass
from jax_tqdm import loop_tqdm
from typing import Tuple, List

sys.path.append("../")
from src import Breakout, DQN, UniformReplayBuffer, minatar_rollout, BaseReplayBuffer, SumTree

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BUFFER_SIZE = 100_000
STATE_SHAPE = (10, 10, 4)

buffer_state = {
    "state": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "action": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "reward": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "next_state": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "done": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
    "td_error": jnp.empty((BUFFER_SIZE), dtype=jnp.float32),
}
jax.tree_map(lambda x: x.shape, buffer_state)

{'action': (100000,),
 'done': (100000,),
 'next_state': (100000, 10, 10, 4),
 'reward': (100000,),
 'state': (100000, 10, 10, 4),
 'td_error': (100000,)}

In [3]:
@dataclass
class Experience:
    state: jnp.ndarray
    action: int
    reward: float
    next_state: jnp.ndarray
    done: bool
    td_error: float


class PrioritizedExperienceReplay(BaseReplayBuffer):
    """
    Source: https://arxiv.org/pdf/1511.05952.pdf
    """

    def __init__(
        self, buffer_size: int, batch_size: int, alpha: float, beta: float
    ) -> None:
        super().__init__(buffer_size, batch_size)
        self.sum_tree = SumTree(buffer_size)
        self.alpha = alpha
        self.beta = beta

    def add(self, buffer_state: dict, experience: Experience, idx: int):
        for field in experience:
            # sumTree.add => tree, cursor => idx = cursor
            buffer_state[field] = buffer_state[field].at[idx].set(experience[field])
        return buffer_state

    def sample(self, tree: jnp.ndarray) -> List[Experience]:
        # compute the sampling probability
        priorities = tree[-self.buffer_size :]
        N = jnp.count_nonzero(priorities)
        probs = priorities**self.alpha
        P = probs / tree[0]  # tree[0] = probs.sum()

        # select sample indices based on P

        # compute importance sampling weight
        # compute td error
        # update transition priority
        # accumulate weight change
        pass

    # def _compute_importance_sampling(self, prob:float):
    # return jnp.power((self.buffer_size*prob), -self.beta) / jnp.max()

    def _compute_td_error(
        model: hk.Transformed,
        online_net_params: dict,
        target_net_params: dict,
        discount: float,
        experience: Experience,
    ):
        state, action, reward, next_state, done = experience
        # TODO: check wheter 1-done belongs here
        td_target = (
            (1 - done)
            * discount
            * jnp.max(model.apply(target_net_params, None, next_state))
        )
        prediction = model.apply(online_net_params, None, state)[action]
        return reward + td_target - prediction


In [4]:
capacity = 10  # number of leaf nodes
tree = jnp.zeros(2 * capacity - 1)  # parent + leaf nodes
sum_tree = SumTree(capacity)
tree, cursor = sum_tree.add(tree, 0.5, 0)
tree, cursor = sum_tree.add(tree, 0.7, cursor)
tree, cursor = sum_tree.add(tree, 1.2, cursor)
tree, cursor = sum_tree.add(tree, 1, cursor)
print(tree)
idx, sample_idx, tree_idx = sum_tree.get_leaf(tree, 1.3)
idx, sample_idx, tree_idx

[3.4 1.2 2.2 0.  1.2 2.2 0.  0.  0.  0.5 0.7 1.2 1.  0.  0.  0.  0.  0.
 0. ]


(Array(11, dtype=int32, weak_type=True),
 Array(2, dtype=int32, weak_type=True),
 Array(1.2, dtype=float32))

In [5]:
vmap(sum_tree.get_leaf, in_axes=(None, 0))(tree, jnp.arange(10, dtype=jnp.float32))

(Array([15, 10, 11, 12, 14, 14, 14, 14, 14, 14], dtype=int32, weak_type=True),
 Array([6, 1, 2, 3, 5, 5, 5, 5, 5, 5], dtype=int32, weak_type=True),
 Array([0. , 0.7, 1.2, 1. , 0. , 0. , 0. , 0. , 0. , 0. ], dtype=float32))

In [6]:
key = random.PRNGKey(0)
env = Breakout()
state, obs, env_key = env.reset(key)



In [7]:
exp = Experience(
    state=obs,
    action=jnp.int32(
        1,
    ),
    reward=jnp.float32(
        1,
    ),
    next_state=obs,
    done=jnp.bool_(False),
    td_error=jnp.float32(
        0.5,
    ),
)