-
Notifications
You must be signed in to change notification settings - Fork 406
[Feature] Add off-policy sequence masking algorithm proposed in DeepSeek v3.2 #999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
slime/utils/arguments.py
Outdated
| help="The rollout routing replay technique from https://arxiv.org/abs/2510.11370", | ||
| ) | ||
| parser.add_argument( | ||
| "--enable-opsm", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please rename to --use-xxx to align the naming to other flags. and I wonder if we can have an additional --use-off-policy-sequence-mask because the origin paper does not mention opsm abbreviation.
also please add the dpsk v3.2 paper to the help.
slime/backends/fsdp_utils/actor.py
Outdated
|
|
||
| opsm_mask = None | ||
| opsm_clipfrac_num = 0 | ||
| if getattr(self.args, "enable_opsm", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if getattr(self.args, "enable_opsm", False): | |
| if self.args.use_opsm: |
slime/backends/fsdp_utils/actor.py
Outdated
| ppo_kl = old_log_probs - log_probs | ||
|
|
||
| opsm_mask = None | ||
| opsm_clipfrac_num = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please move these 2 values inside the if getattr(self.args, "enable_opsm", False): and always use enable_opsm as the flag to check if opsm is enabled.
| ] | ||
| # Pre-gather log probs if needed by OPSM or GSPO to avoid duplicate gathering | ||
| cp_size = mpu.get_context_parallel_world_size() | ||
| need_full_log_probs = (args.enable_opsm or args.advantage_estimator == "gspo") and cp_size > 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| need_full_log_probs = (args.enable_opsm or args.advantage_estimator == "gspo") and cp_size > 1 | |
| need_full_log_probs = (args.enable_opsm or args.advantage_estimator == "gspo") |
it seems that the code below has deal with cp == 1.
| ] | ||
| if cp_size > 1 | ||
| else log_probs | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there seems to be early return in all_gather_with_cp we can always run all_gather_with cp here.
| ppo_kl = torch.cat(ppo_kl, dim=0) | ||
| # Compute OPSM mask if enabled | ||
| opsm_mask = None | ||
| opsm_clipfrac_num = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as the comment in fsdp, please move into the args.enable_opsm
| old_log_probs = torch.cat(old_log_probs, dim=0) | ||
| log_probs = torch.cat(log_probs, dim=0) | ||
| ppo_kl = old_log_probs - log_probs | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when running with opsm and without gspo, will the shape of ppo_kl different from opsm_mask?
slime/utils/ppo_utils.py
Outdated
|
|
||
| # Calculate sequence-level advantage (mean of advantage values) | ||
| # For GRPO, advantage is constant across the sequence, so mean == any element | ||
| seq_advantage = advantage.mean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm... it seems that from the dpsk v3.2 paper, we dont need to have per-sequence adv but per token adv.
slime/utils/ppo_utils.py
Outdated
| # This mask applies to the entire sequence | ||
| condition = (seq_advantage < 0) & (seq_kl > args.opsm_delta) | ||
| mask_chunk = torch.where(condition, torch.zeros_like(local_log_prob), torch.ones_like(local_log_prob)) | ||
| opsm_clipfrac_num += condition.int().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please don't do .item() during the loop, it will trigger gpu -> cpu copy and synchronization.
No description provided.