You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add support for masking out invalid actions when making predictions. This would allow models to converge much faster for environments where many actions may be illegal depending on the state.
This feature could be implemented as a wrapper around a gym environment that adds a method to return the mask. The stable baselines algorithms would check for the wrapper and use the mask if available. The mask is a boolean tensor in the shape of the action space, and it replaces the logits for invalid actions with very large negative values in the underlying probability distribution.
In many environments, there may be portions of the action space that are invalid to select in a given state. Without a way to avoid sampling these actions, training becomes less efficient. Models have to waste time exploring the invalid portions of the space, which may become prohibitively expensive for large action spaces. See this paper for more details.
Alternatives
To the best of my knowledge, the only alternative is to accept that invalid actions may be selected, and to try to discourage it by penalizing the choice with a large negative reward. This is just the status quo. Note that action masking would be optional, and the status quo would stay the default.
Additional context
I ran into this problem in practice when building models for board games. To work around it, I implemented an MVP of this feature in a fork. I'd be happy to make a PR. The branch is here.
Checklist
I have checked that there is no similar issue in the repo (required)
The text was updated successfully, but these errors were encountered:
this would be a good for our contrib repo: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
Please open the same issue there and close that one, make sure to read CONTRIBUTING guide of the contrib repo first ;) (there are some important details in there).
Hello @araffin, thanks for all the work on sb3. sb3-contrib looks like a good home for most of the changes (e.g. wrappers, distributions). The only concern I have is that the internals of the algorithms have to change to use the masks. For example, forward() in ActorCriticPolicy or collect_rollouts() in OnPolicyAlgorithm. So I'm concerned that a lot of the "base" algorithm code will have to be ported/duplicated in sb3-contrib.
🚀 Feature
Add support for masking out invalid actions when making predictions. This would allow models to converge much faster for environments where many actions may be illegal depending on the state.
This feature could be implemented as a wrapper around a gym environment that adds a method to return the mask. The stable baselines algorithms would check for the wrapper and use the mask if available. The mask is a boolean tensor in the shape of the action space, and it replaces the logits for invalid actions with very large negative values in the underlying probability distribution.
Here is an MVP.
Motivation
In many environments, there may be portions of the action space that are invalid to select in a given state. Without a way to avoid sampling these actions, training becomes less efficient. Models have to waste time exploring the invalid portions of the space, which may become prohibitively expensive for large action spaces. See this paper for more details.
Alternatives
To the best of my knowledge, the only alternative is to accept that invalid actions may be selected, and to try to discourage it by penalizing the choice with a large negative reward. This is just the status quo. Note that action masking would be optional, and the status quo would stay the default.
Additional context
I ran into this problem in practice when building models for board games. To work around it, I implemented an MVP of this feature in a fork. I'd be happy to make a PR. The branch is here.
Checklist
The text was updated successfully, but these errors were encountered: