Skip to content

Commit

Permalink
Munchausen RL
Browse files Browse the repository at this point in the history
  • Loading branch information
Haichao-Zhang committed Mar 9, 2022
1 parent 17fcb4e commit 28f1cef
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Read the ALF documentation [here](https://alf.readthedocs.io/).
|[OAC](alf/algorithms/oac_algorithm.py)|Off-policy RL|Ciosek et al. "Better Exploration with Optimistic Actor-Critic" [arXiv:1910.12807](https://arxiv.org/abs/1910.12807)|
|[HER](https://github.com/HorizonRobotics/alf/blob/911d9573866df41e9e3adf6cdd94ee03016bf5a8/alf/algorithms/data_transformer.py#L672)|Off-policy RL|Andrychowicz et al. "Hindsight Experience Replay" [arXiv:1707.01495](https://arxiv.org/abs/1707.01495)|
|[TAAC](alf/algorithms/taac_algorithm.py)|Off-policy RL|Yu et al. "TAAC: Temporally Abstract Actor-Critic for Continuous Control" [arXiv:2104.06521](https://arxiv.org/abs/2104.06521)|
|[Munchausen RL](alf/algorithms/sac_algorithm.py)|Off-policy RL|Nino et al. "Munchausen Reinforcement Learning" [arXiv:2007.14430](https://arxiv.org/abs/2007.14430)|
|[DIAYN](alf/algorithms/diayn_algorithm.py)|Intrinsic motivation/Exploration|Eysenbach et al. "Diversity is All You Need: Learning Diverse Skills without a Reward Function" [arXiv:1802.06070](https://arxiv.org/abs/1802.06070)|
|[ICM](alf/algorithms/icm_algorithm.py)|Intrinsic motivation/Exploration|Pathak et al. "Curiosity-driven Exploration by Self-supervised Prediction" [arXiv:1705.05363](https://arxiv.org/abs/1705.05363)|
|[RND](alf/algorithms/rnd_algorithm.py)|Intrinsic motivation/Exploration|Burda et al. "Exploration by Random Network Distillation" [arXiv:1810.12894](https://arxiv.org/abs/1810.12894)|
Expand Down
35 changes: 35 additions & 0 deletions alf/algorithms/sac_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(self,
reward_weights=None,
epsilon_greedy=None,
use_entropy_reward=True,
munchausen_reward_weight=0,
normalize_entropy_reward=False,
calculate_priority=False,
num_critic_replicas=2,
Expand Down Expand Up @@ -204,6 +205,11 @@ def __init__(self,
from ``config.epsilon_greedy`` and then
``alf.get_config_value(TrainerConfig.epsilon_greedy)``.
use_entropy_reward (bool): whether to include entropy as reward
munchausen_reward_weight (float): the weight of augmenting the task
reward with munchausen reward, as introduced in ``Munchausen
Reinforcement Learning``, which is essentially the log_pi of
the given action. A non-positive value means the munchausen
reward is not used.
normalize_entropy_reward (bool): if True, normalize entropy reward
to reduce bias in episodic cases. Only used if
``use_entropy_reward==True``.
Expand Down Expand Up @@ -267,6 +273,11 @@ def __init__(self,
critic_network_cls, q_network_cls)

self._use_entropy_reward = use_entropy_reward
self._munchausen_reward_weight = min(0, munchausen_reward_weight)
if munchausen_reward_weight > 0:
assert not normalize_entropy_reward, (
"should not normalize entropy "
"reward when using munchausen reward")

if reward_spec.numel > 1:
assert self._act_type != ActionType.Mixed, (
Expand Down Expand Up @@ -846,6 +857,30 @@ def _calc_critic_loss(self, info: SacInfo):
When the reward is multi-dim, the entropy reward will be added to *all*
dims.
"""
if self._munchausen_reward_weight > 0:
with torch.no_grad():
# calculate the log probability of the rollout action
log_pi_rollout_a = nest.map_structure(
lambda dist, a: dist.log_prob(a), info.action_distribution,
info.action)

if self._act_type == ActionType.Mixed:
# For mixed type, add log_pi separately
log_pi_rollout_a = type(self._action_spec)(
(sum(nest.flatten(log_pi_rollout_a[0])),
sum(nest.flatten(log_pi_rollout_a[1]))))
else:
log_pi_rollout_a = sum(nest.flatten(log_pi_rollout_a))

munchausen_reward = nest.map_structure(
lambda la, lp: torch.exp(la) * lp, self._log_alpha,
log_pi_rollout_a)
munchausen_reward = sum(nest.flatten(munchausen_reward))
info = info._replace(
reward=(
info.reward + self._munchausen_reward_weight *
common.expand_dims_as(munchausen_reward, info.reward)))

if self._use_entropy_reward:
with torch.no_grad():
log_pi = info.log_pi
Expand Down

0 comments on commit 28f1cef

Please sign in to comment.