In [1]:
import numpy


# SumTree
# a binary tree data structure where the parent’s value is the sum of its children
class SumTree:
    write = 0

    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = numpy.zeros(2 * capacity - 1)
        self.data = numpy.zeros(capacity, dtype=object)
        self.n_entries = 0

    # update to the root node
    def _propagate(self, idx, change):
        parent = (idx - 1) // 2

        self.tree[parent] += change

        if parent != 0:
            self._propagate(parent, change)

    # find sample on leaf node
    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1

        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s - self.tree[left])

    def total(self):
        return self.tree[0]

    # store priority and sample
    def add(self, p, data):
        idx = self.write + self.capacity - 1

        self.data[self.write] = data
        self.update(idx, p)

        self.write += 1
        if self.write >= self.capacity:
            self.write = 0

        if self.n_entries < self.capacity:
            self.n_entries += 1

    # update priority
    def update(self, idx, p):
        change = p - self.tree[idx]

        self.tree[idx] = p
        self._propagate(idx, change)

    # get priority and sample
    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1

        return (idx, self.tree[idx], self.data[dataIdx])

In [12]:
st = SumTree(5)

In [3]:
st

<__main__.SumTree at 0x7f0564452128>

In [4]:
st.data

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=object)

In [5]:
st.tree

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0.])

In [13]:
st.add(.05, '5')
st.add(.2, '20')
st.add(.4, '40')
st.add(.25, '25')
st.add(.1, '10')

In [14]:
st.data

array(['5', '20', '40', '25', '10'], dtype=object)

In [15]:
st.tree

array([1.  , 0.4 , 0.6 , 0.35, 0.05, 0.2 , 0.4 , 0.25, 0.1 ])

In [9]:
st.add(2, 'bonjour')

In [10]:
st.data

array(['allo', 'bonjour', 0, 0, 0, 0, 0, 0, 0, 0], dtype=object)

In [11]:
st.tree

array([3., 3., 0., 0., 3., 0., 0., 0., 0., 1., 2., 0., 0., 0., 0., 0., 0.,
       0., 0.])

In [31]:
import random
from collections import defaultdict
d = defaultdict(int)
for _ in range(10000000):
    val = st.get(random.uniform(0, 1))[2]
    d[val] += 1

In [32]:
d

defaultdict(int,
            {'40': 3999673,
             '25': 2501319,
             '20': 2000137,
             '10': 998304,
             '5': 500567})

In [33]:
1+1

2

In [34]:
state

NameError: name 'state' is not defined