From 303a9a3e410f0a377023ca4e17304b4ceb97ea26 Mon Sep 17 00:00:00 2001 From: Wei Xu Date: Wed, 1 May 2024 19:11:55 -0700 Subject: [PATCH] Improve normalized advantage calculation 1. Use LazyBatchNorm for advantage normalization.The motivation is that for DDP training, the normalization statistics will combine the statistics from all GPUs. 2. Calculate normalized advantage in PPOAlgorithm.preprocess so that the normalization is based on a much larger batch. --- alf/algorithms/actor_critic_loss.py | 32 +++++++++++----------- alf/algorithms/ppo_algorithm.py | 41 +++++++++++++++++++++++++---- 2 files changed, 52 insertions(+), 21 deletions(-) diff --git a/alf/algorithms/actor_critic_loss.py b/alf/algorithms/actor_critic_loss.py index 79f567127..cbff5a8f2 100644 --- a/alf/algorithms/actor_critic_loss.py +++ b/alf/algorithms/actor_critic_loss.py @@ -28,21 +28,6 @@ ["pg_loss", "td_loss", "neg_entropy"]) -def _normalize_advantages(advantages, variance_epsilon=1e-8): - # advantages is of shape [T, B] or [T, B, N], where N is reward dim - # this function normalizes over all elements in the input advantages - shape = advantages.shape - # shape: [TB, 1] or [TB, N] - advantages = advantages.reshape(np.prod(advantages.shape[:2]), -1) - - adv_mean = advantages.mean(0) - adv_var = torch.var(advantages, dim=0, unbiased=False) - - normalized_advantages = ( - (advantages - adv_mean) / (torch.sqrt(adv_var) + variance_epsilon)) - return normalized_advantages.reshape(*shape) - - @alf.configurable class ActorCriticLoss(Loss): def __init__(self, @@ -97,6 +82,12 @@ def __init__(self, self._lambda = td_lambda self._use_td_lambda_return = use_td_lambda_return self._normalize_advantages = normalize_advantages + if normalize_advantages: + # Note that onvert_sync_batchnorm does not work with LazyBatchNorm + # in general. Fortunately, it works for affine=False and track_running_stats=False + # since no parameter needs to be created. + self._adv_norm = torch.nn.LazyBatchNorm1d( + eps=1e-8, affine=False, track_running_stats=False) assert advantage_clip is None or advantage_clip > 0, ( "Clipping value should be positive!") self._advantage_clip = advantage_clip @@ -107,6 +98,10 @@ def __init__(self, def gamma(self): return self._gamma.clone() + @property + def normalizing_advantages(self): + return self._normalize_advantages + def forward(self, info): """Cacluate actor critic loss. The first dimension of all the tensors is time dimension and the second dimesion is the batch dimension. @@ -147,7 +142,12 @@ def _summarize(v, r, adv, suffix): advantages[..., i], suffix) if self._normalize_advantages: - advantages = _normalize_advantages(advantages) + if hasattr(info, "normalized_advantages"): + advantages = info.normalized_advantages + else: + bt = advantages.shape[0] * advantages.shape[1] + adv = self._adv_norm(advantages.reshape(bt, -1)) + advantages = adv.reshape_as(advantages) if self._advantage_clip: advantages = torch.clamp(advantages, -self._advantage_clip, diff --git a/alf/algorithms/ppo_algorithm.py b/alf/algorithms/ppo_algorithm.py index ef1bdcfa4..30989c8ef 100644 --- a/alf/algorithms/ppo_algorithm.py +++ b/alf/algorithms/ppo_algorithm.py @@ -13,6 +13,7 @@ # limitations under the License. """PPO algorithm.""" +import functools import torch import alf @@ -24,9 +25,18 @@ PPOInfo = namedtuple( "PPOInfo", [ - "step_type", "discount", "reward", "action", "rollout_log_prob", - "rollout_action_distribution", "returns", "advantages", - "action_distribution", "value", "reward_weights" + "step_type", + "discount", + "reward", + "action", + "rollout_log_prob", + "rollout_action_distribution", + "returns", + "advantages", + "action_distribution", + "value", + "reward_weights", + "normalized_advantages", ], default_value=()) @@ -41,6 +51,15 @@ class PPOAlgorithm(ActorCriticAlgorithm): `baselines.ppo2`. """ + @functools.wraps(ActorCriticAlgorithm.__init__) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Note that onvert_sync_batchnorm does not work with LazyBatchNorm + # in general. Fortunately, it works for affine=False and track_running_stats=False + # since no parameter needs to be created. + self._adv_norm = torch.nn.LazyBatchNorm1d( + eps=1e-8, affine=False, track_running_stats=False) + @property def on_policy(self): return False @@ -82,12 +101,24 @@ def preprocess_experience(self, root_inputs: TimeStep, rollout_info, discounts=discounts, td_lambda=self._loss._lambda, time_major=False) - advantages = tensor_utils.tensor_extend_zero(advantages, dim=1) + if self._loss.normalizing_advantages: + bt = advantages.shape[0] * advantages.shape[1] + normalized_advantages = self._adv_norm(advantages.reshape(bt, -1)) + normalized_advantages = normalized_advantages.reshape_as( + advantages) + normalized_advantages = tensor_utils.tensor_extend_zero( + normalized_advantages, dim=1) + else: + normalized_advantages = () + + advantages = tensor_utils.tensor_extend_zero(advantages, dim=1) returns = value + advantages return root_inputs, PPOInfo( rollout_action_distribution=rollout_info.action_distribution, rollout_log_prob=rollout_info.log_prob, returns=returns, action=rollout_info.action, - advantages=advantages) + advantages=advantages, + normalized_advantages=normalized_advantages, + )