Skip to content

Commit

Permalink
Fix alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
Haichao-Zhang committed Mar 9, 2022
1 parent 28f1cef commit fd151e6
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion alf/algorithms/sac_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def __init__(self,
critic_network_cls, q_network_cls)

self._use_entropy_reward = use_entropy_reward
self._munchausen_reward_weight = min(0, munchausen_reward_weight)
self._munchausen_reward_weight = max(0, munchausen_reward_weight)
if munchausen_reward_weight > 0:
assert not normalize_entropy_reward, (
"should not normalize entropy "
Expand Down Expand Up @@ -875,7 +875,13 @@ def _calc_critic_loss(self, info: SacInfo):
munchausen_reward = nest.map_structure(
lambda la, lp: torch.exp(la) * lp, self._log_alpha,
log_pi_rollout_a)
# [T, B]
munchausen_reward = sum(nest.flatten(munchausen_reward))
# forward shift the reward one-step temporally, with zero-padding
munchausen_reward = torch.cat((torch.zeros_like(
munchausen_reward[0:1]), munchausen_reward[:-1]),
dim=0)

info = info._replace(
reward=(
info.reward + self._munchausen_reward_weight *
Expand Down

0 comments on commit fd151e6

Please sign in to comment.