Skip to content

Commit

Permalink
Improve normalized advantage calculation
Browse files Browse the repository at this point in the history
1. Use LazyBatchNorm for advantage normalization.The motivation is that
for DDP training, the normalization statistics will combine the statistics from all GPUs.

2. Calculate normalized advantage in PPOAlgorithm.preprocess so that
the normalization is based on a much larger batch.
  • Loading branch information
emailweixu committed May 7, 2024
1 parent 98ec37d commit 303a9a3
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 21 deletions.
32 changes: 16 additions & 16 deletions alf/algorithms/actor_critic_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,6 @@
["pg_loss", "td_loss", "neg_entropy"])


def _normalize_advantages(advantages, variance_epsilon=1e-8):
# advantages is of shape [T, B] or [T, B, N], where N is reward dim
# this function normalizes over all elements in the input advantages
shape = advantages.shape
# shape: [TB, 1] or [TB, N]
advantages = advantages.reshape(np.prod(advantages.shape[:2]), -1)

adv_mean = advantages.mean(0)
adv_var = torch.var(advantages, dim=0, unbiased=False)

normalized_advantages = (
(advantages - adv_mean) / (torch.sqrt(adv_var) + variance_epsilon))
return normalized_advantages.reshape(*shape)


@alf.configurable
class ActorCriticLoss(Loss):
def __init__(self,
Expand Down Expand Up @@ -97,6 +82,12 @@ def __init__(self,
self._lambda = td_lambda
self._use_td_lambda_return = use_td_lambda_return
self._normalize_advantages = normalize_advantages
if normalize_advantages:
# Note that onvert_sync_batchnorm does not work with LazyBatchNorm
# in general. Fortunately, it works for affine=False and track_running_stats=False
# since no parameter needs to be created.
self._adv_norm = torch.nn.LazyBatchNorm1d(
eps=1e-8, affine=False, track_running_stats=False)
assert advantage_clip is None or advantage_clip > 0, (
"Clipping value should be positive!")
self._advantage_clip = advantage_clip
Expand All @@ -107,6 +98,10 @@ def __init__(self,
def gamma(self):
return self._gamma.clone()

@property
def normalizing_advantages(self):
return self._normalize_advantages

def forward(self, info):
"""Cacluate actor critic loss. The first dimension of all the tensors is
time dimension and the second dimesion is the batch dimension.
Expand Down Expand Up @@ -147,7 +142,12 @@ def _summarize(v, r, adv, suffix):
advantages[..., i], suffix)

if self._normalize_advantages:
advantages = _normalize_advantages(advantages)
if hasattr(info, "normalized_advantages"):
advantages = info.normalized_advantages
else:
bt = advantages.shape[0] * advantages.shape[1]
adv = self._adv_norm(advantages.reshape(bt, -1))
advantages = adv.reshape_as(advantages)

if self._advantage_clip:
advantages = torch.clamp(advantages, -self._advantage_clip,
Expand Down
41 changes: 36 additions & 5 deletions alf/algorithms/ppo_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""PPO algorithm."""

import functools
import torch

import alf
Expand All @@ -24,9 +25,18 @@

PPOInfo = namedtuple(
"PPOInfo", [
"step_type", "discount", "reward", "action", "rollout_log_prob",
"rollout_action_distribution", "returns", "advantages",
"action_distribution", "value", "reward_weights"
"step_type",
"discount",
"reward",
"action",
"rollout_log_prob",
"rollout_action_distribution",
"returns",
"advantages",
"action_distribution",
"value",
"reward_weights",
"normalized_advantages",
],
default_value=())

Expand All @@ -41,6 +51,15 @@ class PPOAlgorithm(ActorCriticAlgorithm):
`baselines.ppo2`.
"""

@functools.wraps(ActorCriticAlgorithm.__init__)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Note that onvert_sync_batchnorm does not work with LazyBatchNorm
# in general. Fortunately, it works for affine=False and track_running_stats=False
# since no parameter needs to be created.
self._adv_norm = torch.nn.LazyBatchNorm1d(
eps=1e-8, affine=False, track_running_stats=False)

@property
def on_policy(self):
return False
Expand Down Expand Up @@ -82,12 +101,24 @@ def preprocess_experience(self, root_inputs: TimeStep, rollout_info,
discounts=discounts,
td_lambda=self._loss._lambda,
time_major=False)
advantages = tensor_utils.tensor_extend_zero(advantages, dim=1)

if self._loss.normalizing_advantages:
bt = advantages.shape[0] * advantages.shape[1]
normalized_advantages = self._adv_norm(advantages.reshape(bt, -1))
normalized_advantages = normalized_advantages.reshape_as(
advantages)
normalized_advantages = tensor_utils.tensor_extend_zero(
normalized_advantages, dim=1)
else:
normalized_advantages = ()

advantages = tensor_utils.tensor_extend_zero(advantages, dim=1)
returns = value + advantages
return root_inputs, PPOInfo(
rollout_action_distribution=rollout_info.action_distribution,
rollout_log_prob=rollout_info.log_prob,
returns=returns,
action=rollout_info.action,
advantages=advantages)
advantages=advantages,
normalized_advantages=normalized_advantages,
)

0 comments on commit 303a9a3

Please sign in to comment.