Skip to content

Commit

Permalink
Add gradient clipping (#73)
Browse files Browse the repository at this point in the history
* Add gradient clipping
  • Loading branch information
Kaixhin committed May 22, 2020
1 parent b13e98f commit f52981f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
3 changes: 3 additions & 0 deletions agent.py
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import torch
from torch import optim
from torch.nn.utils import clip_grad_norm_

from model import DQN

Expand All @@ -19,6 +20,7 @@ def __init__(self, args, env):
self.batch_size = args.batch_size
self.n = args.multi_step
self.discount = args.discount
self.norm_clip = args.norm_clip

self.online_net = DQN(args, self.action_space).to(device=args.device)
if args.model: # Load pretrained model if provided
Expand Down Expand Up @@ -92,6 +94,7 @@ def learn(self, mem):
loss = -torch.sum(m * log_ps_a, 1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
self.online_net.zero_grad()
(weights * loss).mean().backward() # Backpropagate importance-weighted minibatch loss
clip_grad_norm_(self.online_net.parameters(), self.norm_clip) # Clip gradients by L2 norm
self.optimiser.step()

mem.update_priorities(idxs, loss.detach().cpu().numpy()) # Update priorities of sampled transitions
Expand Down
1 change: 1 addition & 0 deletions main.py
Expand Up @@ -44,6 +44,7 @@
parser.add_argument('--learning-rate', type=float, default=0.0000625, metavar='η', help='Learning rate')
parser.add_argument('--adam-eps', type=float, default=1.5e-4, metavar='ε', help='Adam epsilon')
parser.add_argument('--batch-size', type=int, default=32, metavar='SIZE', help='Batch size')
parser.add_argument('--norm-clip', type=float, default=10, metavar='NORM', help='Max L2 norm for gradient clipping')
parser.add_argument('--learn-start', type=int, default=int(20e3), metavar='STEPS', help='Number of steps before starting training')
parser.add_argument('--evaluate', action='store_true', help='Evaluate only')
parser.add_argument('--evaluation-interval', type=int, default=100000, metavar='STEPS', help='Number of training steps between evaluations')
Expand Down

0 comments on commit f52981f

Please sign in to comment.