In [1]:
from heapq import heappush, heappop

In [2]:
heap = []

In [3]:
heappush(heap, (1, 'top'))
heappush(heap, (2, 'topish'))
heappush(heap, (3, 'bottomish'))
heappush(heap, (4, 'bottom'))

In [4]:
heap

[(1, 'top'), (2, 'topish'), (3, 'bottomish'), (4, 'bottom')]

In [6]:
heap[:32]

[(1, 'top'), (2, 'topish'), (3, 'bottomish'), (4, 'bottom')]

In [13]:
heap = []

In [15]:
for i in range(64):
    heappush(heap, (i, 1 / (i + 0.000001)))

In [16]:
heap

[(0, 1000000.0),
 (1, 0.9999990000010001),
 (2, 0.499999750000125),
 (3, 0.33333322222225925),
 (4, 0.24999993750001562),
 (5, 0.199999960000008),
 (6, 0.16666663888889352),
 (7, 0.1428571224489825),
 (8, 0.12499998437500197),
 (9, 0.11111109876543349),
 (10, 0.09999999000000101),
 (11, 0.09090908264462885),
 (12, 0.08333332638888948),
 (13, 0.07692307100591762),
 (14, 0.07142856632653098),
 (15, 0.06666666222222252),
 (16, 0.06249999609375024),
 (17, 0.05882352595155729),
 (18, 0.055555552469135974),
 (19, 0.05263157617728546),
 (20, 0.04999999750000012),
 (21, 0.04761904535147403),
 (22, 0.045454543388429844),
 (23, 0.04347825897920613),
 (24, 0.041666664930555625),
 (25, 0.03999999840000006),
 (26, 0.03846153698224858),
 (27, 0.037037035665294975),
 (28, 0.035714284438775556),
 (29, 0.034482757431629055),
 (30, 0.033333332222222255),
 (31, 0.032258063475546335),
 (32, 0.031249999023437534),
 (33, 0.030303029384756687),
 (34, 0.02941176384083048),
 (35, 0.028571427755102068),
 (36, 0

In [18]:
[x[1] for x in heap[:32]]

[1000000.0,
 0.9999990000010001,
 0.499999750000125,
 0.33333322222225925,
 0.24999993750001562,
 0.199999960000008,
 0.16666663888889352,
 0.1428571224489825,
 0.12499998437500197,
 0.11111109876543349,
 0.09999999000000101,
 0.09090908264462885,
 0.08333332638888948,
 0.07692307100591762,
 0.07142856632653098,
 0.06666666222222252,
 0.06249999609375024,
 0.05882352595155729,
 0.055555552469135974,
 0.05263157617728546,
 0.04999999750000012,
 0.04761904535147403,
 0.045454543388429844,
 0.04347825897920613,
 0.041666664930555625,
 0.03999999840000006,
 0.03846153698224858,
 0.037037035665294975,
 0.035714284438775556,
 0.034482757431629055,
 0.033333332222222255,
 0.032258063475546335]

In [172]:
"""A priority queue for storing previous experiences to sample from."""
from heapq import heappush, heappushpop
import numpy as np


class PrioritizedReplayQueue(object):
    """A prioritized replay queue for replaying previous experiences."""

    def __init__(self, size: int) -> None:
        """
        Initialize a new prioritized replay buffer with a given size.

        Args:
            size: the max number of experiences to store in the queue

        Returns:
            None

        """
        # type check the size parameter
        if not isinstance(size, int):
            raise TypeError('`size` must be of type int')
        # ensure the size is within a legal range of values
        if size <= 0:
            raise ValueError('`size` must be > 0')
        self.size = size
        # initialize the priority queue as a heap
        self.heap = []

    def __repr__(self) -> str:
        """Return an executable string representation of priority queue."""
        return '{}(size={})'.format(self.__class__.__name__, self.size)

    def push(self,
        s: np.ndarray,
        a: int,
        r: int,
        d: bool,
        s2: np.ndarray,
        priority: float
    ) -> None:
        """
        Push a new experience onto the queue.

        Args:
            s: the current state
            a: the action to get from current state `s` to next state `s2`
            r: the reward resulting from taking action `a` in state `s`
            d: the flag denoting whether the episode ended after action `a`
            s2: the next state from taking action `a` in state `s`
            priority: the priority of the item to push to the queue

        Returns:
            None

        """
        # if the heap has arrived at capacity, use push pop to add new items
        if len(self.heap) == self.size:
            heappushpop(self.heap, (priority, (s, a, r, d, s2)))
        # otherwise heap push the item onto the queue
        else:
            heappush(self.heap, (priority, (s, a, r, d, s2)))

    def sample(self, size: int=32) -> bool:
        """
        Return a random sample of items from the queue.

        Args:
            size: the number of items to sample and return

        Returns:
            A random sample from the queue sampled uniformly

        """
        # extract a sample from the heap (priorities are in increasing order)
        # i.e. the lowest priority value is the first item in the sample
        sample_batch = [x[1] for x in self.heap[-size:]]
        # initialize lists for each component of the batch
        s = [None] * len(sample_batch)
        a = [None] * len(sample_batch)
        r = [None] * len(sample_batch)
        d = [None] * len(sample_batch)
        s2 = [None] * len(sample_batch)
        # iterate over the indexes and copy references to the arrays
        for batch, sample in enumerate(sample_batch):
            _s, _a, _r, _d, _s2 = sample
            s[batch] = np.array(_s, copy=False)
            a[batch] = _a
            r[batch] = _r
            d[batch] = _d
            s2[batch] = np.array(_s2, copy=False)
        # convert the lists to arrays for returning for training
        return (
            np.array(s),
            np.array(a, dtype=np.uint8),
            np.array(r, dtype=np.int8),
            np.array(d, dtype=np.bool),
            np.array(s2),
        )


# explicitly define the outward facing API of this module
__all__ = [PrioritizedReplayQueue.__name__]


In [195]:
q = PrioritizedReplayQueue(10)

In [196]:
def random_tuple():
    s = np.random.random((1, 1, 4))
    a = np.random.randint(low=0, high=10)
    r = np.random.random() * 100
    d = bool(np.random.randint(low=0, high=2))
    s2 = np.random.random((1, 1, 4))
    return s, a, r, d, s2

In [197]:
for i in range(1, 20):
    q.push(*random_tuple(), priority=i)

In [198]:
len(q.heap)

10

In [199]:
q.heap

[(10,
  (array([[[0.89768666, 0.08018801, 0.10921835, 0.13151239]]]),
   6,
   34.49071486953673,
   True,
   array([[[0.10291338, 0.20737608, 0.52643718, 0.8765454 ]]]))),
 (11,
  (array([[[0.94195482, 0.16278842, 0.79449866, 0.07406484]]]),
   1,
   19.151353247570658,
   False,
   array([[[0.63640296, 0.07241644, 0.85594163, 0.56062317]]]))),
 (12,
  (array([[[0.64947839, 0.56147468, 0.37459193, 0.92035482]]]),
   7,
   57.504834137531304,
   False,
   array([[[0.92825072, 0.49151014, 0.10880605, 0.86987566]]]))),
 (14,
  (array([[[0.39882194, 0.37802063, 0.81945337, 0.39264481]]]),
   1,
   87.73986609079458,
   True,
   array([[[0.80591086, 0.20331002, 0.44784248, 0.66413382]]]))),
 (13,
  (array([[[0.93978472, 0.22903392, 0.96407705, 0.74970004]]]),
   0,
   3.08778315365269,
   True,
   array([[[0.39237506, 0.56643398, 0.07347348, 0.73485008]]]))),
 (16,
  (array([[[0.5586956 , 0.70812655, 0.97536463, 0.00831798]]]),
   5,
   21.072053867938244,
   True,
   array([[[0.81476428, 

In [200]:
s, a, r, d, s2 = q.sample(5)

In [201]:
s

array([[[[0.5586956 , 0.70812655, 0.97536463, 0.00831798]]],


       [[[0.2539993 , 0.97002454, 0.55672638, 0.64779817]]],


       [[[0.60129811, 0.50520402, 0.49664635, 0.36332866]]],


       [[[0.34308517, 0.26919868, 0.83251812, 0.59681577]]],


       [[[0.84523251, 0.29486836, 0.07159574, 0.98895607]]]])

In [202]:
a

array([5, 0, 3, 1, 4], dtype=uint8)

In [203]:
r

array([21, 30, 55, 18, 16], dtype=int8)

In [204]:
d

array([ True, False, False, False,  True])

In [205]:
s2

array([[[[0.81476428, 0.87760079, 0.76452989, 0.71162329]]],


       [[[0.72156479, 0.8033102 , 0.14247161, 0.48794647]]],


       [[[0.76395245, 0.23939793, 0.40208338, 0.61407713]]],


       [[[0.28542885, 0.42125413, 0.84564135, 0.2590048 ]]],


       [[[0.78238487, 0.49850343, 0.53255337, 0.59578296]]]])