Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Would you be providing an off-policy version of the CRPO method in the later stages? #267

Closed
2 tasks done
guanjiayi opened this issue Aug 14, 2023 · 2 comments
Labels
enhancement New feature or request

Comments

@guanjiayi
Copy link

Required prerequisites

Motivation

  • I have noticed that you have already presented the CRPO algorithm in an on-policy version and demonstrated its competitive performance.
  • Furthermore, the work of CRPO has stated that the algorithm can be extended to the off-policy method. I would like to inquire whether you plan to provide an off-policy version of the CRPO method in the future.
  • I have referred to the ONCRPO code you provided and followed the guidelines for CRPO to develop an off-policy version of the CRPO algorithm. However, the algorithm parameters require fine-tuning, and I am hoping for an in-depth discussion with your team regarding this matter.

Solution

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the Lagrangian version of Soft Actor-Critic algorithm."""


import torch

from omnisafe.algorithms import registry
from omnisafe.algorithms.off_policy.sac import SAC


@registry.register
# pylint: disable-next=too-many-instance-attributes, too-few-public-methods
class SACCRPO(SAC):
    """The Off-policy CRPO algorithm.

    References:
        - Title: CRPO: A New Approach for Safe Reinforcement Learning with Convergence Guarantee.
        - Authors: Tengyu Xu, Yingbin Liang, Guanghui Lan.
        - URL: `CRPO <https://arxiv.org/pdf/2011.05869.pdf>`_.
    """

    def _init(self) -> None:
        """Initialize an instance of :class:`SACCRPO`."""

        super()._init()
        self._rew_update: int  = 0
        self._cost_update: int = 0

    def _init_log(self) -> None:
        """Log the SACCRPO specific information.

        +-----------------+--------------------------------------------+
        | Things to log   | Description                                |
        +=================+============================================+
        | Misc/RewUpdate  | The number of times the reward is updated. |
        +-----------------+--------------------------------------------+
        | Misc/CostUpdate | The number of times the cost is updated.   |
        +-----------------+--------------------------------------------+
        """
        super()._init_log()
        self._logger.register_key('Mics/RewUpdate')
        self._logger.register_key('Mics/CostUpdate')
        self._logger.register_key('Loss/loss_r_mean')
        self._logger.register_key('Loss/loss_c_max')
        self._logger.register_key('Loss/loss_c_mean')


    def _loss_pi(
        self,
        obs: torch.Tensor,
    ) -> torch.Tensor:
        r"""Compute `pi/actor` loss."""
        action = self._actor_critic.actor.predict(obs, deterministic=False)
        log_prob = self._actor_critic.actor.log_prob(action)
        loss_q_r_1, loss_q_r_2 = self._actor_critic.reward_critic(obs, action)
        loss_r = self._alpha * log_prob - torch.min(loss_q_r_1, loss_q_r_2)
        loss_c = self._actor_critic.cost_critic(obs, action)[0]

        # if loss_c.max().item() <= self._cfgs.algo_cfgs.cost_limit + self._cfgs.algo_cfgs.tolerance:
        #     self._rew_update +=1
        #     loss = loss_r
        # else:
        #     self._cost_update += 1
        #     loss = loss_c

        if (loss_c.max().item()>self._cfgs.algo_cfgs.cost_limit + self._cfgs.algo_cfgs.tolerance and
            loss_c.mean().item()>loss_r.mean().item()):
            self._cost_update += 1
            loss = loss_c
        else:
            self._rew_update +=1
            loss = loss_r

        self._logger.store(
            {
                'Mics/RewUpdate': self._rew_update,
                'Mics/CostUpdate': self._cost_update,
                'Loss/loss_r_mean': loss_r.mean().item(),
                'Loss/loss_c_mean': loss_c.mean().item(),
                'Loss/loss_c_max': loss_c.max().item(),
            }
        )
        # print('tolerance:',self._cfgs.algo_cfgs.tolerance)
        return loss.mean()


    def _log_when_not_update(self) -> None:
        super()._log_when_not_update()
        self._logger.store(
            {
                'Mics/RewUpdate': self._rew_update,
                'Mics/CostUpdate': self._cost_update,
                'Loss/loss_r_mean': 0,
                'Loss/loss_c_mean': 0,
                'Loss/loss_c_max': 0,

            }
        )

Alternatives

No response

Additional context

No response

@guanjiayi guanjiayi added the enhancement New feature or request label Aug 14, 2023
@Gaiejj
Copy link
Member

Gaiejj commented Aug 15, 2023

We express our delight in your proactive implementation of the novel algorithm and extend our gratitude for your contributions to the advancement of safe reinforcement learning. Your implementation is commendably aligned with our stipulated criteria for the off-policy version of the CRPO algorithm. However, there are areas that warrant refinement as follows:

  • Given the requisite adherence to the prescribed limit of constraint violation for an entire episode in the navigation task within the Safety-Gymnasium framework, it is recommended that the
 if (loss_c.max().item()>self._cfgs.algo_cfgs.cost_limit + self._cfgs.algo_cfgs.tolerance and
            loss_c.mean().item()>loss_r.mean().item()):

should be :

 if ep_cost>self._cfgs.algo_cfgs.cost_limit + self._cfgs.algo_cfgs.tolerance:

where ep_cost is the episodic cost value obataned from logger.

  • Furthermore, it is advised that greater emphasis be placed on the provision of more comprehensive documentation, and the inclusion of performance curves pertinent to the relevant algorithm.
  • Prior to submission, we encourage the execution of make pre-commit and make test commands at the root directory to ensure the adherence of the codebase to the established standards of OmnniSafe.

These suggestions are expected to significantly enhance the quality of your CRPO implementation. Should any queries or uncertainties arise, please feel free to engage in a discourse with us.

@guanjiayi
Copy link
Author

Thank you for your reply, and we also extend our sincere gratitude for your valuable suggestions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants