In [2]:
import sys
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import optax
import haiku as hk
import plotly.graph_objects as go
import numpy as np

from functools import partial
from jax import random, vmap, lax, tree_map
from chex import dataclass
from jax_tqdm import loop_tqdm
from typing import Tuple, List

sys.path.append("../")
from jym import (
    Breakout,
    DQN,
    UniformReplayBuffer,
    minatar_rollout,
    BaseReplayBuffer,
    SumTree,
)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
BUFFER_SIZE = 64
BATCH_SIZE = 8
STATE_SHAPE = (10, 10, 4)

buffer_state = {
    "state": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "action": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "reward": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "next_state": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "done": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
    "priority": jnp.empty((BUFFER_SIZE), dtype=jnp.float32),
}
jax.tree_map(lambda x: x.shape, buffer_state)

{'action': (64,),
 'done': (64,),
 'next_state': (64, 10, 10, 4),
 'priority': (64,),
 'reward': (64,),
 'state': (64, 10, 10, 4)}

In [4]:
@dataclass
class Experience:
    state: jnp.ndarray
    action: int
    reward: float
    next_state: jnp.ndarray
    done: bool
    priority: float = jnp.float32(0.0)


class PrioritizedExperienceReplay(BaseReplayBuffer):
    """
    Prioritized Experience Replay Buffer

    Source: https://arxiv.org/pdf/1511.05952.pdf
    """

    def __init__(
        self, buffer_size: int, batch_size: int, alpha: float, beta: float
    ) -> None:
        super().__init__(buffer_size, batch_size)
        self.sum_tree = SumTree(buffer_size, batch_size)
        self.alpha = alpha
        self.beta = beta

    def add(
        self,
        tree_state: jnp.ndarray,
        buffer_state: dict,
        idx: int,
        experience: Experience,
    ) -> Tuple[dict, jnp.ndarray]:
        """
        Adds an experience to the replay buffer and
        its priority to the sum tree.
        """
        # assigns maximal priority to the new experience
        priorities = tree_state[-self.buffer_size :]
        max_priority = lax.select(
            jnp.count_nonzero(priorities) > 0,
            jnp.max(priorities),
            1.0,
        )
        experience = experience.replace(priority=max_priority)

        # add the experience to the sum tree and the replay buffer
        idx = idx % self.buffer_size
        tree_state = self.sum_tree.add(tree_state, idx, max_priority)

        # set experience fields
        for field in experience:
            buffer_state[field] = buffer_state[field].at[idx].set(experience[field])

        return buffer_state, tree_state

    def update(self, tree_state: jnp.ndarray, td_error: float, idx: int) -> jnp.ndarray:
        """
        Updates the priority of an experience using alpha.

        Returns:
            jnp.ndarray: the updated tre_state
        """
        priority = td_error**self.alpha
        return self.sum_tree.update(tree_state, idx, priority)

    def sample(
        self,
        key: random.PRNGKey,
        buffer_state: dict,
        tree_state: jnp.ndarray,
    ) -> Tuple[dict[Experience], List[float]]:
        """
        Samples from the sum tree using the cumulative probability
        distribution.

        Returns:
            Tuple[Experience]: a tuple of `capacity` experiences
        """

        @partial(vmap, in_axes=(0, None))
        def sample_experiences(indexes: List[int]) -> Tuple[Experience]:
            return tree_map(lambda x: x[indexes], buffer_state)

        # sample from the sum tree
        total_priority = tree_state[0]
        values = random.uniform(
            key,
            shape=(self.batch_size,),
            minval=0,
            maxval=total_priority,
        )
        _, samples_idx, leaf_values = self.sum_tree.sample_idx_batch(tree_state, values)

        # compute importance weights
        priorities = tree_state[-self.buffer_size :]
        N = jnp.count_nonzero(priorities)
        importance_weights = (1.0 / (N * leaf_values)) ** -self.beta
        # normalize weights
        importance_weights /= importance_weights.max()

        return sample_experiences(samples_idx), importance_weights

    def _compute_td_error(
        model: hk.Transformed,
        online_net_params: dict,
        target_net_params: dict,
        discount: float,
        experience: Experience,
    ) -> float:
        state, action, reward, next_state, done = experience
        td_target = (
            (1 - done)
            * discount
            * jnp.max(model.apply(target_net_params, None, next_state))
        )
        prediction = model.apply(online_net_params, None, state)[action]
        return reward + td_target - prediction

In [6]:
key = random.PRNGKey(0)
env = Breakout()
state, obs, env_key = env.reset(key)



In [7]:
exp = Experience(
    state=obs,
    action=jnp.int32(1),
    reward=jnp.float32(1),
    next_state=obs,
    done=jnp.bool_(False),
)

per = PrioritizedExperienceReplay(BUFFER_SIZE, BATCH_SIZE, 0.5, 0.5)

In [8]:
tree_state = jnp.zeros(2 * BUFFER_SIZE - 1)
tree_state

Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)