Skip to content

Commit

Permalink
Remove transitions straddling buffer index
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin committed Jan 14, 2018
1 parent f398c1a commit 9cb5f59
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
9 changes: 5 additions & 4 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,18 @@ def act(self, state):

def learn(self, mem):
idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size)
batch_size = len(idxs) # May return less than specified if invalid transitions sampled

# Calculate current state probabilities
ps = self.policy_net(states) # Probabilities p(s_t, ·; θpolicy)
ps_a = ps[range(self.batch_size), actions] # p(s_t, a_t; θpolicy)
ps_a = ps[range(batch_size), actions] # p(s_t, a_t; θpolicy)

# Calculate nth next state probabilities
pns = self.policy_net(next_states).data # Probabilities p(s_t+n, ·; θpolicy)
dns = self.support.expand_as(pns) * pns # Distribution d_t+n = (z, p(s_t+n, ·; θpolicy))
argmax_indices_ns = dns.sum(2).max(1)[1] # Perform argmax action selection using policy network: argmax_a[(z, p(s_t+n, a; θpolicy))]
pns = self.target_net(next_states).data # Probabilities p(s_t+n, ·; θtarget)
pns_a = pns[range(self.batch_size), argmax_indices_ns] # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θpolicy))]; θtarget)
pns_a = pns[range(batch_size), argmax_indices_ns] # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θpolicy))]; θtarget)
pns_a *= nonterminals # Set p = 0 for terminal nth next states as all possible expected returns = expected reward at final transition

# Compute Tz (Bellman operator T applied to z)
Expand All @@ -67,8 +68,8 @@ def learn(self, mem):
l, u = b.floor().long(), b.ceil().long()

# Distribute probability of Tz
m = states.data.new(self.batch_size, self.atoms).zero_()
offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).long().unsqueeze(1).expand(self.batch_size, self.atoms).type_as(actions)
m = states.data.new(batch_size, self.atoms).zero_()
offset = torch.linspace(0, ((batch_size - 1) * self.atoms), batch_size).long().unsqueeze(1).expand(batch_size, self.atoms).type_as(actions)
m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b)
m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l)

Expand Down
13 changes: 7 additions & 6 deletions memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def sample(self, batch_size):
batch = [self.transitions.find(s) for s in samples] # Retrieve samples from tree
probs, idxs, tree_idxs = zip(*batch) # Unpack unnormalised probabilities (priorities), data indices, tree indices
# TODO: Check that transitions with 0 probability are not returned/make sure samples are valid
# If any transitions straddle current index, remove them (simpler than replacing with unique valid transitions)
probs, idxs, tree_idxs = self.dtype_float(probs), self.dtype_long(idxs), self.dtype_long(tree_idxs)
valid_idxs = idxs.sub(self.transitions.index).abs_() > max(self.history, self.n)
probs, idxs, tree_idxs = probs[valid_idxs], idxs[valid_idxs], tree_idxs[valid_idxs]

# Retrieve all required transition data (from t - h to t + n)
full_transitions = [[self.transitions.get(i + t) for i in idxs] for t in range(1 - self.history, self.n + 1)] # Time x batch
Expand All @@ -119,14 +123,11 @@ def sample(self, batch_size):
for n in range(self.n - 1):
# Invalid nth next states have reward 0 and hence do not affect calculation
returns = [R + self.discount ** n * transition.reward for R, transition in zip(returns, full_transitions[self.history + n])]
returns = self.dtype_float(returns) # TODO: Make sure this doesn't cause issues around current buffer index
returns = self.dtype_float(returns)

nonterminals = [transition.nonterminal for transition in full_transitions[self.history + self.n - 1]] # Mask for non-terminal nth next states
for t in range(self.history, self.history + self.n): # Hack: if nth next state is invalid (overlapping transition), treat it as terminal
nonterminals = [nonterm and (trans.timestep - pre_trans.timestep) == 1 for nonterm, trans, pre_trans in zip(nonterminals, full_transitions[t], full_transitions[t - 1])]
nonterminals = self.dtype_float(nonterminals).unsqueeze(1)
nonterminals = self.dtype_float([transition.nonterminal for transition in full_transitions[self.history + self.n - 1]]).unsqueeze(1) # Mask for non-terminal nth next states

probs = Variable(self.dtype_float(probs)) / p_total # Calculate normalised probabilities
probs = Variable(probs) / p_total # Calculate normalised probabilities
weights = (self.capacity * probs) ** -self.priority_weight # Compute importance-sampling weights w
weights = weights / weights.max() # Normalise by max importance-sampling weight from batch

Expand Down

0 comments on commit 9cb5f59

Please sign in to comment.