-
Notifications
You must be signed in to change notification settings - Fork 126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
style(model-based): fix mypy and polish api docstring #240
Conversation
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 lines blank.
_env_id (str): The environment id. | ||
_device (torch.device): The device. | ||
_env (CMDP): The environment. | ||
_cfgs (Config): The configuration. | ||
_ep_ret (torch.Tensor): The episode return. | ||
_ep_cost (torch.Tensor): The episode cost. | ||
_ep_len (torch.Tensor): The episode length. | ||
_last_dynamics_update (float): The last time of dynamics update. | ||
_last_policy_update (float): The last time of policy update. | ||
_last_eval (float): The last time of evaluation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Private attributes is not need to be shown.
Codecov Report
❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more. @@ Coverage Diff @@
## dev #240 +/- ##
==========================================
- Coverage 96.99% 96.98% -0.01%
==========================================
Files 134 134
Lines 6867 6888 +21
==========================================
+ Hits 6660 6680 +20
- Misses 207 208 +1
... and 1 file with indirect coverage changes 📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
def get_cost_from_obs_tensor(self, obs: torch.Tensor) -> torch.Tensor: | ||
"""Get cost from tensor observation. | ||
|
||
Args: | ||
obs (torch.Tensor): The observation. | ||
obs (torch.Tensor): The tensor version of observation. | ||
""" | ||
return ( | ||
self._env.get_cost_from_obs_tensor(obs) | ||
if hasattr(self._env, 'get_cost_from_obs_tensor') | ||
else None | ||
else torch.zeros(1) | ||
) | ||
|
||
def get_lidar_from_coordinate(self, obs: torch.Tensor) -> torch.Tensor | None: | ||
def get_lidar_from_coordinate(self, obs: np.ndarray) -> torch.Tensor | None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Original version None
is pretty bad because there is no if
else
process to deal with None
, here use zero tensor as default.
|
||
Keyword Args: | ||
render_mode (str, optional): The render mode, ranging from 'human', 'rgb_array', 'rgb_array_list'. | ||
Defaults to 'rgb_array'. | ||
camera_name (str, optional): The camera name. | ||
camera_id (int, optional): The camera id. | ||
width (int, optional): The width of the rendered image. Defaults to 256. | ||
height (int, optional): The height of the rendered image. Defaults to 256. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A new style of **kwargs
docstring.
@@ -262,23 +282,22 @@ def roll_out( # pylint: disable=too-many-arguments,too-many-locals | |||
truncated, | |||
next_state, | |||
info, | |||
action_info, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused arguments.
) | ||
self._current_obs = next_state | ||
if terminated or truncated: | ||
self._log_metrics(logger) | ||
self._reset_log() | ||
self._current_obs, _ = self.reset() | ||
if algo_reset_func is not None: | ||
algo_reset_func(current_step) | ||
algo_reset_func() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused arguments
if torch.is_tensor(data): | ||
return (data - self._mean_t) / self._std_t | ||
return (data - self._mean) / self._std | ||
return (data - self._mean_t) / self._std_t |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused branch.
return action, info | ||
return action |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The info={}
is useless.
assert isinstance(delta_state, torch.Tensor), 'delta_state should be torch.Tensor' | ||
inputs = torch.cat((state, action), -1) | ||
inputs = torch.reshape(inputs, (inputs.shape[0], -1)) | ||
|
||
labels = torch.reshape(delta_state, (delta_state.shape[0], -1)) | ||
if self._cfgs.dynamics_cfgs.predict_reward: | ||
labels = torch.cat(((torch.reshape(reward, (reward.shape[0], -1))), labels), -1) | ||
if self._cfgs.dynamics_cfgs.predict_cost: | ||
labels = torch.cat(((torch.reshape(cost, (cost.shape[0], -1))), labels), -1) | ||
inputs = inputs.cpu().detach().numpy() | ||
labels = labels.cpu().detach().numpy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the delta_state
is not a tensor, something went wrong. So I think assert
is better than if
here.
@@ -206,34 +224,46 @@ def _wrapper( | |||
if self._env.num_envs == 1: | |||
self._env = Unsqueeze(self._env, device=self._device) | |||
|
|||
def roll_out( # pylint: disable=too-many-arguments,too-many-locals |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rollout, instead of roll_out.
@@ -244,16 +274,14 @@ def roll_out( # pylint: disable=too-many-arguments,too-many-locals | |||
|
|||
epoch_steps = 0 | |||
|
|||
while epoch_steps < roll_out_step and current_step < self._cfgs.train_cfgs.total_steps: | |||
action, action_info = act_func(current_step, self._current_obs, self._env) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return info
is {}
, which is useless.
self._predict_cost = predict_cost | ||
self._state_size: int = state_size | ||
self._reward_size: int = reward_size | ||
self._cost_siz: int = cost_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._cost_siz: int = cost_size | |
self._cost_size: int = cost_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this cause exceptions in CI?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I don't know neither but I fixed it.
@@ -466,8 +581,9 @@ def train( | |||
holdout_inputs, holdout_labels = inputs[:num_holdout], labels[:num_holdout] | |||
self._ensemble_model.scaler.fit(train_inputs) | |||
|
|||
train_mse_losses = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this change the original code behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have considered it. Because original function use the train_mse_loss
in the for loop, that will raise unbound
error. I just log the total training loss to log, but still use loss in loop to determine whether to break. The performance curve seems it does not change original code behavior.
labels = torch.cat(((torch.reshape(cost, (cost.shape[0], -1))), labels), -1) | ||
inputs = inputs.cpu().detach().numpy() | ||
labels = labels.cpu().detach().numpy() | ||
assert not torch.is_tensor(inputs) and not torch.is_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this assertion useless?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can see in line 263 and 264 that the inputs
and labels
have been transferred to numpy.ndarray()
, so i don't think this assertion makes any sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Description
list
->list[float]
.np.ndarray
but actually it istorch.Tensor
.Motivation and Context
Some codes and docs exist problems and need to be polished.
This pull request solves issue #230
Types of changes
What types of changes does your code introduce? Put an
x
in all the boxes that apply:Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!
make format
. (required)make lint
. (required)make test
pass. (required)