Skip to content

Conversation

@yitianlian
Copy link
Collaborator

No description provided.

help="The rollout routing replay technique from https://arxiv.org/abs/2510.11370",
)
parser.add_argument(
"--enable-opsm",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please rename to --use-xxx to align the naming to other flags. and I wonder if we can have an additional --use-off-policy-sequence-mask because the origin paper does not mention opsm abbreviation.

also please add the dpsk v3.2 paper to the help.


opsm_mask = None
opsm_clipfrac_num = 0
if getattr(self.args, "enable_opsm", False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if getattr(self.args, "enable_opsm", False):
if self.args.use_opsm:

ppo_kl = old_log_probs - log_probs

opsm_mask = None
opsm_clipfrac_num = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please move these 2 values inside the if getattr(self.args, "enable_opsm", False): and always use enable_opsm as the flag to check if opsm is enabled.

]
# Pre-gather log probs if needed by OPSM or GSPO to avoid duplicate gathering
cp_size = mpu.get_context_parallel_world_size()
need_full_log_probs = (args.enable_opsm or args.advantage_estimator == "gspo") and cp_size > 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
need_full_log_probs = (args.enable_opsm or args.advantage_estimator == "gspo") and cp_size > 1
need_full_log_probs = (args.enable_opsm or args.advantage_estimator == "gspo")

it seems that the code below has deal with cp == 1.

]
if cp_size > 1
else log_probs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there seems to be early return in all_gather_with_cp we can always run all_gather_with cp here.

ppo_kl = torch.cat(ppo_kl, dim=0)
# Compute OPSM mask if enabled
opsm_mask = None
opsm_clipfrac_num = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as the comment in fsdp, please move into the args.enable_opsm

old_log_probs = torch.cat(old_log_probs, dim=0)
log_probs = torch.cat(log_probs, dim=0)
ppo_kl = old_log_probs - log_probs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when running with opsm and without gspo, will the shape of ppo_kl different from opsm_mask?


# Calculate sequence-level advantage (mean of advantage values)
# For GRPO, advantage is constant across the sequence, so mean == any element
seq_advantage = advantage.mean()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm... it seems that from the dpsk v3.2 paper, we dont need to have per-sequence adv but per token adv.

# This mask applies to the entire sequence
condition = (seq_advantage < 0) & (seq_kl > args.opsm_delta)
mask_chunk = torch.where(condition, torch.zeros_like(local_log_prob), torch.ones_like(local_log_prob))
opsm_clipfrac_num += condition.int().item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't do .item() during the loop, it will trigger gpu -> cpu copy and synchronization.

@zhuzilin zhuzilin merged commit caa4af2 into THUDM:main Dec 2, 2025
4 of 5 checks passed
Fengzdadi pushed a commit to Fengzdadi/slime that referenced this pull request Dec 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants