Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: change wrapper setting #73

Merged
merged 13 commits into from
Jan 12, 2023
Merged

refactor: change wrapper setting #73

merged 13 commits into from
Jan 12, 2023

Conversation

Gaiejj
Copy link
Member

@Gaiejj Gaiejj commented Jan 11, 2023

Description

refactor: change wrapper setting

Motivation and Context

Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax close #15213 if this solves the issue #15213

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Put an x in all the boxes that apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide. (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly. (required for a bug fix or a new feature)
  • I have updated the documentation accordingly.
  • I have reformatted the code using make format. (required)
  • I have checked the code using make lint. (required)
  • I have ensured make test pass. (required)

@@ -13,15 +13,15 @@
# limitations under the License.
# ==============================================================================
"""Implementation of the DDPG algorithm."""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This blank line is required.

@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Implementation of the SAC algorithm."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blank line.

@@ -13,25 +13,27 @@
# limitations under the License.
# ==============================================================================
"""Implementation of the Lagrange version of the SAC algorithm."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blank line.

@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Implementation of the SDDPG algorithm."""
from typing import Dict, NamedTuple, Tuple
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blank line.

@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Implementation of the TD3 algorithm."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blank line.

@@ -13,23 +13,27 @@
# limitations under the License.
# ==============================================================================
"""Implementation of the Lagrange version of the TD3 algorithm."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blank line.

@@ -28,6 +30,11 @@ class PPOEarlyTerminated(PPO):
URL: https://arxiv.org/abs/2107.04200
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this line different from Reference's comment in off-policy algorithms?

"""Update."""
raw_data, data = self.buf.pre_process_data()
# First update Lagrange multiplier parameter
# pylint: disable=too-many-locals
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use disable?

from omnisafe.common.lagrange import Lagrange


@registry.register
class PPOLag(PolicyGradient, Lagrange):
class PPOLag(PPO, Lagrange):
"""The Lagrange version of the PPO algorithm.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I need to cite the safety gym paper here?

URL: https://arxiv.org/abs/1705.10528
- Title: Constrained Policy Optimization
- Authors: Joshua Achiam, David Held, Aviv Tamar, Pieter Abbeel.
- URL: https://arxiv.org/abs/1705.10528
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

URL style.

@@ -37,42 +39,82 @@ class PCPO(TRPO):
URL: https://arxiv.org/abs/2010.03152
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

URL styple.

info = {}
return cost_loss, info

# pylint: disable=too-many-locals,invalid-name,too-many-arguments
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

disable-next

loss_pi_before = loss_pi.item()
# get prob. distribution before updates
p_dist = self.actor_critic.actor(data['obs'])
p_dist = self.actor_critic.actor(obs)
# Train policy with multiple steps of gradient descent
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lowercase

loss_pi, pi_info = self.compute_loss_pi(data=data)
# Process the advantage function.
processed_adv = self.compute_surrogate(adv=adv, cost_adv=cost_adv)
# Compute the loss of policy net.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lowercase

self.actor_optimizer.zero_grad()
loss_pi, pi_info = self.compute_loss_pi(data=data)
# Process the advantage function.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lowercase

return self.std

def forward(self, raw_data=None):
"""FOrward"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error spelling

if terminal and not epoch_ended:
# Log info about epoch
self.rollout_log(logger, idx)
# Only save EpRet / EpLen if trajectory finished
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this line?

if use_rand_action:
action = self.env.action_space.sample()
# Step the env
# pylint: disable=unused-variable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

disable

else:
logger.store(**{'Values/V': value})
@WRAPPER_REGISTRY.register
class EarlyTerminatedWrapper(CMDPWrapper): # pylint: disable=too-many-instance-attributes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

disable


from omnisafe.algorithms import registry
from omnisafe.algorithms.off_policy.td3 import TD3
from omnisafe.common.lagrange import Lagrange
from omnisafe.models.constraint_actor_q_critic import ConstraintActorQCritic


@registry.register
class TD3Lag(TD3, Lagrange): # pylint: disable=too-many-instance-attributes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

disable


from omnisafe.algorithms import registry
from omnisafe.algorithms.off_policy.td3 import TD3
from omnisafe.common.lagrange import Lagrange
from omnisafe.models.constraint_actor_q_critic import ConstraintActorQCritic


@registry.register
class TD3Lag(TD3, Lagrange): # pylint: disable=too-many-instance-attributes
"""The Lagrange version of the TD3 algorithm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.

"""compute loss for policy"""
dist, _log_p = self.actor_critic.actor(data['obs'], data['act'])
ratio = torch.exp(_log_p - data['log_p'])
# pylint: disable=too-many-arguments
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

disable

def compute_loss_pi(self, data: dict):
"""Compute policy loss."""
dist, _log_p = self.actor_critic.actor(data['obs'], data['act'])
# pylint: disable=too-many-arguments
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

disable

Copy link
Member

@zmsn-2077 zmsn-2077 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.

omnisafe/algorithms/off_policy/ddpg.py Outdated Show resolved Hide resolved
omnisafe/algorithms/off_policy/ddpg.py Outdated Show resolved Hide resolved
omnisafe/algorithms/off_policy/ddpg.py Outdated Show resolved Hide resolved
omnisafe/algorithms/off_policy/ddpg_lag.py Outdated Show resolved Hide resolved
omnisafe/algorithms/off_policy/ddpg_lag.py Show resolved Hide resolved
omnisafe/algorithms/on_policy/naive_lagrange/pdo.py Outdated Show resolved Hide resolved
omnisafe/algorithms/on_policy/second_order/cpo.py Outdated Show resolved Hide resolved
omnisafe/common/buffer.py Outdated Show resolved Hide resolved
omnisafe/common/buffer.py Show resolved Hide resolved
omnisafe/common/normalize.py Outdated Show resolved Hide resolved
omnisafe/algorithms/on_policy/pid_lagrange/trpo_pid.py Outdated Show resolved Hide resolved
omnisafe/evaluator.py Outdated Show resolved Hide resolved
omnisafe/evaluator.py Outdated Show resolved Hide resolved
omnisafe/evaluator.py Outdated Show resolved Hide resolved
omnisafe/evaluator.py Show resolved Hide resolved
@zmsn-2077 zmsn-2077 self-requested a review January 11, 2023 16:10
Copy link
Member

@zmsn-2077 zmsn-2077 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After careful review, I agree with these changes.

Copy link
Member

@calico-1226 calico-1226 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I approve these changes.

@Gaiejj Gaiejj merged commit 60bf0bf into PKU-Alignment:dev Jan 12, 2023
@Gaiejj Gaiejj deleted the dev branch January 12, 2023 08:07
zmsn-2077 pushed a commit that referenced this pull request Feb 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants