fix: skip_reference_policy_logprobs_calculation=true crashes training#2174
Closed
ShriyaRishab wants to merge 1 commit into
Closed
fix: skip_reference_policy_logprobs_calculation=true crashes training#2174ShriyaRishab wants to merge 1 commit into
ShriyaRishab wants to merge 1 commit into
Conversation
Fixes NVIDIA-NeMo#1968: Setting skip_reference_policy_logprobs_calculation=true with reference_policy_kl_penalty=0 crashed training in three ways: Bug 1: use_reference_model() context manager crash when reference model was never initialized (AttributeError on reference_state_dict). Fix: Added early-return guard in use_reference_model() for all three worker types (megatron, dtensor v1, dtensor v2) - yields without swapping when reference model is None/missing. Bug 2: Async GRPO path unconditionally called get_reference_policy_logprobs() without checking the skip flag. Fix: Added the same skip guard as the sync path, setting zeros_like for reference_policy_logprobs when skipping. Bug 3: Missing reference_policy_logprobs key in train_data causing shape mismatches downstream in loss computation. Fix: Both sync and async paths now explicitly set train_data['reference_policy_logprobs'] = zeros_like(prev_logprobs) when skipping. Also added a _has_reference_model() helper and zeros fallback in base_policy_worker.get_reference_policy_logprobs() as defense-in-depth.
4 tasks
Contributor
|
@ShriyaRishab skip_reference_policy_logprobs_calculation should just skip the computation of reference model logprobs, but should not skip the initialization of reference model states, why do we need to do the changes in |
jinglinglingling
added a commit
to jinglinglingling/RL
that referenced
this pull request
May 9, 2026
Cherry-picked PR NVIDIA-NeMo#2174 didn't run ruff format on the worker files it touched. This commit applies the format pass so subsequent diffs stay clean. No functional changes. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Linglin Jing <linglinj@nvidia.com>
jinglinglingling
added a commit
to jinglinglingling/RL
that referenced
this pull request
May 9, 2026
Adds a functional smoke test for the path enabled by PR NVIDIA-NeMo#2178 plus the auto-skip safety net added in response to yuki-97's review: > and I think it's better to add a functional test (or modify one > exist functional test) for reference_policy_kl_penalty == 0. The test runs a 2-step GRPO with reference_policy_kl_penalty=0 and without explicitly setting skip_reference_policy_logprobs_calculation, then asserts: * the auto-skip log line fires (proves setup() override worked); * the existing "Reference policy logprob calculation will be skipped" confirmation log fires; * standard probs_ratio + gen_kl_error metric envelopes pass (PR NVIDIA-NeMo#2174 zeros placeholder keeps loss math valid when KL penalty is zero). Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Linglin Jing <linglinj@nvidia.com>
jinglinglingling
added a commit
to jinglinglingling/RL
that referenced
this pull request
May 9, 2026
…ratio Adds two parametrized unit tests in tests/unit/algorithms/test_grpo.py that cover both grpo_train and async_grpo_train: - test_grpo_train_skips_reference_policy_logprobs_when_configured: guards issue NVIDIA-NeMo#1968 / PRs NVIDIA-NeMo#2174, NVIDIA-NeMo#2178 by asserting that policy.get_reference_policy_logprobs is never called when grpo.skip_reference_policy_logprobs_calculation=True. - test_grpo_train_skips_prev_logprobs_when_force_on_policy_ratio: guards PR NVIDIA-NeMo#2177 by asserting that policy.get_logprobs is never called when loss_fn.force_on_policy_ratio=True. Both tests reuse the existing mock_grpo_components fixture and the mock_async_grpo_infrastructure helper so they require no GPU / Ray cluster and run in CI in milliseconds (modulo cold-start import cost). Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Linglin Jing <linglinj@nvidia.com>
jinglinglingling
added a commit
to jinglinglingling/RL
that referenced
this pull request
May 9, 2026
The two regression tests added in this PR drive `grpo_train` / `async_grpo_train` through code paths that call `torch.zeros_like(prev_logprobs)` (PRs NVIDIA-NeMo#2174 / NVIDIA-NeMo#2178) and `torch.zeros_like(generation_logprobs)` (PR NVIDIA-NeMo#2177). Under the bare `mock_grpo_components` fixture those inputs are `MagicMock` objects, so CI failed with `TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not MagicMock` at `nemo_rl/algorithms/grpo.py:1801`. Add a `_patched_logprob_phase` context manager that swaps in real tensors for `policy.get_logprobs`, `policy.get_reference_policy_logprobs`, and `batched_message_log_to_flat_message`, and use it in both the sync and async branches of the two new tests. Signed-off-by: Linglin Jing <linglinj@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
Summary
Setting
skip_reference_policy_logprobs_calculation=truein GRPO config crashes because:reference_policy_logprobsis never assigned totrain_datawhen skippeduse_reference_model()context manager crashes when no reference state dict existsFixes #1968
Root Cause
Three code paths needed fixes:
grpo.pysync path: missingtrain_data["reference_policy_logprobs"]assignmentgrpo.pyasync path: sameuse_reference_model()tries to swap non-existent state dictsFix
torch.zeros_like(prev_logprobs)toreference_policy_logprobs_has_reference_model()base methodget_reference_policy_logprobs(): return zeros if no reference modeluse_reference_model()context managers: yield without swapping if no reference state dictIssues
List issues that this PR closes (syntax):
#1968
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information