Skip to content

Commit

Permalink
Weight losses for PER + clip grads
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin committed Jun 18, 2018
1 parent cf4c315 commit d6538df
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
5 changes: 4 additions & 1 deletion agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self, args, env):
self.batch_size = args.batch_size
self.n = args.multi_step
self.discount = args.discount
self.norm_clip = args.norm_clip

self.online_net = DQN(args, self.action_space, args.quantile).to(device=args.device)
if args.model and os.path.isfile(args.model):
Expand Down Expand Up @@ -93,9 +94,11 @@ def learn(self, mem):
loss = torch.sum(torch.abs(self.cumulative_density - (u < 0).to(torch.float32)) * huber_loss, 1) # Quantile Huber loss ρκτ(u) = |τ − δ{u<0}|Lκ(u)
else:
loss = -torch.sum(m * ps_a.log(), 1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
loss = weights * loss # Importance weight losses
self.online_net.zero_grad()
(weights * loss).mean().backward() # Importance weight losses
loss.mean().backward() # Backpropagate minibatch loss
self.optimiser.step()
nn.utils.clip_grad_norm_(self.online_net.parameters(), self.norm_clip) # Clip gradients by L2 norm
if self.quantile:
loss = (self.atoms * loss).clamp(max=5) # Heuristic for prioritised replay

Expand Down
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
parser.add_argument('--lr', type=float, default=0.0000625, metavar='η', help='Learning rate')
parser.add_argument('--adam-eps', type=float, default=1.5e-4, metavar='ε', help='Adam epsilon')
parser.add_argument('--batch-size', type=int, default=32, metavar='SIZE', help='Batch size')
parser.add_argument('--norm-clip', type=float, default=10, metavar='NORM', help='Max L2 norm for gradient clipping')
parser.add_argument('--learn-start', type=int, default=int(80e3), metavar='STEPS', help='Number of steps before starting training')
parser.add_argument('--evaluate', action='store_true', help='Evaluate only')
parser.add_argument('--evaluation-interval', type=int, default=100000, metavar='STEPS', help='Number of training steps between evaluations')
Expand Down

0 comments on commit d6538df

Please sign in to comment.