Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] An error in MaskPPO training #81

Open
Yangxiaojun1230 opened this issue Jul 5, 2022 · 19 comments
Open

[Bug] An error in MaskPPO training #81

Yangxiaojun1230 opened this issue Jul 5, 2022 · 19 comments
Labels
more information needed Please fill the issue template completely

Comments

@Yangxiaojun1230
Copy link

System Info
Describe the characteristic of your environment:

Describe how the library was installed: pip
sb3-contrib=='1.5.1a9'
Python: 3.8.13
Stable-Baselines3: 1.5.1a9
PyTorch: 1.11.0+cu102
GPU Enabled: False
Numpy: 1.22.3
Gym: 0.21.0

My training code as below:
model = MaskablePPO("MultiInputPolicy", env, gamma=0.4, seed=32, verbose=0)
model.learn(300000)
My action space is spaces.Discrete() . It seems a problem in torch distribution init(), the input logits had invalid value. And the error happened at uncertain training step.

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/sb3_contrib/ppo_mask/ppo_mask.py:579, in MaskablePPO.learn(self, total_timesteps, callback, log_interval, eval_env, eval_freq, n_eval_episodes, tb_log_name, eval_log_path, reset_num_timesteps, use_masking)
576 self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
577 self.logger.dump(step=self.num_timesteps)
--> 579 self.train()
581 callback.on_training_end()
583 return self

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/sb3_contrib/ppo_mask/ppo_mask.py:439, in MaskablePPO.train(self)
435 if isinstance(self.action_space, spaces.Discrete):
436 # Convert discrete action from float to long
437 actions = rollout_data.actions.long().flatten()
--> 439 values, log_prob, entropy = self.policy.evaluate_actions(
440 rollout_data.observations,
441 actions,
442 action_masks=rollout_data.action_masks,
443 )
445 values = values.flatten()
446 # Normalize advantage

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/sb3_contrib/common/maskable/policies.py:280, in MaskableActorCriticPolicy.evaluate_actions(self, obs, actions, action_masks)
278 distribution = self._get_action_dist_from_latent(latent_pi)
279 if action_masks is not None:
--> 280 distribution.apply_masking(action_masks)
281 log_prob = distribution.log_prob(actions)
282 values = self.value_net(latent_vf)

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/sb3_contrib/common/maskable/distributions.py:152, in MaskableCategoricalDistribution.apply_masking(self, masks)
150 def apply_masking(self, masks: Optional[np.ndarray]) -> None:
151 assert self.distribution is not None, "Must set distribution parameters"
--> 152 self.distribution.apply_masking(masks)

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/sb3_contrib/common/maskable/distributions.py:62, in MaskableCategorical.apply_masking(self, masks)
59 logits = self._original_logits
61 # Reinitialize with updated logits
---> 62 super().init(logits=logits)
64 # self.probs may already be cached, so we must force an update
65 self.probs = logits_to_probs(self.logits)

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/torch/distributions/categorical.py:64, in Categorical.init(self, probs, logits, validate_args)
62 self._num_events = self._param.size()[-1]
63 batch_shape = self._param.size()[:-1] if self._param.ndimension() > 1 else torch.Size()
---> 64 super(Categorical, self).init(batch_shape, validate_args=validate_args)

File ~/anaconda3/envs/stable_base/lib/python3.8/site-packages/torch/distributions/distribution.py:55, in Distribution.init(self, batch_shape, event_shape, validate_args)
53 valid = constraint.check(value)
54 if not valid.all():
---> 55 raise ValueError(
56 f"Expected parameter {param} "
57 f"({type(value).name} of shape {tuple(value.shape)}) "
58 f"of distribution {repr(self)} "
59 f"to satisfy the constraint {repr(constraint)}, "
60 f"but found invalid values:\n{value}"
61 )
62 super(Distribution, self).init()

ValueError: Expected parameter probs (Tensor of shape (64, 400)) of distribution MaskableCategorical(probs: torch.Size([64, 400]), logits: torch.Size([64, 400])) to satisfy the constraint Simplex(), but found invalid values:

@araffin araffin added the more information needed Please fill the issue template completely label Jul 5, 2022
@Yangxiaojun1230
Copy link
Author

I checked my env by using "check_env(env)", which failed and output error message. But this env was successfully used before through sb3-ppo.
AssertionError: Error while checking key=bin_size_h: The observation returned by the reset() method does not match the given observation space.

My observation spaces and obs declared as below, I couldn't find any problem
self.observation_space = spaces.Dict(spaces=
{
"state_grid":spaces.MultiBinary(self.max_num),
"node_placed":spaces.MultiBinary(self.max_inst_num),
"cur_node_w": spaces.Box(low=0.0, high=1, shape=(1,), dtype=np.float32),
"cur_node_h": spaces.Box(low=0.0, high=1, shape=(1,), dtype=np.float32),
"bin_size_w": spaces.Box(low=0.0, high=1, shape=(self.max_num,), dtype=np.float32),
"bin_size_h": spaces.Box(low=0.0, high=1, shape=(self.max_num,), dtype=np.float32),
"node_size_w": spaces.Box(low=0.0, high=1, shape=(self.max_inst_num,), dtype=np.float32),
"node_size_h": spaces.Box(low=0.0, high=1, shape=(self.max_inst_num,), dtype=np.float32),
}
def get_obs(self):
return collections.OrderedDict([
("state_grid",self.state_bin),
("node_placed",self.node_placed),
("cur_node_w",self.cur_node_w),
("cur_node_h",self.cur_node_h),
("bin_size_w", self.bin_size_w),
("bin_size_h", self.bin_size_h ),
("node_size_w",self.node_size_w),
("node_size_h",self.node_size_h)
])

@Miffyli
Copy link
Member

Miffyli commented Jul 5, 2022

Hey. We do not offer tech support and it is hard to give guidance without further code. If you can replicate this issue with a minimal code and you believe it should be right, please include minimal code to replicate the issue. Meanwhile, I recommend double-checking what comes out of your reset function.

@Yangxiaojun1230
Copy link
Author

Hi guys,
I solved the problem by changing dtype=np.float32 -> dtype=np.float64.

@Yangxiaojun1230
Copy link
Author

Hi guys,
The error happened again, and I found the root reason is in torch Categorical class, it will do some constraint check. The failed check is the |sum(probs)-1|<1e-6 .
The value of (sum(probs)-1) in the case is -1.6e-6, is this caused by apply_mask() function?
Mybe in the function could set the dtype to float64 or change 1e-8 to 1e-6 in below code? Any advice will be appreciate
”HUGE_NEG = th.tensor(-1e8, dtype=self.logits.dtype, device=device) ”

@svolokh
Copy link

svolokh commented Jul 11, 2022

I've run into the same problem when using the maskable PPO implementation (this is only relevant in debug mode, where arg validation is enabled by default). Here is a repro for the problem.

The issue seems to be that, for some combination of logits, logits_to_probs will return a value for probs that does not sum to 1 within the tolerance limit (due to precision issues), causing the arg validation constraint for the Categorical class to fail. Normally (without action masking), the Categorical constructor does not run arg validation for probs since it is not present when the constructor is invoked. However, with action masking enabled, the Categorical constructor ends up being called multiple times here, once with probs set, so it runs arg validation on probs and therefore will fail on rare occasion.

One way I've found fix this is to change how the apply_masking method deals with the cached probs here, instead of force updating probs just remove it if it is present before calling the constructor. So basically introduce the following code before the constructor call:

# remove cached probs if present
if 'probs' in self.__dict__:
    delattr(self, 'probs')

# Reinitialize with updated logits
super().__init__(logits=logits)

If this looks reasonable I can make a PR. Thanks!

@Yangxiaojun1230
Copy link
Author

@svolokh Thanks for your infomation.
In my case, I overwrite the F.softmax(logits,dim=-1, dtype=torch.double) in torch to make it work.

@araffin
Copy link
Member

araffin commented Jul 18, 2022

@svolokh thanks for the info, would validate_args=False solve also the issue? (probably cleaner than deleting the cached probs)

@svolokh
Copy link

svolokh commented Jul 19, 2022

@araffin That does indeed fix the issue as well!

@araffin
Copy link
Member

araffin commented Jul 20, 2022

Good to hear =)
then i would be happy to receive a PR that solves this issue ;)

svolokh added a commit to svolokh/stable-baselines3-contrib that referenced this issue Jul 22, 2022
@dervan
Copy link

dervan commented Jul 22, 2022

Hi, I happened to have the same issue and I did the very same fix as @svolokh in first post. I was quite surprised that @araffin decided that ignoring validation is cleaner solution: I think it definitely isn't! The validation of logits itself should be done, and calling the init method of Categorical class with already (incorrectly!) filled probs is at least suspicious practice. Could you elaborate why do you think, that removing the probs is not a good idea?

@araffin
Copy link
Member

araffin commented Jul 25, 2022

Hello,

Could you elaborate why do you think, that removing the probs is not a good idea?

The idea behind it is to use a feature that is in the interface of PyTorch, to avoid manual deleting of attribute (which may have side effects).

EDIT: another reason is that deleting the attribute has the same effect: #81 (comment)

the init method of Categorical class with already (incorrectly!) filled probs is at least suspicious practice.

of course, best would be to solve the root cause of the problem.
I haven't looked too much in that problem (didn't wrote that code neither) but from what I get, it is an error due to numerical imprecision, so if you have a better solution that avoid deleting attributes and keep argument validation, I'm of course up for it ;)

@hsjung02
Copy link

Hi, I got the same error and found this issue.
Is there any reason that the validate_args=False is not released?

@koliber31
Copy link

Hi, I got the same error and found this issue. Is there any reason that the validate_args=False is not released?

Does in work in your case after changing that line of code?

@hsjung02
Copy link

Yes, at least it doesn't produce the same error as before. However, I don't know whether it affects the learning performance,

@koliber31
Copy link

Yes, at least it doesn't produce the same error as before. However, I don't know whether it affects the learning performance,

Did you check if your agent learns after this change? I mean does it learn at all because after this change this error stopped occuring but agent wasn't able to learn anything.

@hsjung02
Copy link

hsjung02 commented Jul 13, 2023 via email

@koliber31
Copy link

I did check it and as i said it stopped learning at all. Below are screenshots of learning cureves, these with 150k timesteps and 260k timesteps show learning ended with error (without validate_args=False) and one with 4M timesteps show learning with validate_args=False.
260kSteps
150kSteps
wykresy
As You can tell after this change it doesn't learn at all. Would you do me a favor and run your learning (if you still have it ofc) just to see if something is happening and tell me the results?

@hsjung02
Copy link

hsjung02 commented Jul 13, 2023 via email

@yiptsangkin
Copy link

As You can tell after this change it doesn't learn at all. Would you do me a favor and run your learning (if you still have it ofc) just to see if something is happening and tell me the results?

same problem, change validate_args to false i can not learning anymore, had you solve this problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
more information needed Please fill the issue template completely
Projects
None yet
Development

No branches or pull requests

8 participants