# Prioritized Experience Replay

- toc: true 
- badges: true
- comments: true
- categories: [RL]

# Experience Replay

* Online Reinforcement Learning (RL) agents update their parameter as they make observation.
* Two Issues:
  * Strongly correlated updates (breaks i.i.d assumption of stochastic optimizers)
  * Rare events are immediately discarded
* _Experience Replay_ (*ER*) {% cite er %} stores the observation in a replay memory
  * Can mix recent and non-recent observations (no temporal correlation)
  * Rare event can be re-used
* ER reduce amount of observations required to learn and replace with computational and memory resources (cheaper)


# Prioritized ER
The intuition behind prioritized experience replay {% cite per %} is
* Some transitions are more informative than other
  * Transitions may be more or less surprising, redundant or task relevant
  * Some transition may not be immediately useful (RL agent not mature enough)
* Replay _important transitions_ more frequently and therefore learn more efficiently.

## Temporal Difference as proxy measurement
* The magnitude of temporal difference (TD) error $\delta$ measures how far the value is from next-step bootstrap estimate {% cite Andre98generalizedprioritized %}.
* Reasonable proxy as it indicates how "surprising" the transition is.

> Warning: Ignores inherent stochasticity in reward and transitions (poor TD estimates).
> Warning: Limitation from partial observability (unlearnable transitions)


## Greedy TD-error prioritization

* Algorithm
  * Stores last TD error of a transtion in replay buffer
  * Transition with largest abs. TD error get replayed. Q-learning update $\propto$ TD-error
  * New transition gets highest priority
* Substantial reduction in training effort
* Implementation: Binary heap for priority queue
  * Sampling: $O(1)$
  * Update: $O(log N)$

### Issues
* To avoid expensive sweep over entire replay memory, only replayed transition is updated
  > Warning: Transition with low TD-error in first visit may not be replayed for a while.

* Sensitive to noise (stochastic reward, bootstrapping noise)
* Prone to over-fitting as high error transitions get replayed frequently (lack diversity)

---

# Stochastic Prioritization

* Interpolates between greedy and uniform random sampling
* Probability of sampling transition $i$ is given by 


$$
P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}
$$

* where,
  * $p_i$ is the **priority of transition**
  * hyperparameter $\alpha$ determines how much prioritization to use (uniform sampling with $\alpha = 0$)
* monotonic in transition priority

## Variants
  * *Proportional prioritization*: $p_i = |\delta_i| + \epsilon$
  * *Rank based prioritization*: $p_i = \frac{1}{rank(i)}$ where $rank(i)$ is rank of transition if sorted $|\delta_i|$
    * insensitive to outlier and hence more robust

## Bias
* Stochastic updates relies on the update being from same distribution as its expectation
* Prioritized replay induce bias as it chages the update distribution
* bias correction using importance sampling weights
$$w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$$ 
* where, 
  * hyper parameter $\beta$ fully compensates for non-uniform $P(i)$ when $\beta = 1$.
  * Normalize weights by $\frac{1}{\max_i w_i}$ for stability. 
  * this is applied to Q-learning by $w_i\delta_i$ (weighted IS)
* Unbiased nature is most important towards the convergence at end of training.
> Note: Process is highly non-stationary (changing policy, state dist. and bootstrapped targets)
* Anneal the amount of correction over time.
* prioritization ensures high-error transitions are seen many times while IS correction reduces gradient magnitude.

---

# Implementation

## Rank based prioritization

* Approximate $cdf$ with piecewise linear function with $k$ equal size segments
* First, sample a segment and then sample uniformly within among transitions within segment
* for minibatch based learning:
  * choose $k$ to be size of minibatch
  * sample one transition from each segment

## Propotional prioritization

* Use sum tree


### Binary Segment Tree

* Efficiently calculate $\sum_k^i p_k^\alpha$, the cumulative probability,
which is needed to sample.
* To find $\min p_i^\alpha$, which is needed for $\frac{1}{\max_i w_i}$.
> Note: We can also use a min-heap for this.
Binary Segment Tree lets us calculate these in $\mathcal{O}(\log n)$
time, which is way more efficient that the naive $\mathcal{O}(n)$
approach.

A Binary Segment Tree is a data structure that allows answering range queries 
over an array effectively, while still being flexible enough to allow modifying the array.

This is how a binary segment tree works for sum (it is similar for minimum)
* Let $x_i$ where $i \in \{1, 2 \cdots, N\}$ be the list of $N$ values we want to represent.
* Let $b_{h,j}$ be the $j^{\mathop{th}}$ node at height $h^{\mathop{th}}$ in the binary tree.
* The two children of node $b_{h,j}$ are $b_{h+1,2j}$ and $b_{h+1,2j + 1}$.
* The leaf nodes are at height $H = \left\lceil {1 + \log_2 N} \right\rceil$ will have values of $x$.

Every node keeps the sum of the two child nodes. That is, the root node keeps the sum of the entire 
array of values. The left and right children of the root node keep the sum of the first half of the 
array andthe sum of the second half of the array, respectively. and so on

$$b_{h,j} = \sum_{k = (j -1) * 2^{H - h} + 1}^{j * 2^{H - h}} x_k$$
Number of nodes at height $h$,
$$N_h = \left\lceil{\frac{N}{H - h + 1}} \right\rceil$$
This is equal to the sum of nodes in all levels above $h$.
So we can use a single array $a$ to store the tree, where,
$$b_{h,j} \rightarrow a_{N_h + j}$$

Then child nodes of $a_i$ are $a_{2i}$ and $a_{2i + 1}$.
That is,
$$a_i = a_{2i} + a_{2i + 1}$$

This way of maintaining binary trees is very easy to program.
>Note: indexing starting from 1.
>Note: We use the same structure to compute the minimum.

In [None]:
import random

import numpy as np
from typing import List
from dataclasses import dataclass

@dataclass
class Transition():
    current_state = np.zeros(shape=(1, 10), dtype=np.float)
    action = 0
    reward = 0.0
    next_state = np.zeros(shape=(1, 10), dtype=np.float)
    done = 0


class PriorityExperienceReplayBuffer:
    def __init__(self, capacity, alpha):
        """
        ### Initialize
        """
        # We use a power of $2$ for capacity because it simplifies the code and debugging
        self.capacity = capacity
        # $\alpha$
        self.alpha = alpha

        # Maintain segment binary trees to take sum and find minimum over a range
        self.priority_sum = [0 for _ in range(2 * self.capacity)]
        self.priority_min = [float('inf') for _ in range(2 * self.capacity)]

        # Current max priority, $p$, to be assigned to new transitions
        self.max_priority = 1.

        # Arrays for buffer
        self.data = [Transition() for _ in range(capacity)]
        # We use cyclic buffers to store data, and `next_idx` keeps the index of the next empty
        # slot
        self.next_idx = 0

        # Size of the buffer
        self.size = 0

    def add(self, transition: Transition):
        """
        ### Add sample to queue
        """

        # Get next available slot
        idx = self.next_idx

        # store in the queue
        self.data[idx] = transition

        # Increment next available slot
        self.next_idx = (idx + 1) % self.capacity
        # Calculate the size
        self.size = min(self.capacity, self.size + 1)

        # $p_i^\alpha$, new samples get `max_priority`
        priority_alpha = self.max_priority ** self.alpha
        # Update the two segment trees for sum and minimum
        self._set_priority_min(idx, priority_alpha)
        self._set_priority_sum(idx, priority_alpha)

    def _set_priority_min(self, idx, priority_alpha):
        """
        #### Set priority in binary segment tree for minimum
        """

        # Leaf of the binary tree
        idx += self.capacity
        self.priority_min[idx] = priority_alpha

        # Update tree, by traversing along ancestors.
        # Continue until the root of the tree.
        while idx >= 2:
            # Get the index of the parent node
            idx //= 2
            # Value of the parent node is the minimum of it's two children
            self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])

    def _set_priority_sum(self, idx, priority):
        """
        #### Set priority in binary segment tree for sum
        """

        # Leaf of the binary tree
        idx += self.capacity
        # Set the priority at the leaf
        self.priority_sum[idx] = priority

        # Update tree, by traversing along ancestors.
        # Continue until the root of the tree.
        while idx >= 2:
            # Get the index of the parent node
            idx //= 2
            # Value of the parent node is the sum of it's two children
            self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]

    def _sum(self):
        """
        #### $\sum_k p_k^\alpha$
        """

        # The root node keeps the sum of all values
        return self.priority_sum[1]

    def _min(self):
        """
        #### $\min_k p_k^\alpha$
        """

        # The root node keeps the minimum of all values
        return self.priority_min[1]

    def find_prefix_sum_idx(self, prefix_sum):
        """
        #### Find largest $i$ such that $\sum_{k=1}^{i} p_k^\alpha  \le P$
        """

        # Start from the root
        idx = 1
        while idx < self.capacity:
            # If the sum of the left branch is higher than required sum
            if self.priority_sum[idx * 2] > prefix_sum:
                # Go to left branch of the tree
                idx = 2 * idx
            else:
                # Otherwise go to right branch and reduce the sum of left
                #  branch from required sum
                prefix_sum -= self.priority_sum[idx * 2]
                idx = 2 * idx + 1

        # We are at the leaf node. Subtract the capacity by the index in the tree
        # to get the index of actual value
        return idx - self.capacity

    def sample(self, batch_size, beta):
        """
        ### Sample from buffer
        """

        # Initialize samples
        samples = {
            'weights': np.zeros(shape=batch_size, dtype=np.float32),
            'indexes': np.zeros(shape=batch_size, dtype=np.int32),
        }

        # Get sample indexes
        for i in range(batch_size):
            p = random.random() * self._sum()
            idx = self.find_prefix_sum_idx(p)
            samples['indexes'][i] = idx

        # $\min_i P(i) = \frac{\min_i p_i^\alpha}{\sum_k p_k^\alpha}$
        prob_min = self._min() / self._sum()
        # $\max_i w_i = \bigg(\frac{1}{N} \frac{1}{\min_i P(i)}\bigg)^\beta$
        max_weight = (prob_min * self.size) ** (-beta)

        for i in range(batch_size):
            idx = samples['indexes'][i]
            # $P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}$
            prob = self.priority_sum[idx + self.capacity] / self._sum()
            # $w_i = \bigg(\frac{1}{N} \frac{1}{P(i)}\bigg)^\beta$
            weight = (prob * self.size) ** (-beta)
            # Normalize by $\frac{1}{\max_i w_i}$,
            #  which also cancels off the $\frac{1}{N}$ term
            samples['weights'][i] = weight / max_weight

        # Get samples data
        samples['transitions'] = [self.data[idx] for idx in samples['indexes']]

        return samples

    def update_priorities(self, indexes, priorities):
        """
        ### Update priorities
        """

        for idx, priority in zip(indexes, priorities):
            # Set current max priority
            self.max_priority = max(self.max_priority, priority)

            # Calculate $p_i^\alpha$
            priority_alpha = priority ** self.alpha
            # Update the trees
            self._set_priority_min(idx, priority_alpha)
            self._set_priority_sum(idx, priority_alpha)

    def is_full(self):
        """
        ### Whether the buffer is full
        """
        return self.capacity == self.size

# References
{% bibliography --cited %}