In [1]:
import sys
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import optax
import haiku as hk
import plotly.graph_objects as go
import numpy as np

from jax import random, vmap, lax
from chex import dataclass
from jax_tqdm import loop_tqdm
from typing import Tuple

sys.path.append("../")
from src import Breakout, DQN, UniformReplayBuffer, minatar_rollout, BaseReplayBuffer

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
@dataclass
class Experience:
    state: jnp.ndarray
    action: int
    reward: float
    next_state: jnp.ndarray
    done: bool
    td_error: float


class PrioritizedExperienceReplay(BaseReplayBuffer):
    """
    Source: https://arxiv.org/pdf/1511.05952.pdf
    """

    def __init__(
        self, buffer_size: int, batch_size: int, alpha: float, beta: float
    ) -> None:
        super().__init__(buffer_size, batch_size)
        self.alpha = alpha
        self.beta = beta

    def add(self, buffer_state: dict, experience: Experience, idx: int):
        for field in experience:
            # sumTree.add => tree, cursor => idx = cursor
            buffer_state[field] = buffer_state[field].at[idx].set(experience[field])
        return buffer_state

    def sample(self):
        # implement sampling prob
        # compute importance sampling weight
        # compute td error
        # update transition priority
        # accumulate weight change
        pass

    # def _compute_importance_sampling(self, prob:float):
        # return jnp.power((self.buffer_size*prob), -self.beta) / jnp.max()

    def _compute_td_error(
        model: hk.Transformed,
        online_net_params: dict,
        target_net_params: dict,
        discount: float,
        experience: Experience,
    ):
        state, action, reward, next_state, done = experience
        # TODO: check wheter 1-done belongs here
        td_target = (
            (1 - done)
            * discount
            * jnp.max(model.apply(target_net_params, None, next_state))
        )
        prediction = model.apply(online_net_params, None, state)[action]
        return reward + td_target - prediction


class SumTree:
    def __init__(self, capacity: int) -> None:
        """
        Args:
            capacity (int): The maximum number of leaves (priorities/experiences)
            the tree can hold.
        """
        self.capacity = capacity

    def add(
        self, tree: jnp.ndarray, priority: float, cursor: int
    ) -> Tuple[jnp.ndarray, int]:
        """
        Add a new priority to the tree and update the cursor position.

        Args:
            tree (jnp.ndarray): The current state of the sum tree.
            priority (float): The priority value of the new experience.
            cursor (int): The current write cursor in the tree.

        Returns:
            Tuple[jnp.ndarray, int]: The updated tree and cursor.
        """
        idx = cursor + self.capacity - 1
        tree = self.update(tree, idx, priority)
        cursor = lax.select(cursor + 1 >= self.capacity, cursor + 1, 0)
        return tree, cursor

    @staticmethod
    def _propagate(tree: jnp.ndarray, idx: int, change: float) -> jnp.ndarray:
        """
        Propagate the changes in priority up the tree from a given index.

        Args:
            tree (jnp.ndarray): The current state of the sum tree.
            idx (int): The index of the tree where the priority was updated.
            change (float): The amount of change in priority.

        Returns:
            jnp.ndarray: The updated tree after propagation.
        """

        def _cond_fn(val: tuple):
            idx, _ = val
            return idx != 0

        def _while_body(val: tuple):
            idx, tree = val
            parent_idx = (idx - 1) // 2
            tree = tree.at[parent_idx].add(change)
            return parent_idx, tree

        val_init = (idx, tree)
        _, tree = lax.while_loop(_cond_fn, _while_body, val_init)
        return tree

    def update(self, tree: jnp.ndarray, idx: int, priority: float) -> jnp.ndarray:
        """
        Update a priority in the tree at a specific index and propagate the change.

        Args:
            tree (jnp.ndarray): The current state of the sum tree.
            idx (int): The index in the tree where the priority is to be updated.
            priority (float): The new priority value.

        Returns:
            jnp.ndarray: The updated tree after the priority change.
        """
        change = priority - tree.at[idx].get()
        tree = tree.at[idx].set(priority)
        return self._propagate(tree, idx, change)

    def get_leaf(self, tree: jnp.ndarray, value: float) -> Tuple[int, int, float]:
        """
        Retrieve the index and value of a leaf based on a given value.

        Args:
            tree (jnp.ndarray): The current state of the sum tree.
            value (float): A value to query the tree with.

        Returns:
            Tuple[int, int, float]: The index of the tree, index of the sample, and value of the leaf.
        """
        idx = self._retrieve(tree, 0, value)
        sample_idx = idx - len(tree) + 1
        return idx, sample_idx, tree[idx]

    def _retrieve(self, tree: jnp.ndarray, idx: int, value: float):
        def is_leaf():
            return idx

        def is_not_leaf():
            return lax.cond(
                value <= tree[left],
                lambda: self.retrieve(tree, left, value),
                lambda: self.retrieve(tree, right, value - tree[left]),
            )

        left = 2 * idx + 1
        right = left + 1
        is_leaf_node = left >= len(tree)

        return lax.cond(
            is_leaf_node,
            is_leaf,
            is_not_leaf,
        )

In [14]:
tree = jnp.zeros(8)
sum_tree = SumTree(len(tree))
# tree = sum_tree.propagate(tree, 5, 1)
print(tree)
tree = sum_tree.update(tree, 6, 5)
print(tree)
sum_tree.add(tree, 2, 4)

[0. 0. 0. 0. 0. 0. 0. 0.]
[5. 0. 5. 0. 0. 0. 5. 0.]


(Array([7., 0., 7., 0., 0., 2., 5., 0.], dtype=float32),
 Array(0, dtype=int32, weak_type=True))

In [None]:
key = random.PRNGKey(0)
env = Breakout()
state, obs, env_key = env.reset(key)



In [None]:
exp = Experience(
    state=obs,
    action=jnp.int32(
        1,
    ),
    reward=jnp.float32(
        1,
    ),
    next_state=obs,
    done=jnp.bool_(False),
    td_error=jnp.float32(
        0.5,
    ),
)

In [None]:
BUFFER_SIZE = 100_000
STATE_SHAPE = (10, 10, 4)

per = PrioritizedExperienceReplay(BUFFER_SIZE, 32)
buffer_state = {
    "state": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "action": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "reward": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "next_state": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "done": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
    "td_error": jnp.empty((BUFFER_SIZE), dtype=jnp.float32),
}
jax.tree_map(lambda x: x.shape, buffer_state)

{'action': (100000,),
 'done': (100000,),
 'next_state': (100000, 10, 10, 4),
 'reward': (100000,),
 'state': (100000, 10, 10, 4),
 'td_error': (100000,)}