In [3]:
import numpy as np

In [12]:
class RingBuf:
    def __init__(self, size):
        # Pro-tip: when implementing a ring buffer, always allocate one extra element,
        # this way, self.start == self.end always means the buffer is EMPTY, whereas
        # if you allocate exactly the right number of elements, it could also mean
        # the buffer is full. This greatly simplifies the rest of the code.
        self.data = [None] * (size + 1)
        self.start = 0
        self.end = 0
        self.is_full = False
        
    def append(self, element):
        self.data[self.end] = element
        self.end = (self.end + 1) % len(self.data)
        # end == start and yet we just added one element. This means the buffer has one
        # too many element. Remove the first element by incrementing start.
        if self.end == self.start:
            self.start = (self.start + 1) % len(self.data)
            self.is_full = True
            
    def add(self, state, action, new_frame, reward, is_done):
        self.append((state, action, new_frame, reward, is_done))

    def sample_batch(self, length):
        """
        Returns states_batch, action_batch, next_states_batch, reward_batch, done_batch
        """
        if length > len(self):
            samples =  self.data
        if self.is_full:
            indices = np.random.randint(0, len(self.data), length)
            samples =  [self.data[i] for i in indices]
        else:
            indices = np.random.randint(self.start, self.end, length)
            samples =  [self.data[i] for i in indices]
        return map(np.array, zip(*samples))
            
        
    def __getitem__(self, idx):
        return self.data[(self.start + idx) % len(self.data)]
    
    def __len__(self):
        if self.end < self.start:
            return self.end + len(self.data) - self.start
        else:
            return self.end - self.start
        
    def __iter__(self):
        for i in range(len(self)):
            yield self[i]

In [13]:
test = RingBuf(100000)
for i in range(120000):
    test.append((np.array([i, i + 1, i + 2, i + 3]), "asd"))

In [14]:
i, i2 = test.sample_batch(32)
i, i2


(array([[ 56654,  56655,  56656,  56657],
        [ 29560,  29561,  29562,  29563],
        [ 96534,  96535,  96536,  96537],
        [ 78087,  78088,  78089,  78090],
        [ 81222,  81223,  81224,  81225],
        [110131, 110132, 110133, 110134],
        [ 20638,  20639,  20640,  20641],
        [ 63420,  63421,  63422,  63423],
        [ 49345,  49346,  49347,  49348],
        [118230, 118231, 118232, 118233],
        [ 24001,  24002,  24003,  24004],
        [ 75161,  75162,  75163,  75164],
        [ 52051,  52052,  52053,  52054],
        [ 66334,  66335,  66336,  66337],
        [ 68050,  68051,  68052,  68053],
        [101528, 101529, 101530, 101531],
        [101883, 101884, 101885, 101886],
        [ 50465,  50466,  50467,  50468],
        [ 35407,  35408,  35409,  35410],
        [ 88834,  88835,  88836,  88837],
        [ 71481,  71482,  71483,  71484],
        [117099, 117100, 117101, 117102],
        [ 91874,  91875,  91876,  91877],
        [ 31140,  31141,  31142,  

In [15]:
del test, i ,i2