In [1]:
%load_ext autoreload
%autoreload 2

import os, sys, random
path, _ = os.path.split(os.getcwd())
sys.path.append(path)

In [55]:
import numpy as np

class SumTree:
    write = 0

    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = numpy.zeros( 2*capacity - 1 )
        self.data = numpy.zeros( capacity, dtype=object )

    def _propagate(self, idx, change):
        parent = (idx - 1) // 2

        self.tree[parent] += change

        if parent != 0:
            self._propagate(parent, change)

    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1

        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s-self.tree[left])
    
    def __len__(self):
        return len(self.tree)

    def total(self):
        return self.tree[0]

    def add(self, p, data):
        idx = self.write + self.capacity - 1

        self.data[self.write] = data
        self.update(idx, p)

        self.write += 1
        if self.write >= self.capacity:
            self.write = 0

    def update(self, idx, p):
        change = p - self.tree[idx]

        self.tree[idx] = p
        self._propagate(idx, change)

    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1

        return (idx, self.tree[idx], self.data[dataIdx])

In [62]:

#-------------------- MEMORY --------------------------
class SumTreeMemory:   # stored as ( s, a, r, s_ ) in SumTree
    e = 0.01
    a = 0.6

    def __init__(self, capacity):
        self.tree = SumTree(capacity)

    def _getPriority(self, error):
        return (error + self.e) ** self.a

    def add(self, error, sample):
        p = self._getPriority(error)
        self.tree.add(p, sample) 

    def sample(self, n):
        batch = []
        segment = self.tree.total() / n

        for i in range(n):
            a = segment * i
            b = segment * (i + 1)

            s = np.random.uniform(a, b)
            (idx, p, data) = self.tree.get(s)
            batch.append( (idx, data) )

        return batch

    def update(self, idx, error):
        p = self._getPriority(error)
        self.tree.update(idx, p)
        
    def __len__(self):
        self.tree.total()

In [63]:
mem = Memory(101)

In [58]:
from collections import namedtuple
Test = namedtuple("Test", ["x", "y", "z"])


In [64]:
for i in range(300):
    test = Test(x=i, y = i%10, z = i%23)
    mem.add(sample=test, error= ((i + 1) / (i + 1)**i))

In [65]:
mem.sample(5)

[(132, Test(x=234, y=4, z=4), 0.337960148228581),
 (156, Test(x=258, y=8, z=5), 1.8334716605426205),
 (183, Test(x=285, y=5, z=9), 3.5912757267240534),
 (193, Test(x=295, y=5, z=19), 4.190174923038327),
 (110, Test(x=212, y=2, z=5), 5.326255046067846)]

In [66]:
len(mem.tree)

201