In [65]:
import sys
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import optax
import haiku as hk
import plotly.graph_objects as go
import numpy as np

from functools import partial
from jax import random, vmap, lax, tree_map
from chex import dataclass
from jax_tqdm import loop_tqdm
from typing import Tuple, List

sys.path.append("../")
from jym import (
    Breakout,
    DQN,
    UniformReplayBuffer,
    minatar_rollout,
    BaseReplayBuffer,
    SumTree,
)

In [54]:
BUFFER_SIZE = 64
BATCH_SIZE = 8
STATE_SHAPE = (10, 10, 4)

buffer_state = {
    "state": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "action": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "reward": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "next_state": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "done": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
    "priority": jnp.empty((BUFFER_SIZE), dtype=jnp.float32),
}
jax.tree_map(lambda x: x.shape, buffer_state)

{'action': (64,),
 'done': (64,),
 'next_state': (64, 10, 10, 4),
 'priority': (64,),
 'reward': (64,),
 'state': (64, 10, 10, 4)}

In [60]:
@dataclass
class Experience:
    state: jnp.ndarray
    action: int
    reward: float
    next_state: jnp.ndarray
    done: bool
    priority: float = jnp.float32(0.0)


class PrioritizedExperienceReplay(BaseReplayBuffer):
    """
    Prioritized Experience Replay Buffer

    Source: https://arxiv.org/pdf/1511.05952.pdf
    """

    def __init__(
        self, buffer_size: int, batch_size: int, alpha: float, beta: float
    ) -> None:
        super().__init__(buffer_size, batch_size)
        self.sum_tree = SumTree(buffer_size)
        self.alpha = alpha
        self.beta = beta

    def add(
        self,
        tree_state: jnp.ndarray,
        buffer_state: dict,
        experience: Experience,
        idx: int,
    ) -> Tuple[dict, jnp.ndarray]:
        """
        Adds an experience to the replay buffer and
        its priority to the sum tree.
        """
        # assigns maximal priority to the new experience
        priorities = tree_state[-self.buffer_size :]
        max_priority = lax.select(
            jnp.count_nonzero(priorities) > 0,
            jnp.max(priorities),
            1.0,
        )
        experience = experience.replace(priority=max_priority)

        # add the experience to the sum tree and the replay buffer
        idx = idx % self.buffer_size
        tree_state = self.sum_tree.add(tree_state, max_priority, idx)

        # set experience fields
        for field in experience:
            buffer_state[field] = buffer_state[field].at[idx].set(experience[field])

        return buffer_state, tree_state

    def sample(
        self,
        key: random.PRNGKey,
        buffer_state: dict,
        tree_state: jnp.ndarray,
    ) -> dict[Experience]:
        @partial(vmap, in_axes=(0, None))
        def sample_batch(indexes, buffer) -> Tuple[Experience]:
            return tree_map(lambda x: x[indexes], buffer)

        # compute the sampling probability
        priorities = tree_state[-self.buffer_size :]
        total_priority = tree_state[0]

        probs = priorities**self.alpha
        P = probs / total_priority
        # TODO: determine how probabilities interact with the sum tree
        # and how to replace random.choice with tree sampling

        values = random.uniform(
            key,
            shape=(self.batch_size,),
            minval=0,
            maxval=total_priority,
        )
        samples_idx = self.sum_tree.sample_batch(tree_state, values)

        return (sample_batch(samples_idx, buffer_state),)

    # def _compute_importance_sampling(self, prob:float):
    # return jnp.power((self.buffer_size*prob), -self.beta) / jnp.max()

    def _compute_td_error(
        model: hk.Transformed,
        online_net_params: dict,
        target_net_params: dict,
        discount: float,
        experience: Experience,
    ):
        state, action, reward, next_state, done = experience
        # TODO: check wheter 1-done belongs here
        td_target = (
            (1 - done)
            * discount
            * jnp.max(model.apply(target_net_params, None, next_state))
        )
        prediction = model.apply(online_net_params, None, state)[action]
        return reward + td_target - prediction

In [61]:
key = random.PRNGKey(0)
env = Breakout()
state, obs, env_key = env.reset(key)



In [62]:
exp = Experience(
    state=obs,
    action=jnp.int32(
        1,
    ),
    reward=jnp.float32(
        1,
    ),
    next_state=obs,
    done=jnp.bool_(False),
)

per = PrioritizedExperienceReplay(BUFFER_SIZE, BATCH_SIZE, 0.5, 0.5)

In [64]:
tree_state = jnp.zeros(2 * BUFFER_SIZE - 1)
buffer_state, tree_state = per.add(tree_state, buffer_state, exp, 0)
tree_state



(Array([1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array(1, dtype=int32, weak_type=True))