Skip to content

Commit

Permalink
Improve robustness for QR prioritisation
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin committed Jun 7, 2018
1 parent c4ce26c commit cf4c315
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Rainbow
=======
[![MIT License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE.md)

Rainbow: Combining Improvements in Deep Reinforcement Learning [[1]](#references). Includes quantile regression loss [[2]](#references): run with `--quantile --atoms 200`.
Rainbow: Combining Improvements in Deep Reinforcement Learning [[1]](#references). Includes quantile regression loss [[2]](#references): run with `--quantile`.

Results and pretrained models can be found in the [releases](https://github.com/Kaixhin/Rainbow/releases).

Expand Down
11 changes: 6 additions & 5 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ class Agent():
def __init__(self, args, env):
self.action_space = env.action_space()
self.quantile = args.quantile
self.atoms = args.atoms # Alternatively number of quantiles
self.atoms = args.quantiles if args.quantile else args.atoms
if args.quantile:
self.cumulative_density = (2 * torch.arange(args.atoms).to(device=args.device) + 1) / (2 * args.atoms) # Quantile cumulative probability weights τ
self.cumulative_density = (2 * torch.arange(self.atoms).to(device=args.device) + 1) / (2 * self.atoms) # Quantile cumulative probability weights τ
else:
self.Vmin = args.V_min
self.Vmax = args.V_max
self.support = torch.linspace(args.V_min, args.V_max, args.atoms).to(device=args.device) # Support (range) of z
self.delta_z = (args.V_max - args.V_min) / (args.atoms - 1)
self.support = torch.linspace(args.V_min, args.V_max, self.atoms).to(device=args.device) # Support (range) of z
self.delta_z = (args.V_max - args.V_min) / (self.atoms - 1)
self.batch_size = args.batch_size
self.n = args.multi_step
self.discount = args.discount
Expand Down Expand Up @@ -96,7 +96,8 @@ def learn(self, mem):
self.online_net.zero_grad()
(weights * loss).mean().backward() # Importance weight losses
self.optimiser.step()
loss *= 100 if self.quantile else 1 # Heuristic for prioritised replay
if self.quantile:
loss = (self.atoms * loss).clamp(max=5) # Heuristic for prioritised replay

mem.update_priorities(idxs, loss.detach()) # Update priorities of sampled transitions

Expand Down
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
parser.add_argument('--history-length', type=int, default=4, metavar='T', help='Number of consecutive states processed')
parser.add_argument('--hidden-size', type=int, default=512, metavar='SIZE', help='Network hidden size')
parser.add_argument('--noisy-std', type=float, default=0.1, metavar='σ', help='Initial standard deviation of noisy linear layers')
parser.add_argument('--atoms', type=int, default=51, metavar='C', help='Discretised size of value distribution/number of quantiles')
parser.add_argument('--atoms', type=int, default=51, metavar='C', help='Discretised size of value distribution')
parser.add_argument('--quantiles', type=int, default=200, metavar='Q', help='Number of quantiles')
parser.add_argument('--quantile', action='store_true', help='Use quantile regression')
parser.add_argument('--V-min', type=float, default=-10, metavar='V', help='Minimum of value distribution support')
parser.add_argument('--V-max', type=float, default=10, metavar='V', help='Maximum of value distribution support')
Expand Down
6 changes: 3 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def forward(self, input):
class DQN(nn.Module):
def __init__(self, args, action_space, quantile=True):
super().__init__()
self.atoms = args.atoms # Alternatively number of quantiles
self.atoms = args.quantiles if quantile else args.atoms
self.action_space = action_space
self.quantile = quantile

Expand All @@ -56,8 +56,8 @@ def __init__(self, args, action_space, quantile=True):
self.conv3 = nn.Conv2d(64, 64, 3)
self.fc_h_v = NoisyLinear(3136, args.hidden_size, std_init=args.noisy_std)
self.fc_h_a = NoisyLinear(3136, args.hidden_size, std_init=args.noisy_std)
self.fc_z_v = NoisyLinear(args.hidden_size, args.atoms, std_init=args.noisy_std)
self.fc_z_a = NoisyLinear(args.hidden_size, action_space * args.atoms, std_init=args.noisy_std)
self.fc_z_v = NoisyLinear(args.hidden_size, self.atoms, std_init=args.noisy_std)
self.fc_z_a = NoisyLinear(args.hidden_size, action_space * self.atoms, std_init=args.noisy_std)

def forward(self, x):
x = F.relu(self.conv1(x))
Expand Down

0 comments on commit cf4c315

Please sign in to comment.