From e26c8e42a383efff50c65aadb8724eba03d02a22 Mon Sep 17 00:00:00 2001 From: deepbrain Date: Sun, 21 Oct 2018 13:33:52 -0700 Subject: [PATCH 1/4] Fixed a bug in the replay memory. In my tests the new memory requirements are reduced 3x and the training speed is increased by 3-4x. --- agent.py | 2 +- memory.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/agent.py b/agent.py index b18f482..61a4a33 100644 --- a/agent.py +++ b/agent.py @@ -83,7 +83,7 @@ def learn(self, mem): (weights * loss).mean().backward() # Backpropagate importance-weighted minibatch loss self.optimiser.step() - mem.update_priorities(idxs, loss.detach()) # Update priorities of sampled transitions + mem.update_priorities(idxs, loss.detach().cpu().numpy()) # Update priorities of sampled transitions def update_target_net(self): self.target_net.load_state_dict(self.online_net.state_dict()) diff --git a/memory.py b/memory.py index 5d126c0..75f9cda 100644 --- a/memory.py +++ b/memory.py @@ -1,6 +1,7 @@ import random from collections import namedtuple import torch +import numpy as np Transition = namedtuple('Transition', ('timestep', 'state', 'action', 'reward', 'nonterminal')) @@ -133,7 +134,7 @@ def sample(self, batch_size): return tree_idxs, states, actions, returns, next_states, nonterminals, weights def update_priorities(self, idxs, priorities): - priorities.pow_(self.priority_exponent) + priorities = np.power(priorities, self.priority_exponent) [self.transitions.update(idx, priority) for idx, priority in zip(idxs, priorities)] # Set up internal state for iterator From 0aea5fb13b4d763b95cd0fb9937f46ae8bf3796b Mon Sep 17 00:00:00 2001 From: deepbrain Date: Sun, 21 Oct 2018 20:24:13 -0700 Subject: [PATCH 2/4] Fixed a bug in the replay memory --- memory.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/memory.py b/memory.py index 75f9cda..9884518 100644 --- a/memory.py +++ b/memory.py @@ -125,14 +125,16 @@ def sample(self, batch_size): segment = p_total / batch_size # Batch size number of segments, based on sum over all probabilities batch = [self._get_sample_from_segment(segment, i) for i in range(batch_size)] # Get batch of valid samples probs, idxs, tree_idxs, states, actions, returns, next_states, nonterminals = zip(*batch) - states, next_states, = torch.stack(states), torch.stack(next_states) + states, next_states, = np.stack(states), np.stack(next_states) actions, returns, nonterminals = torch.cat(actions), torch.cat(returns), torch.stack(nonterminals) - probs = torch.tensor(probs, dtype=torch.float32, device=self.device) / p_total # Calculate normalised probabilities + probs = np.array(probs, dtype = np.float32)/p_total #, dtype=torch.float32, device=self.device). + #/ torch.tensor(p_total, dtype=torch.float32, device=self.device) # Calculate normalised probabilities capacity = self.capacity if self.transitions.full else self.transitions.index weights = (capacity * probs) ** -self.priority_weight # Compute importance-sampling weights w - weights = weights / weights.max() # Normalise by max importance-sampling weight from batch + weights = torch.tensor(weights / weights.max(), dtype=torch.float32, device=self.device) # Normalise by max importance-sampling weight from batch return tree_idxs, states, actions, returns, next_states, nonterminals, weights + def update_priorities(self, idxs, priorities): priorities = np.power(priorities, self.priority_exponent) [self.transitions.update(idx, priority) for idx, priority in zip(idxs, priorities)] From 1e2d7a4b91b2769698df228599063f9e744c3b71 Mon Sep 17 00:00:00 2001 From: deepbrain Date: Sun, 21 Oct 2018 21:08:51 -0700 Subject: [PATCH 3/4] Fixed a bug in the replay memory --- memory.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/memory.py b/memory.py index 9884518..de72231 100644 --- a/memory.py +++ b/memory.py @@ -127,8 +127,7 @@ def sample(self, batch_size): probs, idxs, tree_idxs, states, actions, returns, next_states, nonterminals = zip(*batch) states, next_states, = np.stack(states), np.stack(next_states) actions, returns, nonterminals = torch.cat(actions), torch.cat(returns), torch.stack(nonterminals) - probs = np.array(probs, dtype = np.float32)/p_total #, dtype=torch.float32, device=self.device). - #/ torch.tensor(p_total, dtype=torch.float32, device=self.device) # Calculate normalised probabilities + probs = np.array(probs, dtype = np.float32)/p_total # Calculate normalised probabilities capacity = self.capacity if self.transitions.full else self.transitions.index weights = (capacity * probs) ** -self.priority_weight # Compute importance-sampling weights w weights = torch.tensor(weights / weights.max(), dtype=torch.float32, device=self.device) # Normalise by max importance-sampling weight from batch From 5a9f42ace2bcbcb4dd4ddeaa29dbfd7f057d099e Mon Sep 17 00:00:00 2001 From: deepbrain Date: Mon, 22 Oct 2018 11:15:38 -0700 Subject: [PATCH 4/4] revert to work with the tensor states --- memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/memory.py b/memory.py index de72231..db15e15 100644 --- a/memory.py +++ b/memory.py @@ -125,7 +125,7 @@ def sample(self, batch_size): segment = p_total / batch_size # Batch size number of segments, based on sum over all probabilities batch = [self._get_sample_from_segment(segment, i) for i in range(batch_size)] # Get batch of valid samples probs, idxs, tree_idxs, states, actions, returns, next_states, nonterminals = zip(*batch) - states, next_states, = np.stack(states), np.stack(next_states) + states, next_states, = torch.stack(states), torch.stack(next_states) actions, returns, nonterminals = torch.cat(actions), torch.cat(returns), torch.stack(nonterminals) probs = np.array(probs, dtype = np.float32)/p_total # Calculate normalised probabilities capacity = self.capacity if self.transitions.full else self.transitions.index