-
Notifications
You must be signed in to change notification settings - Fork 126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: change wrapper setting #73
Conversation
@@ -13,15 +13,15 @@ | |||
# limitations under the License. | |||
# ============================================================================== | |||
"""Implementation of the DDPG algorithm.""" | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This blank line is required.
@@ -13,6 +13,7 @@ | |||
# limitations under the License. | |||
# ============================================================================== | |||
"""Implementation of the SAC algorithm.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blank line.
@@ -13,25 +13,27 @@ | |||
# limitations under the License. | |||
# ============================================================================== | |||
"""Implementation of the Lagrange version of the SAC algorithm.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blank line.
@@ -13,6 +13,7 @@ | |||
# limitations under the License. | |||
# ============================================================================== | |||
"""Implementation of the SDDPG algorithm.""" | |||
from typing import Dict, NamedTuple, Tuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blank line.
@@ -13,6 +13,7 @@ | |||
# limitations under the License. | |||
# ============================================================================== | |||
"""Implementation of the TD3 algorithm.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blank line.
@@ -13,23 +13,27 @@ | |||
# limitations under the License. | |||
# ============================================================================== | |||
"""Implementation of the Lagrange version of the TD3 algorithm.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blank line.
@@ -28,6 +30,11 @@ class PPOEarlyTerminated(PPO): | |||
URL: https://arxiv.org/abs/2107.04200 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this line different from Reference's comment in off-policy algorithms?
"""Update.""" | ||
raw_data, data = self.buf.pre_process_data() | ||
# First update Lagrange multiplier parameter | ||
# pylint: disable=too-many-locals |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we use disable?
from omnisafe.common.lagrange import Lagrange | ||
|
||
|
||
@registry.register | ||
class PPOLag(PolicyGradient, Lagrange): | ||
class PPOLag(PPO, Lagrange): | ||
"""The Lagrange version of the PPO algorithm. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I need to cite the safety gym paper here?
URL: https://arxiv.org/abs/1705.10528 | ||
- Title: Constrained Policy Optimization | ||
- Authors: Joshua Achiam, David Held, Aviv Tamar, Pieter Abbeel. | ||
- URL: https://arxiv.org/abs/1705.10528 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
URL style.
@@ -37,42 +39,82 @@ class PCPO(TRPO): | |||
URL: https://arxiv.org/abs/2010.03152 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
URL styple.
info = {} | ||
return cost_loss, info | ||
|
||
# pylint: disable=too-many-locals,invalid-name,too-many-arguments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
disable-next
loss_pi_before = loss_pi.item() | ||
# get prob. distribution before updates | ||
p_dist = self.actor_critic.actor(data['obs']) | ||
p_dist = self.actor_critic.actor(obs) | ||
# Train policy with multiple steps of gradient descent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lowercase
loss_pi, pi_info = self.compute_loss_pi(data=data) | ||
# Process the advantage function. | ||
processed_adv = self.compute_surrogate(adv=adv, cost_adv=cost_adv) | ||
# Compute the loss of policy net. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lowercase
self.actor_optimizer.zero_grad() | ||
loss_pi, pi_info = self.compute_loss_pi(data=data) | ||
# Process the advantage function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lowercase
omnisafe/common/normalize.py
Outdated
return self.std | ||
|
||
def forward(self, raw_data=None): | ||
"""FOrward""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
error spelling
if terminal and not epoch_ended: | ||
# Log info about epoch | ||
self.rollout_log(logger, idx) | ||
# Only save EpRet / EpLen if trajectory finished |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete this line?
omnisafe/wrappers/cmdp_wrapper.py
Outdated
if use_rand_action: | ||
action = self.env.action_space.sample() | ||
# Step the env | ||
# pylint: disable=unused-variable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
disable
else: | ||
logger.store(**{'Values/V': value}) | ||
@WRAPPER_REGISTRY.register | ||
class EarlyTerminatedWrapper(CMDPWrapper): # pylint: disable=too-many-instance-attributes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
disable
|
||
from omnisafe.algorithms import registry | ||
from omnisafe.algorithms.off_policy.td3 import TD3 | ||
from omnisafe.common.lagrange import Lagrange | ||
from omnisafe.models.constraint_actor_q_critic import ConstraintActorQCritic | ||
|
||
|
||
@registry.register | ||
class TD3Lag(TD3, Lagrange): # pylint: disable=too-many-instance-attributes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
disable
|
||
from omnisafe.algorithms import registry | ||
from omnisafe.algorithms.off_policy.td3 import TD3 | ||
from omnisafe.common.lagrange import Lagrange | ||
from omnisafe.models.constraint_actor_q_critic import ConstraintActorQCritic | ||
|
||
|
||
@registry.register | ||
class TD3Lag(TD3, Lagrange): # pylint: disable=too-many-instance-attributes | ||
"""The Lagrange version of the TD3 algorithm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.
"""compute loss for policy""" | ||
dist, _log_p = self.actor_critic.actor(data['obs'], data['act']) | ||
ratio = torch.exp(_log_p - data['log_p']) | ||
# pylint: disable=too-many-arguments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
disable
def compute_loss_pi(self, data: dict): | ||
"""Compute policy loss.""" | ||
dist, _log_p = self.actor_critic.actor(data['obs'], data['act']) | ||
# pylint: disable=too-many-arguments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
disable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After careful review, I agree with these changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I approve these changes.
Description
refactor: change wrapper setting
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213
if this solves the issue #15213Types of changes
What types of changes does your code introduce? Put an
x
in all the boxes that apply:Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!
make format
. (required)make lint
. (required)make test
pass. (required)