-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
min_prob 永遠返回 0 #183
Comments
有些人為了解決這個問題 https://blog.csdn.net/gsww404/article/details/103673852 但我認為也不是很好的做法 而且add entry時因要計算整個tree的max p和sample要計算整個tree的min p浪費了很多的資源 以下是我整理過的代碼 import random
import numpy as np
class PrioritizedMemory(object): # stored as ( s, a, r, s_ ) in SumTree
e = 0.01
a = 0.6
beta = 0.4
beta_increment_per_sampling = 0.001
def __init__(self, capacity):
self.tree = SumTree(capacity)
self.capacity = capacity
def _get_priority(self, error):
return (np.abs(error) + self.e) ** self.a
def add(self, error, sample):
p = self._get_priority(error)
self.tree.add(p, sample)
def sample(self, n):
batch = []
idxs = []
segment = self.tree.total() / n
priorities = []
self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])
for i in range(n):
a = segment * i
b = segment * (i + 1)
s = random.uniform(a, b)
(idx, p, data) = self.tree.get(s)
priorities.append(p)
batch.append(data)
idxs.append(idx)
sampling_probabilities = priorities / self.tree.total()
is_weight = np.power(self.tree.n_entries * sampling_probabilities, -self.beta)
is_weight /= is_weight.max()
return idxs, batch, is_weight
def batch_update(self, tree_idx, abs_errors):
for ti, e in zip(tree_idx, abs_errors):
p = self._get_priority(e)
self.tree.update(ti, p)
class SumTree:
data_pointer = 0
def __init__(self, capacity):
self.capacity = capacity
self.tree = np.zeros(2 * capacity - 1)
self.data = np.zeros(capacity, dtype=object)
self.n_entries = 0
def update(self, tree_idx, p):
change = p - self.tree[tree_idx]
self.tree[tree_idx] = p
# then propagate the change through tree
while tree_idx != 0: # this method is faster than the recursive loop in the reference code
tree_idx = (tree_idx - 1) // 2
self.tree[tree_idx] += change
def total(self):
return self.tree[0]
def add(self, p, data):
tree_idx = self.data_pointer + self.capacity - 1
self.data[self.data_pointer] = data # update data_frame
self.update(tree_idx, p) # update tree_frame
self.data_pointer += 1
if self.data_pointer >= self.capacity: # replace when exceed the capacity
self.data_pointer = 0
if self.n_entries < self.capacity:
self.n_entries += 1
def get(self, v):
"""
Tree structure and array storage:
Tree index:
0 -> storing priority sum
/ \
1 2
/ \ / \
3 4 5 6 -> storing priority for transitions
Array type for storing:
[0,1,2,3,4,5,6]
"""
parent_idx = 0
while True: # the while loop is faster than the method in the reference code
cl_idx = 2 * parent_idx + 1 # this leaf's left and right kids
cr_idx = cl_idx + 1
if cl_idx >= len(self.tree): # reach bottom, end search
leaf_idx = parent_idx
break
else: # downward search, always search for a higher priority node
if v <= self.tree[cl_idx]:
parent_idx = cl_idx
else:
v -= self.tree[cl_idx]
parent_idx = cr_idx
data_idx = leaf_idx - self.capacity + 1
return leaf_idx, self.tree[leaf_idx], self.data[data_idx]
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/contents/5.2_Prioritized_Replay_DQN/RL_brain.py#L114
由於Sumtree 在剛開始的時候存在大量 0
所以 np.min 會返回 0
而導致
返回錯誤
The text was updated successfully, but these errors were encountered: