In [22]:
import numpy as np
import numpy.random as rd
import time

ask the ***now_tree_max_len*** first !

In [23]:
class BinarySearchTree:
    """Binary Tree for PER
    Contributor: Github GyChou
    Reference:
    https://github.com/kaixindelele/DRLib/tree/main/algos/pytorch/td3_sp
    https://github.com/jaromiru/AI-blog/blob/master/SumTree.py
    """

    def __init__(self, memo_len):
        self.memo_len = memo_len  # replay buffer len
        # SumTree size is 2 * buffer_len - 1, parent nodes is buffer_len-1, and leaves node is buffer_len.
        self.ps_tree = np.zeros(2 * memo_len - 1)
        self.now_max_tree_len = self.memo_len - 1
        self.indices = None

        self.per_alpha = 0.6
        self.per_beta = 0.4
        self.per_beta_increment_per_sampling = 0.001
        self.depth = int(np.log2(2 * memo_len - 1))

    def update(self, data_idx, prob=10):  # 10 is max_prob
        tree_idx = data_idx + self.memo_len - 1
        if self.now_max_tree_len == tree_idx:
            self.now_max_tree_len += 1

        delta = prob - self.ps_tree[tree_idx]
        self.ps_tree[tree_idx] = prob

        while tree_idx != 0:  # propagate the change through tree
            tree_idx = (tree_idx - 1) // 2  # faster than the recursive loop
            self.ps_tree[tree_idx] += delta

    def update_(self, data_ids, prob=10):  # 10 is max_prob
        ids = data_ids + self.memo_len - 1
        self.now_max_tree_len += (ids >= self.now_max_tree_len).sum()

        upper_step = self.depth - 1
        self.ps_tree[ids] = prob  # here, ids means the indices of 
                                  # given children (maybe the right ones or left ones)
        p_ids = (ids - 1) // 2

        while upper_step:  # propagate the change through tree
            ids = p_ids * 2 + 1  # in this while loop, ids means the indices of the left children
            self.ps_tree[p_ids] = self.ps_tree[ids] + self.ps_tree[ids + 1]
            p_ids = (p_ids - 1) // 2
            upper_step -= 1

        self.ps_tree[0] = self.ps_tree[1] + self.ps_tree[2]  # because we take depth-1 upper steps, ps_tree[0] need to be updated alone

    def get_leaf_id(self, v):
        """
        Tree structure and array storage:
        Tree index:
              0       -> storing priority sum
            |  |
          1     2
         | |   | |
        3  4  5  6    -> storing priority for transitions
        Array type for storing:
        [0,1,2,3,4,5,6]
        """
        parent_idx = 0
        while True:  # the while loop is faster than the method in the reference code
            l_idx = 2 * parent_idx + 1  # the leaf's left node
            r_idx = l_idx + 1  # the leaf's right node
            if l_idx >= (len(self.ps_tree)):  # reach bottom, end search
                leaf_idx = parent_idx
                break
            else:  # downward search, always search for a higher priority node
                if v <= self.ps_tree[l_idx]:
                    parent_idx = l_idx
                else:
                    v -= self.ps_tree[l_idx]
                    parent_idx = r_idx
        return min(leaf_idx, self.now_max_tree_len - 2)  # leaf_idx

    def get_indices_is_weights(self, batch_size, beg, end):
        self.per_beta = np.min([1., self.per_beta + self.per_beta_increment_per_sampling])  # max = 1

        # get random values for searching indices with proportional prioritization
        values = (rd.rand(batch_size) + np.arange(batch_size)) * (self.ps_tree[0] / batch_size)

        # get proportional prioritization
        leaf_ids = np.array([self.get_leaf_id(v) for v in values])
        self.indices = leaf_ids - (self.memo_len - 1)

        probs = self.ps_tree[leaf_ids] / self.ps_tree[beg:end].min()
        is_weights = np.power(probs, -self.per_beta)  # important sampling weights
        return self.indices, is_weights

    def td_error_update(self, td_error):  # td_error = (q-q).detach_().abs()
        prob = td_error.clamp(1e-6, 10).pow(self.per_alpha)
        prob = prob.cpu().numpy()
        for data_idx, p in zip(self.indices, prob):
            self.update(data_idx, p)


In [24]:
my_tree1 = BinarySearchTree(1000)
my_tree2 = BinarySearchTree(1000)

In [25]:
my_tree1.now_max_tree_len, my_tree2.now_max_tree_len

(999, 999)

In [26]:
# ids = np.random.randint(0, 1000, size=(1000,))
ids = np.random.randint(0, 1000, size=(1000,))
ids = np.unique(ids)

In [46]:
t0 = time.time()

for i in ids:
    my_tree1.update(i)

t1 = time.time()
normal_time = t1 - t0
print("time for update: {}".format(normal_time))

time for update: 0.028980016708374023


In [47]:
t0 = time.time()

my_tree2.update_(ids)

t1 = time.time()
numpy_time = t1 - t0
print("time for numpy update: {}".format(numpy_time))
print("speedup ratio: {}".format(normal_time / numpy_time))

time for numpy update: 0.001996755599975586
speedup ratio: 14.51355223880597


In [29]:
np.square((my_tree1.ps_tree - my_tree2.ps_tree)).sum()

0.0

*** ask jh whether the now_max_tree_len is expected ***

In [30]:
my_tree1.now_max_tree_len, my_tree2.now_max_tree_len

(1001, 1631)