Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve normalized advantage calculation #1642

Merged
merged 2 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions alf/algorithms/actor_critic_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,6 @@
["pg_loss", "td_loss", "neg_entropy"])


def _normalize_advantages(advantages, variance_epsilon=1e-8):
# advantages is of shape [T, B] or [T, B, N], where N is reward dim
# this function normalizes over all elements in the input advantages
shape = advantages.shape
# shape: [TB, 1] or [TB, N]
advantages = advantages.reshape(np.prod(advantages.shape[:2]), -1)

adv_mean = advantages.mean(0)
adv_var = torch.var(advantages, dim=0, unbiased=False)

normalized_advantages = (
(advantages - adv_mean) / (torch.sqrt(adv_var) + variance_epsilon))
return normalized_advantages.reshape(*shape)


@alf.configurable
class ActorCriticLoss(Loss):
def __init__(self,
Expand Down Expand Up @@ -97,6 +82,12 @@ def __init__(self,
self._lambda = td_lambda
self._use_td_lambda_return = use_td_lambda_return
self._normalize_advantages = normalize_advantages
if normalize_advantages:
# Note that onvert_sync_batchnorm does not work with LazyBatchNorm
# in general. Fortunately, it works for affine=False and track_running_stats=False
# since no parameter needs to be created.
self._adv_norm = torch.nn.LazyBatchNorm1d(
eps=1e-8, affine=False, track_running_stats=False)
assert advantage_clip is None or advantage_clip > 0, (
"Clipping value should be positive!")
self._advantage_clip = advantage_clip
Expand All @@ -107,6 +98,10 @@ def __init__(self,
def gamma(self):
return self._gamma.clone()

@property
def normalizing_advantages(self):
return self._normalize_advantages

def forward(self, info):
"""Cacluate actor critic loss. The first dimension of all the tensors is
time dimension and the second dimesion is the batch dimension.
Expand Down Expand Up @@ -147,7 +142,12 @@ def _summarize(v, r, adv, suffix):
advantages[..., i], suffix)

if self._normalize_advantages:
advantages = _normalize_advantages(advantages)
if hasattr(info, "normalized_advantages"):
advantages = info.normalized_advantages
else:
bt = advantages.shape[0] * advantages.shape[1]
adv = self._adv_norm(advantages.reshape(bt, -1))
advantages = adv.reshape_as(advantages)

if self._advantage_clip:
advantages = torch.clamp(advantages, -self._advantage_clip,
Expand Down
12 changes: 0 additions & 12 deletions alf/algorithms/actor_critic_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,9 @@
import torch.distributions as td
import numpy as np

from alf.algorithms.actor_critic_loss import _normalize_advantages
from alf.utils.dist_utils import compute_entropy, compute_log_probability


class TestAdvantageNormalization(unittest.TestCase):
def test_advantage_normalization(self):
advantages = torch.Tensor([[1, 2], [3, 4.0]])
# results computed from tf
normalized_advantages_expected = torch.Tensor(
[[-1.3416407, -0.4472136], [0.4472136, 1.3416407]])
normalized_advantages_obtained = _normalize_advantages(advantages)
np.testing.assert_array_almost_equal(normalized_advantages_obtained,
normalized_advantages_expected)


class TestEntropyExpand(unittest.TestCase):
def test_entropy(self):
m = td.categorical.Categorical(torch.Tensor([0.25, 0.75]))
Expand Down
41 changes: 36 additions & 5 deletions alf/algorithms/ppo_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""PPO algorithm."""

import functools
import torch

import alf
Expand All @@ -24,9 +25,18 @@

PPOInfo = namedtuple(
"PPOInfo", [
"step_type", "discount", "reward", "action", "rollout_log_prob",
"rollout_action_distribution", "returns", "advantages",
"action_distribution", "value", "reward_weights"
"step_type",
"discount",
"reward",
"action",
"rollout_log_prob",
"rollout_action_distribution",
"returns",
"advantages",
"action_distribution",
"value",
"reward_weights",
"normalized_advantages",
],
default_value=())

Expand All @@ -41,6 +51,15 @@ class PPOAlgorithm(ActorCriticAlgorithm):
`baselines.ppo2`.
"""

@functools.wraps(ActorCriticAlgorithm.__init__)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Note that onvert_sync_batchnorm does not work with LazyBatchNorm
# in general. Fortunately, it works for affine=False and track_running_stats=False
# since no parameter needs to be created.
self._adv_norm = torch.nn.LazyBatchNorm1d(
eps=1e-8, affine=False, track_running_stats=False)

@property
def on_policy(self):
return False
Expand Down Expand Up @@ -82,12 +101,24 @@ def preprocess_experience(self, root_inputs: TimeStep, rollout_info,
discounts=discounts,
td_lambda=self._loss._lambda,
time_major=False)
advantages = tensor_utils.tensor_extend_zero(advantages, dim=1)

if self._loss.normalizing_advantages:
bt = advantages.shape[0] * advantages.shape[1]
normalized_advantages = self._adv_norm(advantages.reshape(bt, -1))
normalized_advantages = normalized_advantages.reshape_as(
advantages)
normalized_advantages = tensor_utils.tensor_extend_zero(
normalized_advantages, dim=1)
else:
normalized_advantages = ()

advantages = tensor_utils.tensor_extend_zero(advantages, dim=1)
returns = value + advantages
return root_inputs, PPOInfo(
rollout_action_distribution=rollout_info.action_distribution,
rollout_log_prob=rollout_info.log_prob,
returns=returns,
action=rollout_info.action,
advantages=advantages)
advantages=advantages,
normalized_advantages=normalized_advantages,
)
Loading