Skip to content

Commit

Permalink
Merge pull request #33 from deepbrain/master
Browse files Browse the repository at this point in the history
Improve the replay memory to avoid storing the autograd graphs in it
  • Loading branch information
Kaixhin committed Oct 23, 2018
2 parents de04b85 + 5a9f42a commit de446cc
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def learn(self, mem):
(weights * loss).mean().backward() # Backpropagate importance-weighted minibatch loss
self.optimiser.step()

mem.update_priorities(idxs, loss.detach()) # Update priorities of sampled transitions
mem.update_priorities(idxs, loss.detach().cpu().numpy()) # Update priorities of sampled transitions

def update_target_net(self):
self.target_net.load_state_dict(self.online_net.state_dict())
Expand Down
8 changes: 5 additions & 3 deletions memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random
from collections import namedtuple
import torch
import numpy as np


Transition = namedtuple('Transition', ('timestep', 'state', 'action', 'reward', 'nonterminal'))
Expand Down Expand Up @@ -126,14 +127,15 @@ def sample(self, batch_size):
probs, idxs, tree_idxs, states, actions, returns, next_states, nonterminals = zip(*batch)
states, next_states, = torch.stack(states), torch.stack(next_states)
actions, returns, nonterminals = torch.cat(actions), torch.cat(returns), torch.stack(nonterminals)
probs = torch.tensor(probs, dtype=torch.float32, device=self.device) / p_total # Calculate normalised probabilities
probs = np.array(probs, dtype = np.float32)/p_total # Calculate normalised probabilities
capacity = self.capacity if self.transitions.full else self.transitions.index
weights = (capacity * probs) ** -self.priority_weight # Compute importance-sampling weights w
weights = weights / weights.max() # Normalise by max importance-sampling weight from batch
weights = torch.tensor(weights / weights.max(), dtype=torch.float32, device=self.device) # Normalise by max importance-sampling weight from batch
return tree_idxs, states, actions, returns, next_states, nonterminals, weights


def update_priorities(self, idxs, priorities):
priorities.pow_(self.priority_exponent)
priorities = np.power(priorities, self.priority_exponent)
[self.transitions.update(idx, priority) for idx, priority in zip(idxs, priorities)]

# Set up internal state for iterator
Expand Down

0 comments on commit de446cc

Please sign in to comment.