Skip to content

Commit

Permalink
Revert "Use running statistics to normalize advantages for PPO and AC (
Browse files Browse the repository at this point in the history
…#1655)"

This reverts commit 1de48a9.
  • Loading branch information
breakds committed Jun 10, 2024
1 parent d3ea2e0 commit ad586df
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 92 deletions.
3 changes: 1 addition & 2 deletions alf/algorithms/actor_critic_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ def __init__(self,
self._actor_network = actor_network
self._value_network = value_network
if loss is None:
loss = loss_class(
reward_dim=reward_spec.numel, debug_summaries=debug_summaries)
loss = loss_class(debug_summaries=debug_summaries)
self._loss = loss

# The following checkpoint loading hook handles the case when value
Expand Down
66 changes: 10 additions & 56 deletions alf/algorithms/actor_critic_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,15 @@
["pg_loss", "td_loss", "neg_entropy"])


def normalize(batch_norm, x):
batch_norm.train()
momentum = batch_norm.momentum
if batch_norm.num_batches_tracked * momentum < 1.0:
# For the first few batches, we do cumulative moving average
batch_norm.momentum = None
batch_norm(x)
batch_norm.momentum = momentum
# We use the running mean and variance of the advantages to normalize
# since the batch may not be large enough to properly normalize within
# the batch.
batch_norm.eval()
return batch_norm(x)


@alf.configurable
class ActorCriticLoss(Loss):
def __init__(self,
reward_dim=1,
gamma=0.99,
td_error_loss_fn=element_wise_squared_loss,
use_gae=False,
td_lambda=0.95,
use_td_lambda_return=True,
normalize_advantages=False,
normalize_scalar_advantages=False,
advantage_norm_momentum=0.9,
advantage_clip=None,
entropy_regularization=None,
td_loss_weight=1.0,
Expand All @@ -69,7 +51,6 @@ def __init__(self,
- entropy_regularization * entropy)
Args:
reward_dim (int): dimension of the reward.
gamma (float|list[float]): A discount factor for future rewards. For
multi-dim reward, this can also be a list of discounts, each
discount applies to a reward dim.
Expand All @@ -84,14 +65,8 @@ def __init__(self,
``(td_lambda_return = gae_advantage + value_predictions)``.
td_lambda (float): Lambda parameter for TD-lambda computation.
normalize_advantages (bool): If True, normalize advantage to zero
mean and unit variance within batch for calculating policy
mean and unit variance within batch for caculating policy
gradient. This is commonly used for PPO.
normalize_scalar_advantages (bool): If False, the normalization is
performed for each reward dimension. If True, the normalization
is performed for the weighted sum of advantages using reward_weights.
Note that this will take precedence over `normalize_advantages`.
advantage_norm_momentum (float): Momentum for moving average of
mean and variance of advantages (same as the momentum for nn.BatchNorm1d).
advantage_clip (float): If set, clip advantages to :math:`[-x, x]`
entropy_regularization (float): Coefficient for entropy
regularization loss term.
Expand All @@ -106,23 +81,13 @@ def __init__(self,
self._use_gae = use_gae
self._lambda = td_lambda
self._use_td_lambda_return = use_td_lambda_return
if normalize_scalar_advantages:
self._adv_norm = torch.nn.BatchNorm1d(
num_features=1,
eps=1e-8,
momentum=advantage_norm_momentum,
affine=False,
track_running_stats=True)
normalize_advantages = False
elif normalize_advantages:
self._adv_norm = torch.nn.BatchNorm1d(
num_features=reward_dim,
eps=1e-8,
momentum=advantage_norm_momentum,
affine=False,
track_running_stats=True)
self._normalize_advantages = normalize_advantages
self._normalize_scalar_advantages = normalize_scalar_advantages
if normalize_advantages:
# Note that onvert_sync_batchnorm does not work with LazyBatchNorm
# in general. Fortunately, it works for affine=False and track_running_stats=False
# since no parameter needs to be created.
self._adv_norm = torch.nn.LazyBatchNorm1d(
eps=1e-8, affine=False, track_running_stats=False)
assert advantage_clip is None or advantage_clip > 0, (
"Clipping value should be positive!")
self._advantage_clip = advantage_clip
Expand All @@ -137,10 +102,6 @@ def gamma(self):
def normalizing_advantages(self):
return self._normalize_advantages

@property
def normalizing_scalar_advantages(self):
return self._normalize_scalar_advantages

def forward(self, info):
"""Cacluate actor critic loss. The first dimension of all the tensors is
time dimension and the second dimesion is the batch dimension.
Expand Down Expand Up @@ -179,27 +140,20 @@ def _summarize(v, r, adv, suffix):
suffix = '/' + str(i)
_summarize(value[..., i], returns[..., i],
advantages[..., i], suffix)

if self._normalize_advantages:
if hasattr(info, "normalized_advantages"):
advantages = info.normalized_advantages
else:
bt = advantages.shape[0] * advantages.shape[1]
adv = normalize(self._adv_norm, advantages.reshape(bt, -1))
advantages = adv.reshape_as(advantages)
elif self._normalize_scalar_advantages:
if hasattr(info, "normalized_advantages"):
advantages = info.normalized_advantages
else:
advantages = (advantages * info.reward_weights).sum(-1)
adv = normalize(self._adv_norm, advantages.reshape(-1, 1))
adv = self._adv_norm(advantages.reshape(bt, -1))
advantages = adv.reshape_as(advantages)

if self._advantage_clip:
advantages = torch.clamp(advantages, -self._advantage_clip,
self._advantage_clip)

if info.reward_weights != () and not self._normalize_scalar_advantages:
# reward_weights has already been applied for self._normalize_scalar_advantages
if info.reward_weights != ():
advantages = (advantages * info.reward_weights).sum(-1)
pg_loss = self._pg_loss(info, advantages.detach())

Expand Down
26 changes: 11 additions & 15 deletions alf/algorithms/ppo_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import alf
from alf.algorithms.actor_critic_algorithm import ActorCriticAlgorithm
from alf.algorithms.ppo_loss import PPOLoss
from alf.algorithms.actor_critic_loss import normalize
from alf.data_structures import namedtuple, TimeStep
from alf.utils import value_ops, tensor_utils
from alf.nest.utils import convert_device
Expand Down Expand Up @@ -52,6 +51,15 @@ class PPOAlgorithm(ActorCriticAlgorithm):
`baselines.ppo2`.
"""

@functools.wraps(ActorCriticAlgorithm.__init__)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Note that onvert_sync_batchnorm does not work with LazyBatchNorm
# in general. Fortunately, it works for affine=False and track_running_stats=False
# since no parameter needs to be created.
self._adv_norm = torch.nn.LazyBatchNorm1d(
eps=1e-8, affine=False, track_running_stats=False)

@property
def on_policy(self):
return False
Expand Down Expand Up @@ -94,21 +102,9 @@ def preprocess_experience(self, root_inputs: TimeStep, rollout_info,
td_lambda=self._loss._lambda,
time_major=False)

if self._loss.normalizing_scalar_advantages:
if self.has_multidim_reward():
scalar_advantages = (advantages * self.reward_weights).sum(-1)
else:
scalar_advantages = advantages
normalized_advantages = normalize(self._loss._adv_norm,
scalar_advantages.reshape(-1, 1))
normalized_advantages = normalized_advantages.reshape_as(
scalar_advantages)
normalized_advantages = tensor_utils.tensor_extend_zero(
normalized_advantages, dim=1)
elif self._loss.normalizing_advantages:
if self._loss.normalizing_advantages:
bt = advantages.shape[0] * advantages.shape[1]
normalized_advantages = normalize(self._loss._adv_norm,
advantages.reshape(bt, -1))
normalized_advantages = self._adv_norm(advantages.reshape(bt, -1))
normalized_advantages = normalized_advantages.reshape_as(
advantages)
normalized_advantages = tensor_utils.tensor_extend_zero(
Expand Down
5 changes: 1 addition & 4 deletions alf/algorithms/ppo_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,7 @@ def create_algorithm(env, use_rnn=False, learning_rate=1e-1):
config=config,
actor_network_ctor=actor_net,
value_network_ctor=value_net,
loss=PPOLoss(
reward_dim=env.reward_spec().numel,
gamma=1.0,
debug_summaries=DEBUGGING),
loss=PPOLoss(gamma=1.0, debug_summaries=DEBUGGING),
optimizer=optimizer,
debug_summaries=DEBUGGING)

Expand Down
17 changes: 2 additions & 15 deletions alf/algorithms/ppo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,10 @@ class PPOLoss(ActorCriticLoss):
"""PPO loss."""

def __init__(self,
reward_dim=1,
gamma=0.99,
td_error_loss_fn=element_wise_squared_loss,
td_lambda=0.95,
normalize_advantages=True,
normalize_scalar_advantages=False,
advantage_norm_momentum=0.9,
compute_advantages_internally=False,
advantage_clip=None,
entropy_regularization=None,
Expand Down Expand Up @@ -61,7 +58,6 @@ def __init__(self,
much.
Args:
reward_dim (int): dimension of the reward.
gamma (float|list[float]): A discount factor for future rewards. For
multi-dim reward, this can also be a list of discounts, each
discount applies to a reward dim.
Expand All @@ -70,14 +66,8 @@ def __init__(self,
Q values and returns the loss for each element of the batch.
td_lambda (float): Lambda parameter for TD-lambda computation.
normalize_advantages (bool): If True, normalize advantage to zero
mean and unit variance within batch for calculating policy
mean and unit variance within batch for caculating policy
gradient.
normalize_scalar_advantages (bool): If False, the normalization is
performed for each reward dimension. If True, the normalization
is performed for the weighted sum of advantages using reward_weights.
Note that this will take precedence over `normalize_advantages`.
advantage_norm_momentum (float): Momentum for moving average of
mean and variance of advantages (same as the momentum for nn.BatchNorm1d).
compute_advantages_internally (bool): Normally PPOLoss does not
compute the adavantage and it expects the info to carry the
already-computed advantage. If this flag is set to True, PPOLoss
Expand All @@ -100,16 +90,13 @@ def __init__(self,
"""

super().__init__(
reward_dim=reward_dim,
super(PPOLoss, self).__init__(
gamma=gamma,
td_error_loss_fn=td_error_loss_fn,
use_gae=True,
td_lambda=td_lambda,
use_td_lambda_return=True,
normalize_advantages=normalize_advantages,
normalize_scalar_advantages=normalize_scalar_advantages,
advantage_norm_momentum=advantage_norm_momentum,
advantage_clip=advantage_clip,
entropy_regularization=entropy_regularization,
td_loss_weight=td_loss_weight,
Expand Down

0 comments on commit ad586df

Please sign in to comment.