Skip to content

Commit

Permalink
Parallelized sampling from the replay buffer and building the segment…
Browse files Browse the repository at this point in the history
… tree (#608)

* Parallelized sampling from the replay buffer and building the segment tree.
  • Loading branch information
flodorner authored and Miffyli committed Jan 4, 2020
1 parent b06ab28 commit c7084c8
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 45 deletions.
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Breaking Changes:

New Features:
^^^^^^^^^^^^^
- Parallelized updating and sampling from the replay buffer in DQN. (@flodorner)

- Docker build script, `scripts/build_docker.sh`, can push images automatically.

Expand Down Expand Up @@ -597,3 +598,4 @@ Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner
84 changes: 59 additions & 25 deletions stable_baselines/common/segment_tree.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
import operator
import numpy as np


def unique(sorted_array):
"""
More efficient implementation of np.unique for sorted arrays
:param sorted_array: (np.ndarray)
:return:(np.ndarray) sorted_array without duplicate elements
"""
if len(sorted_array) == 1:
return sorted_array
left = sorted_array[:-1]
right = sorted_array[1:]
uniques = np.append(right != left, True)
return sorted_array[uniques]


class SegmentTree(object):
Expand All @@ -8,7 +22,7 @@ def __init__(self, capacity, operation, neutral_element):
https://en.wikipedia.org/wiki/Segment_tree
Can be used as regular array, but with two
Can be used as regular array that supports Index arrays, but with two
important differences:
a) setting item's value is slightly slower.
Expand All @@ -26,6 +40,7 @@ def __init__(self, capacity, operation, neutral_element):
self._capacity = capacity
self._value = [neutral_element for _ in range(2 * capacity)]
self._operation = operation
self.neutral_element = neutral_element

def _reduce_helper(self, start, end, node, node_start, node_end):
if start == node_start and end == node_end:
Expand Down Expand Up @@ -61,29 +76,36 @@ def reduce(self, start=0, end=None):
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)

def __setitem__(self, idx, val):
# index of the leaf
idx += self._capacity
self._value[idx] = val
idx //= 2
while idx >= 1:
self._value[idx] = self._operation(
self._value[2 * idx],
self._value[2 * idx + 1]
# indexes of the leaf
idxs = idx + self._capacity
self._value[idxs] = val
if isinstance(idxs, int):
idxs = np.array([idxs])
# go up one level in the tree and remove duplicate indexes
idxs = unique(idxs // 2)
while len(idxs) > 1 or idxs[0] > 0:
# as long as there are non-zero indexes, update the corresponding values
self._value[idxs] = self._operation(
self._value[2 * idxs],
self._value[2 * idxs + 1]
)
idx //= 2
# go up one level in the tree and remove duplicate indexes
idxs = unique(idxs // 2)

def __getitem__(self, idx):
assert 0 <= idx < self._capacity
assert np.max(idx) < self._capacity
assert 0 <= np.min(idx)
return self._value[self._capacity + idx]


class SumSegmentTree(SegmentTree):
def __init__(self, capacity):
super(SumSegmentTree, self).__init__(
capacity=capacity,
operation=operator.add,
operation=np.add,
neutral_element=0.0
)
self._value = np.array(self._value)

def sum(self, start=0, end=None):
"""
Expand All @@ -98,33 +120,45 @@ def sum(self, start=0, end=None):
def find_prefixsum_idx(self, prefixsum):
"""
Find the highest index `i` in the array such that
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum for each entry in prefixsum
if array values are probabilities, this function
allows to sample indexes according to the discrete
probability efficiently.
:param prefixsum: (float) upperbound on the sum of array prefix
:return: (int) highest index satisfying the prefixsum constraint
:param prefixsum: (np.ndarray) float upper bounds on the sum of array prefix
:return: (np.ndarray) highest indexes satisfying the prefixsum constraint
"""
assert 0 <= prefixsum <= self.sum() + 1e-5
idx = 1
while idx < self._capacity: # while non-leaf
if self._value[2 * idx] > prefixsum:
idx = 2 * idx
else:
prefixsum -= self._value[2 * idx]
idx = 2 * idx + 1
if isinstance(prefixsum, float):
prefixsum = np.array([prefixsum])
assert 0 <= np.min(prefixsum)
assert np.max(prefixsum) <= self.sum() + 1e-5
assert isinstance(prefixsum[0], float)

idx = np.ones(len(prefixsum), dtype=int)
cont = np.ones(len(prefixsum), dtype=bool)

while np.any(cont): # while not all nodes are leafs
idx[cont] = 2 * idx[cont]
prefixsum_new = np.where(self._value[idx] <= prefixsum, prefixsum - self._value[idx], prefixsum)
# prepare update of prefixsum for all right children
idx = np.where(np.logical_or(self._value[idx] > prefixsum, np.logical_not(cont)), idx, idx + 1)
# Select child node for non-leaf nodes
prefixsum = prefixsum_new
# update prefixsum
cont = idx < self._capacity
# collect leafs
return idx - self._capacity


class MinSegmentTree(SegmentTree):
def __init__(self, capacity):
super(MinSegmentTree, self).__init__(
capacity=capacity,
operation=min,
operation=np.minimum,
neutral_element=float('inf')
)
self._value = np.array(self._value)

def min(self, start=0, end=None):
"""
Expand Down
35 changes: 15 additions & 20 deletions stable_baselines/deepq/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,12 @@ def add(self, obs_t, action, reward, obs_tp1, done):
self._it_min[idx] = self._max_priority ** self._alpha

def _sample_proportional(self, batch_size):
res = []
for _ in range(batch_size):
# TODO(szymon): should we ensure no repeats?
mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1)
idx = self._it_sum.find_prefixsum_idx(mass)
res.append(idx)
return res
mass = []
total = self._it_sum.sum(0, len(self._storage) - 1)
# TODO(szymon): should we ensure no repeats?
mass = np.random.random(size=batch_size) * total
idx = self._it_sum.find_prefixsum_idx(mass)
return idx

def sample(self, batch_size, beta=0):
"""
Expand All @@ -166,16 +165,11 @@ def sample(self, batch_size, beta=0):
assert beta > 0

idxes = self._sample_proportional(batch_size)

weights = []
p_min = self._it_min.min() / self._it_sum.sum()
max_weight = (p_min * len(self._storage)) ** (-beta)

for idx in idxes:
p_sample = self._it_sum[idx] / self._it_sum.sum()
weight = (p_sample * len(self._storage)) ** (-beta)
weights.append(weight / max_weight)
weights = np.array(weights)
p_sample = self._it_sum[idxes] / self._it_sum.sum()
weights = (p_sample * len(self._storage)) ** (-beta) / max_weight
encoded_sample = self._encode_sample(idxes)
return tuple(list(encoded_sample) + [weights, idxes])

Expand All @@ -191,10 +185,11 @@ def update_priorities(self, idxes, priorities):
denoted by variable `idxes`.
"""
assert len(idxes) == len(priorities)
for idx, priority in zip(idxes, priorities):
assert priority > 0
assert 0 <= idx < len(self._storage)
self._it_sum[idx] = priority ** self._alpha
self._it_min[idx] = priority ** self._alpha
assert np.min(priorities) > 0
assert np.min(idxes) >= 0
assert np.max(idxes) < len(self.storage)
self._it_sum[idxes] = priorities ** self._alpha
self._it_min[idxes] = priorities ** self._alpha

self._max_priority = max(self._max_priority, np.max(priorities))

self._max_priority = max(self._max_priority, priority)
76 changes: 76 additions & 0 deletions tests/test_segment_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@ def test_tree_set():
"""
tree = SumSegmentTree(4)

tree[np.array([2, 3])] = [1.0, 3.0]

assert np.isclose(tree.sum(), 4.0)
assert np.isclose(tree.sum(0, 2), 0.0)
assert np.isclose(tree.sum(0, 3), 1.0)
assert np.isclose(tree.sum(2, 3), 1.0)
assert np.isclose(tree.sum(2, -1), 1.0)
assert np.isclose(tree.sum(2, 4), 4.0)

tree = SumSegmentTree(4)
tree[2] = 1.0
tree[3] = 3.0

Expand All @@ -26,6 +36,17 @@ def test_tree_set_overlap():
"""
tree = SumSegmentTree(4)

tree[np.array([2])] = 1.0
tree[np.array([2])] = 3.0

assert np.isclose(tree.sum(), 3.0)
assert np.isclose(tree.sum(2, 3), 3.0)
assert np.isclose(tree.sum(2, -1), 3.0)
assert np.isclose(tree.sum(2, 4), 3.0)
assert np.isclose(tree.sum(1, 2), 0.0)

tree = SumSegmentTree(4)

tree[2] = 1.0
tree[2] = 3.0

Expand All @@ -51,6 +72,19 @@ def test_prefixsum_idx():
assert tree.find_prefixsum_idx(1.01) == 3
assert tree.find_prefixsum_idx(3.00) == 3
assert tree.find_prefixsum_idx(4.00) == 3
assert np.all(tree.find_prefixsum_idx([0.0, 0.5, 0.99, 1.01, 3.00, 4.00]) == [2, 2, 2, 3, 3, 3])

tree = SumSegmentTree(4)

tree[np.array([2, 3])] = [1.0, 3.0]

assert tree.find_prefixsum_idx(0.0) == 2
assert tree.find_prefixsum_idx(0.5) == 2
assert tree.find_prefixsum_idx(0.99) == 2
assert tree.find_prefixsum_idx(1.01) == 3
assert tree.find_prefixsum_idx(3.00) == 3
assert tree.find_prefixsum_idx(4.00) == 3
assert np.all(tree.find_prefixsum_idx([0.0, 0.5, 0.99, 1.01, 3.00, 4.00]) == [2, 2, 2, 3, 3, 3])


def test_prefixsum_idx2():
Expand All @@ -59,6 +93,17 @@ def test_prefixsum_idx2():
"""
tree = SumSegmentTree(4)

tree[np.array([0, 1, 2, 3])] = [0.5, 1.0, 1.0, 3.0]

assert tree.find_prefixsum_idx(0.00) == 0
assert tree.find_prefixsum_idx(0.55) == 1
assert tree.find_prefixsum_idx(0.99) == 1
assert tree.find_prefixsum_idx(1.51) == 2
assert tree.find_prefixsum_idx(3.00) == 3
assert tree.find_prefixsum_idx(5.50) == 3

tree = SumSegmentTree(4)

tree[0] = 0.5
tree[1] = 1.0
tree[2] = 1.0
Expand Down Expand Up @@ -109,6 +154,37 @@ def test_max_interval_tree():
assert np.isclose(tree.min(2, -1), 4.0)
assert np.isclose(tree.min(3, 4), 3.0)

tree = MinSegmentTree(4)

tree[np.array([0, 2, 3])] = [1.0, 0.5, 3.0]

assert np.isclose(tree.min(), 0.5)
assert np.isclose(tree.min(0, 2), 1.0)
assert np.isclose(tree.min(0, 3), 0.5)
assert np.isclose(tree.min(0, -1), 0.5)
assert np.isclose(tree.min(2, 4), 0.5)
assert np.isclose(tree.min(3, 4), 3.0)

tree[np.array([2])] = 0.7

assert np.isclose(tree.min(), 0.7)
assert np.isclose(tree.min(0, 2), 1.0)
assert np.isclose(tree.min(0, 3), 0.7)
assert np.isclose(tree.min(0, -1), 0.7)
assert np.isclose(tree.min(2, 4), 0.7)
assert np.isclose(tree.min(3, 4), 3.0)

tree[np.array([2])] = 4.0

assert np.isclose(tree.min(), 1.0)
assert np.isclose(tree.min(0, 2), 1.0)
assert np.isclose(tree.min(0, 3), 1.0)
assert np.isclose(tree.min(0, -1), 1.0)
assert np.isclose(tree.min(2, 4), 3.0)
assert np.isclose(tree.min(2, 3), 4.0)
assert np.isclose(tree.min(2, -1), 4.0)
assert np.isclose(tree.min(3, 4), 3.0)


if __name__ == '__main__':
test_tree_set()
Expand Down

0 comments on commit c7084c8

Please sign in to comment.