Skip to content
This repository has been archived by the owner on Dec 11, 2022. It is now read-only.

Commit

Permalink
new feature - implementation of Quantile Regression DQN (https://arxi…
Browse files Browse the repository at this point in the history
…v.org/pdf/1710.10044v1.pdf)

API change - Distributional DQN renamed to Categorical DQN
  • Loading branch information
itaicaspi-intel committed Nov 1, 2017
1 parent 1ad6262 commit a8bce98
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 17 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ python3 coach.py -p Hopper_A3C -n 16
* [Dueling Q Network](https://arxiv.org/abs/1511.06581)
* [Mixed Monte Carlo (MMC)](https://arxiv.org/abs/1703.01310)
* [Persistent Advantage Learning (PAL)](https://arxiv.org/abs/1512.04860)
* [Distributional Deep Q Network ](https://arxiv.org/abs/1707.06887)
* [Categorical Deep Q Network (C51)](https://arxiv.org/abs/1707.06887)
* [Quantile Regression Deep Q Network (QR-DQN)](https://arxiv.org/pdf/1710.10044v1.pdf)
* [Bootstrapped Deep Q Network](https://arxiv.org/abs/1602.04621)
* [N-Step Q Learning](https://arxiv.org/abs/1602.01783) | **Distributed**
* [Neural Episodic Control (NEC) ](https://arxiv.org/abs/1703.01988)
Expand Down
3 changes: 2 additions & 1 deletion agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from agents.ddqn_agent import *
from agents.dfp_agent import *
from agents.dqn_agent import *
from agents.distributional_dqn_agent import *
from agents.categorical_dqn_agent import *
from agents.mmc_agent import *
from agents.n_step_q_agent import *
from agents.naf_agent import *
Expand All @@ -32,3 +32,4 @@
from agents.policy_optimization_agent import *
from agents.ppo_agent import *
from agents.value_optimization_agent import *
from agents.qr_dqn_agent import *
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from agents.value_optimization_agent import *


# Distributional Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf
class DistributionalDQNAgent(ValueOptimizationAgent):
# Categorical Deep Q Network - https://arxiv.org/pdf/1707.06887.pdf
class CategoricalDQNAgent(ValueOptimizationAgent):
def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
self.z_values = np.linspace(self.tp.agent.v_min, self.tp.agent.v_max, self.tp.agent.atoms)
Expand Down
62 changes: 62 additions & 0 deletions agents/qr_dqn_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from agents.value_optimization_agent import *


# Quantile Regression Deep Q Network - https://arxiv.org/pdf/1710.10044v1.pdf
class QuantileRegressionDQNAgent(ValueOptimizationAgent):
def __init__(self, env, tuning_parameters, replicated_device=None, thread_id=0):
ValueOptimizationAgent.__init__(self, env, tuning_parameters, replicated_device, thread_id)
self.quantile_probabilities = np.ones(self.tp.agent.atoms) / float(self.tp.agent.atoms)

# prediction's format is (batch,actions,atoms)
def get_q_values(self, quantile_values):
return np.dot(quantile_values, self.quantile_probabilities)

def learn_from_batch(self, batch):
current_states, next_states, actions, rewards, game_overs, _ = self.extract_batch(batch)

# get the quantiles of the next states and current states
next_state_quantiles = self.main_network.target_network.predict(next_states)
current_quantiles = self.main_network.online_network.predict(current_states)

# get the optimal actions to take for the next states
target_actions = np.argmax(self.get_q_values(next_state_quantiles), axis=1)

# calculate the Bellman update
batch_idx = list(range(self.tp.batch_size))
rewards = np.expand_dims(rewards, -1)
game_overs = np.expand_dims(game_overs, -1)
TD_targets = rewards + (1.0 - game_overs) * self.tp.agent.discount \
* next_state_quantiles[batch_idx, target_actions]

# get the locations of the selected actions within the batch for indexing purposes
actions_locations = [[b, a] for b, a in zip(batch_idx, actions)]

# calculate the cumulative quantile probabilities and reorder them to fit the sorted quantiles order
cumulative_probabilities = np.array(range(self.tp.agent.atoms+1))/float(self.tp.agent.atoms) # tau_i
quantile_midpoints = 0.5*(cumulative_probabilities[1:] + cumulative_probabilities[:-1]) # tau^hat_i
quantile_midpoints = np.tile(quantile_midpoints, (self.tp.batch_size, 1))
for idx in range(self.tp.batch_size):
quantile_midpoints[idx, :] = quantile_midpoints[idx, np.argsort(current_quantiles[batch_idx, actions])[idx]]

# train
result = self.main_network.train_and_sync_networks([current_states, actions_locations, quantile_midpoints], TD_targets)
total_loss = result[0]

return total_loss

3 changes: 2 additions & 1 deletion architectures/tensorflow_components/general_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def get_output_head(self, head_type, head_idx, loss_weight=1.):
OutputTypes.NAF: NAFHead,
OutputTypes.PPO: PPOHead,
OutputTypes.PPO_V : PPOVHead,
OutputTypes.DistributionalQ: DistributionalQHead
OutputTypes.CategoricalQ: CategoricalQHead,
OutputTypes.QuantileRegressionQ: QuantileRegressionQHead
}
return output_mapping[head_type](self.tp, head_idx, loss_weight, self.network_is_local)

Expand Down
48 changes: 46 additions & 2 deletions architectures/tensorflow_components/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,10 +462,10 @@ def _build_module(self, input_layer):
tf.losses.add_loss(self.loss)


class DistributionalQHead(Head):
class CategoricalQHead(Head):
def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
self.name = 'distributional_dqn_head'
self.name = 'categorical_dqn_head'
self.num_actions = tuning_parameters.env_instance.action_space_size
self.num_atoms = tuning_parameters.agent.atoms

Expand All @@ -484,3 +484,47 @@ def _build_module(self, input_layer):
self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.target, logits=values_distribution)
tf.losses.add_loss(self.loss)


class QuantileRegressionQHead(Head):
def __init__(self, tuning_parameters, head_idx=0, loss_weight=1., is_local=True):
Head.__init__(self, tuning_parameters, head_idx, loss_weight, is_local)
self.name = 'quantile_regression_dqn_head'
self.num_actions = tuning_parameters.env_instance.action_space_size
self.num_atoms = tuning_parameters.agent.atoms # we use atom / quantile interchangeably
self.huber_loss_interval = 1 # k

def _build_module(self, input_layer):
self.actions = tf.placeholder(tf.int32, [None, 2], name="actions")
self.quantile_midpoints = tf.placeholder(tf.float32, [None, self.num_atoms], name="quantile_midpoints")
self.input = [self.actions, self.quantile_midpoints]

# the output of the head is the N unordered quantile locations {theta_1, ..., theta_N}
quantiles_locations = tf.layers.dense(input_layer, self.num_actions * self.num_atoms)
quantiles_locations = tf.reshape(quantiles_locations, (tf.shape(quantiles_locations)[0], self.num_actions, self.num_atoms))
self.output = quantiles_locations

self.quantiles = tf.placeholder(tf.float32, shape=(None, self.num_atoms), name="quantiles")
self.target = self.quantiles

# only the quantiles of the taken action are taken into account
quantiles_for_used_actions = tf.gather_nd(quantiles_locations, self.actions)

# reorder the output quantiles and the target quantiles as a preparation step for calculating the loss
# the output quantiles vector and the quantile midpoints are tiled as rows of a NxN matrix (N = num quantiles)
# the target quantiles vector is tiled as column of a NxN matrix
theta_i = tf.tile(tf.expand_dims(quantiles_for_used_actions, -1), [1, 1, self.num_atoms])
T_theta_j = tf.tile(tf.expand_dims(self.target, -2), [1, self.num_atoms, 1])
tau_i = tf.tile(tf.expand_dims(self.quantile_midpoints, -1), [1, 1, self.num_atoms])

# Huber loss of T(theta_j) - theta_i
abs_error = tf.abs(T_theta_j - theta_i)
quadratic = tf.minimum(abs_error, self.huber_loss_interval)
huber_loss = self.huber_loss_interval * (abs_error - quadratic) + 0.5 * quadratic ** 2

# Quantile Huber loss
quantile_huber_loss = tf.abs(tau_i - tf.cast(abs_error < 0, dtype=tf.float32)) * huber_loss

# Quantile regression loss (the probability for each quantile is 1/num_quantiles)
quantile_regression_loss = tf.reduce_sum(quantile_huber_loss) / float(self.num_atoms)
self.loss = quantile_regression_loss
tf.losses.add_loss(self.loss)
15 changes: 11 additions & 4 deletions configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class OutputTypes(object):
NAF = 7
PPO = 8
PPO_V = 9
DistributionalQ = 10
CategoricalQ = 10
QuantileRegressionQ = 11


class MiddlewareTypes(object):
Expand Down Expand Up @@ -307,14 +308,20 @@ class BootstrappedDQN(DQN):
num_output_head_copies = 10


class DistributionalDQN(DQN):
type = 'DistributionalDQNAgent'
output_types = [OutputTypes.DistributionalQ]
class CategoricalDQN(DQN):
type = 'CategoricalDQNAgent'
output_types = [OutputTypes.CategoricalQ]
v_min = -10.0
v_max = 10.0
atoms = 51


class QuantileRegressionDQN(DQN):
type = 'QuantileRegressionDQNAgent'
output_types = [OutputTypes.QuantileRegressionQ]
atoms = 51


class NEC(AgentParameters):
type = 'NECAgent'
optimizer_type = 'RMSProp'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Distributional DQN
# Categorical DQN

**Actions space:** Discrete

Expand Down
2 changes: 1 addition & 1 deletion docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pages:
- 'DQN' : algorithms/value_optimization/dqn.md
- 'Double DQN' : algorithms/value_optimization/double_dqn.md
- 'Dueling DQN' : algorithms/value_optimization/dueling_dqn.md
- 'Distributional DQN' : algorithms/value_optimization/distributional_dqn.md
- 'Categorical DQN' : algorithms/value_optimization/categorical_dqn.md
- 'Mixed Monte Carlo' : algorithms/value_optimization/mmc.md
- 'Persistent Advantage Learning' : algorithms/value_optimization/pal.md
- 'Neural Episodic Control' : algorithms/value_optimization/nec.md
Expand Down
32 changes: 28 additions & 4 deletions presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ def __init__(self):
self.num_heatup_steps = 1000



class Doom_Basic_QRDQN(Preset):
def __init__(self):
Preset.__init__(self, QuantileRegressionDQN, Doom, ExplorationParameters)
self.env.level = 'basic'
self.agent.num_steps_between_copying_online_weights_to_target = 1000
self.learning_rate = 0.00025
self.agent.num_episodes_in_experience_replay = 200
self.num_heatup_steps = 1000



class Doom_Basic_OneStepQ(Preset):
def __init__(self):
Preset.__init__(self, NStepQ, Doom, ExplorationParameters)
Expand Down Expand Up @@ -340,9 +352,9 @@ def __init__(self):
self.test_min_return_threshold = 150


class CartPole_DistributionalDQN(Preset):
class CartPole_C51(Preset):
def __init__(self):
Preset.__init__(self, DistributionalDQN, GymVectorObservation, ExplorationParameters)
Preset.__init__(self, CategoricalDQN, GymVectorObservation, ExplorationParameters)
self.env.level = 'CartPole-v0'
self.agent.num_steps_between_copying_online_weights_to_target = 100
self.learning_rate = 0.00025
Expand All @@ -356,6 +368,18 @@ def __init__(self):
self.test_min_return_threshold = 150


class CartPole_QRDQN(Preset):
def __init__(self):
Preset.__init__(self, QuantileRegressionDQN, GymVectorObservation, ExplorationParameters)
self.env.level = 'CartPole-v0'
self.agent.num_steps_between_copying_online_weights_to_target = 100
self.learning_rate = 0.00025
self.agent.num_episodes_in_experience_replay = 200
self.num_heatup_steps = 1000
self.exploration.epsilon_decay_steps = 3000
self.agent.discount = 1.0


# The below preset matches the hyper-parameters setting as in the original DQN paper.
# This a very resource intensive preset, and might easily blow up your RAM (> 100GB of usage).
# Try reducing the number of transitions in the experience replay (50e3 might be a reasonable number to start with),
Expand All @@ -377,9 +401,9 @@ def __init__(self):
self.evaluate_every_x_episodes = 100


class Breakout_DistributionalDQN(Preset):
class Breakout_C51(Preset):
def __init__(self):
Preset.__init__(self, DistributionalDQN, Atari, ExplorationParameters)
Preset.__init__(self, CategoricalDQN, Atari, ExplorationParameters)
self.env.level = 'BreakoutDeterministic-v4'
self.agent.num_steps_between_copying_online_weights_to_target = 10000
self.learning_rate = 0.00025
Expand Down

0 comments on commit a8bce98

Please sign in to comment.