Skip to content

Commit

Permalink
Fix alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
Haichao-Zhang committed Mar 9, 2022
1 parent 28f1cef commit b330630
Showing 1 changed file with 27 additions and 18 deletions.
45 changes: 27 additions & 18 deletions alf/algorithms/sac_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def __init__(self,
critic_network_cls, q_network_cls)

self._use_entropy_reward = use_entropy_reward
self._munchausen_reward_weight = min(0, munchausen_reward_weight)
self._munchausen_reward_weight = max(0, munchausen_reward_weight)
if munchausen_reward_weight > 0:
assert not normalize_entropy_reward, (
"should not normalize entropy "
Expand Down Expand Up @@ -853,10 +853,23 @@ def _calc_critic_loss(self, info: SacInfo):
(There is an issue in their implementation: their "terminals" can't
differentiate between discount=0 (NormalEnd) and discount=1 (TimeOut).
In the latter case, masking should not be performed.)
When the reward is multi-dim, the entropy reward will be added to *all*
dims.
"""
if self._use_entropy_reward:
with torch.no_grad():
log_pi = info.log_pi
if self._entropy_normalizer is not None:
log_pi = self._entropy_normalizer.normalize(log_pi)
entropy_reward = nest.map_structure(
lambda la, lp: -torch.exp(la) * lp, self._log_alpha,
log_pi)
entropy_reward = sum(nest.flatten(entropy_reward))
discount = self._critic_losses[0].gamma * info.discount
# When the reward is multi-dim, the entropy reward will be
# added to *all* dims.
info = info._replace(
reward=(info.reward + common.expand_dims_as(
entropy_reward * discount, info.reward)))

if self._munchausen_reward_weight > 0:
with torch.no_grad():
# calculate the log probability of the rollout action
Expand All @@ -875,26 +888,22 @@ def _calc_critic_loss(self, info: SacInfo):
munchausen_reward = nest.map_structure(
lambda la, lp: torch.exp(la) * lp, self._log_alpha,
log_pi_rollout_a)
# [T, B]
munchausen_reward = sum(nest.flatten(munchausen_reward))
# forward shift the munchausen reward one-step temporally,
# with zero-padding for the first step. This dummy reward
# for the first step does not impact training as it is not
# used in TD-learning.
munchausen_reward = torch.cat((torch.zeros_like(
munchausen_reward[0:1]), munchausen_reward[:-1]),
dim=0)
# When the reward is multi-dim, the munchausen reward will be
# added to *all* dims.
info = info._replace(
reward=(
info.reward + self._munchausen_reward_weight *
common.expand_dims_as(munchausen_reward, info.reward)))

if self._use_entropy_reward:
with torch.no_grad():
log_pi = info.log_pi
if self._entropy_normalizer is not None:
log_pi = self._entropy_normalizer.normalize(log_pi)
entropy_reward = nest.map_structure(
lambda la, lp: -torch.exp(la) * lp, self._log_alpha,
log_pi)
entropy_reward = sum(nest.flatten(entropy_reward))
discount = self._critic_losses[0].gamma * info.discount
info = info._replace(
reward=(info.reward + common.expand_dims_as(
entropy_reward * discount, info.reward)))

critic_info = info.critic
critic_losses = []
for i, l in enumerate(self._critic_losses):
Expand Down

0 comments on commit b330630

Please sign in to comment.