Skip to content

Commit

Permalink
Fix alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
Haichao-Zhang committed Mar 9, 2022
1 parent 28f1cef commit 4866283
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions alf/algorithms/sac_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def __init__(self,
critic_network_cls, q_network_cls)

self._use_entropy_reward = use_entropy_reward
self._munchausen_reward_weight = min(0, munchausen_reward_weight)
self._munchausen_reward_weight = max(0, munchausen_reward_weight)
if munchausen_reward_weight > 0:
assert not normalize_entropy_reward, (
"should not normalize entropy "
Expand Down Expand Up @@ -857,6 +857,20 @@ def _calc_critic_loss(self, info: SacInfo):
When the reward is multi-dim, the entropy reward will be added to *all*
dims.
"""
if self._use_entropy_reward:
with torch.no_grad():
log_pi = info.log_pi
if self._entropy_normalizer is not None:
log_pi = self._entropy_normalizer.normalize(log_pi)
entropy_reward = nest.map_structure(
lambda la, lp: -torch.exp(la) * lp, self._log_alpha,
log_pi)
entropy_reward = sum(nest.flatten(entropy_reward))
discount = self._critic_losses[0].gamma * info.discount
info = info._replace(
reward=(info.reward + common.expand_dims_as(
entropy_reward * discount, info.reward)))

if self._munchausen_reward_weight > 0:
with torch.no_grad():
# calculate the log probability of the rollout action
Expand All @@ -875,26 +889,18 @@ def _calc_critic_loss(self, info: SacInfo):
munchausen_reward = nest.map_structure(
lambda la, lp: torch.exp(la) * lp, self._log_alpha,
log_pi_rollout_a)
# [T, B]
munchausen_reward = sum(nest.flatten(munchausen_reward))
# forward shift the reward one-step temporally, with zero-padding
munchausen_reward = torch.cat((torch.zeros_like(
munchausen_reward[0:1]), munchausen_reward[:-1]),
dim=0)

info = info._replace(
reward=(
info.reward + self._munchausen_reward_weight *
common.expand_dims_as(munchausen_reward, info.reward)))

if self._use_entropy_reward:
with torch.no_grad():
log_pi = info.log_pi
if self._entropy_normalizer is not None:
log_pi = self._entropy_normalizer.normalize(log_pi)
entropy_reward = nest.map_structure(
lambda la, lp: -torch.exp(la) * lp, self._log_alpha,
log_pi)
entropy_reward = sum(nest.flatten(entropy_reward))
discount = self._critic_losses[0].gamma * info.discount
info = info._replace(
reward=(info.reward + common.expand_dims_as(
entropy_reward * discount, info.reward)))

critic_info = info.critic
critic_losses = []
for i, l in enumerate(self._critic_losses):
Expand Down

0 comments on commit 4866283

Please sign in to comment.