diff --git a/agent.py b/agent.py index b18f482..61a4a33 100644 --- a/agent.py +++ b/agent.py @@ -83,7 +83,7 @@ def learn(self, mem): (weights * loss).mean().backward() # Backpropagate importance-weighted minibatch loss self.optimiser.step() - mem.update_priorities(idxs, loss.detach()) # Update priorities of sampled transitions + mem.update_priorities(idxs, loss.detach().cpu().numpy()) # Update priorities of sampled transitions def update_target_net(self): self.target_net.load_state_dict(self.online_net.state_dict()) diff --git a/memory.py b/memory.py index 5d126c0..db15e15 100644 --- a/memory.py +++ b/memory.py @@ -1,6 +1,7 @@ import random from collections import namedtuple import torch +import numpy as np Transition = namedtuple('Transition', ('timestep', 'state', 'action', 'reward', 'nonterminal')) @@ -126,14 +127,15 @@ def sample(self, batch_size): probs, idxs, tree_idxs, states, actions, returns, next_states, nonterminals = zip(*batch) states, next_states, = torch.stack(states), torch.stack(next_states) actions, returns, nonterminals = torch.cat(actions), torch.cat(returns), torch.stack(nonterminals) - probs = torch.tensor(probs, dtype=torch.float32, device=self.device) / p_total # Calculate normalised probabilities + probs = np.array(probs, dtype = np.float32)/p_total # Calculate normalised probabilities capacity = self.capacity if self.transitions.full else self.transitions.index weights = (capacity * probs) ** -self.priority_weight # Compute importance-sampling weights w - weights = weights / weights.max() # Normalise by max importance-sampling weight from batch + weights = torch.tensor(weights / weights.max(), dtype=torch.float32, device=self.device) # Normalise by max importance-sampling weight from batch return tree_idxs, states, actions, returns, next_states, nonterminals, weights + def update_priorities(self, idxs, priorities): - priorities.pow_(self.priority_exponent) + priorities = np.power(priorities, self.priority_exponent) [self.transitions.update(idx, priority) for idx, priority in zip(idxs, priorities)] # Set up internal state for iterator