You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have noticed that you have already presented the CRPO algorithm in an on-policy version and demonstrated its competitive performance.
Furthermore, the work of CRPO has stated that the algorithm can be extended to the off-policy method. I would like to inquire whether you plan to provide an off-policy version of the CRPO method in the future.
I have referred to the ONCRPO code you provided and followed the guidelines for CRPO to develop an off-policy version of the CRPO algorithm. However, the algorithm parameters require fine-tuning, and I am hoping for an in-depth discussion with your team regarding this matter.
Solution
# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the Lagrangian version of Soft Actor-Critic algorithm."""
import torch
from omnisafe.algorithms import registry
from omnisafe.algorithms.off_policy.sac import SAC
@registry.register
# pylint: disable-next=too-many-instance-attributes, too-few-public-methods
class SACCRPO(SAC):
"""The Off-policy CRPO algorithm.
References:
- Title: CRPO: A New Approach for Safe Reinforcement Learning with Convergence Guarantee.
- Authors: Tengyu Xu, Yingbin Liang, Guanghui Lan.
- URL: `CRPO <https://arxiv.org/pdf/2011.05869.pdf>`_.
"""
def _init(self) -> None:
"""Initialize an instance of :class:`SACCRPO`."""
super()._init()
self._rew_update: int = 0
self._cost_update: int = 0
def _init_log(self) -> None:
"""Log the SACCRPO specific information.
+-----------------+--------------------------------------------+
| Things to log | Description |
+=================+============================================+
| Misc/RewUpdate | The number of times the reward is updated. |
+-----------------+--------------------------------------------+
| Misc/CostUpdate | The number of times the cost is updated. |
+-----------------+--------------------------------------------+
"""
super()._init_log()
self._logger.register_key('Mics/RewUpdate')
self._logger.register_key('Mics/CostUpdate')
self._logger.register_key('Loss/loss_r_mean')
self._logger.register_key('Loss/loss_c_max')
self._logger.register_key('Loss/loss_c_mean')
def _loss_pi(
self,
obs: torch.Tensor,
) -> torch.Tensor:
r"""Compute `pi/actor` loss."""
action = self._actor_critic.actor.predict(obs, deterministic=False)
log_prob = self._actor_critic.actor.log_prob(action)
loss_q_r_1, loss_q_r_2 = self._actor_critic.reward_critic(obs, action)
loss_r = self._alpha * log_prob - torch.min(loss_q_r_1, loss_q_r_2)
loss_c = self._actor_critic.cost_critic(obs, action)[0]
# if loss_c.max().item() <= self._cfgs.algo_cfgs.cost_limit + self._cfgs.algo_cfgs.tolerance:
# self._rew_update +=1
# loss = loss_r
# else:
# self._cost_update += 1
# loss = loss_c
if (loss_c.max().item()>self._cfgs.algo_cfgs.cost_limit + self._cfgs.algo_cfgs.tolerance and
loss_c.mean().item()>loss_r.mean().item()):
self._cost_update += 1
loss = loss_c
else:
self._rew_update +=1
loss = loss_r
self._logger.store(
{
'Mics/RewUpdate': self._rew_update,
'Mics/CostUpdate': self._cost_update,
'Loss/loss_r_mean': loss_r.mean().item(),
'Loss/loss_c_mean': loss_c.mean().item(),
'Loss/loss_c_max': loss_c.max().item(),
}
)
# print('tolerance:',self._cfgs.algo_cfgs.tolerance)
return loss.mean()
def _log_when_not_update(self) -> None:
super()._log_when_not_update()
self._logger.store(
{
'Mics/RewUpdate': self._rew_update,
'Mics/CostUpdate': self._cost_update,
'Loss/loss_r_mean': 0,
'Loss/loss_c_mean': 0,
'Loss/loss_c_max': 0,
}
)
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered:
We express our delight in your proactive implementation of the novel algorithm and extend our gratitude for your contributions to the advancement of safe reinforcement learning. Your implementation is commendably aligned with our stipulated criteria for the off-policy version of the CRPO algorithm. However, there are areas that warrant refinement as follows:
Given the requisite adherence to the prescribed limit of constraint violation for an entire episode in the navigation task within the Safety-Gymnasium framework, it is recommended that the
if (loss_c.max().item()>self._cfgs.algo_cfgs.cost_limit+self._cfgs.algo_cfgs.toleranceandloss_c.mean().item()>loss_r.mean().item()):
where ep_cost is the episodic cost value obataned from logger.
Furthermore, it is advised that greater emphasis be placed on the provision of more comprehensive documentation, and the inclusion of performance curves pertinent to the relevant algorithm.
Prior to submission, we encourage the execution of make pre-commit and make test commands at the root directory to ensure the adherence of the codebase to the established standards of OmnniSafe.
These suggestions are expected to significantly enhance the quality of your CRPO implementation. Should any queries or uncertainties arise, please feel free to engage in a discourse with us.
Required prerequisites
Motivation
Solution
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: