chore: patch KL loss to prevent nans#876
Conversation
Signed-off-by: rohitrango <rohit.rango@gmail.com>
|
@terrykong please review |
WalkthroughRemoved packaging/check scripts and a third-party license under 3rdparty Megatron workspaces; added an optional clamping parameter to a KL penalty utility in nemo_rl, applying clamping by default to the log-probability ratio before computing the penalty. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Caller as Trainer/Algorithm
participant Utils as utils.calculate_kl_penalty_joschu2020
Caller->>Utils: logprobs_policy, logprobs_reference, clamp_value=20.0 (default)
activate Utils
Utils->>Utils: r = logprobs_reference - logprobs_policy
alt clamp_value is not None
note over Utils: Clamp r to [-clamp_value, clamp_value]
else
note over Utils: No clamping
end
Utils->>Utils: penalty = exp(r) - r - 1
Utils-->>Caller: penalty tensor
deactivate Utils
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Comment |
ℹ️ File Consistency CheckCheck based on commit: 97ff7ea (PR #876 from This is a test comment This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: 4a4e5de (PR #876 from This is a test comment This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
nemo_rl/algorithms/utils.py (2)
75-80: Bug: constructing device from uninitialized variable; use rewards.deviceThis will raise at runtime (reward_device referenced before assignment). Use the tensor’s device.
- device_ordinal = rewards.get_device() - if device_ordinal == -1: - reward_device = torch.device("cpu") - else: - reward_device = torch.device(reward_device) + reward_device = rewards.device
287-287: Bug: dict has no attribute .size — use tensor batch dimension insteadThis will crash. Use the length of a tensor in the batch (e.g., input_ids.size(0)) for batch size.
- min_padding = (math.ceil(batch.size / (mbs * dp_size)) * mbs * dp_size) - batch.size + batch_size = batch["input_ids"].size(0) + min_padding = (math.ceil(batch_size / (mbs * dp_size)) * mbs * dp_size) - batch_size
🧹 Nitpick comments (4)
nemo_rl/algorithms/utils.py (4)
42-45: Docstring missing new parameter detailsPlease add an Args entry for clamp_value, including how to disable (None) and dtype considerations (fp16/bf16).
94-107: Leave‑one‑out denominator can undercount if the current sample is invalidWhen leave_one_out_baseline is True and the current sample is invalid, subtracting 1 still reduces num_valid. Consider subtracting valid_mask[prompt_idx] (per‑row) instead of a scalar.
123-131: Typo: surpress_user_warnings → suppress_user_warningsRename for correctness; keep a deprecated alias if this is public.
-def surpress_user_warnings(f): # type: ignore +def suppress_user_warnings(f): # type: ignore @wraps(f) def wrapper(*args, **kwargs): # type: ignore with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) output = f(*args, **kwargs) return output @@ - return wrapper + return wrapper + +# Backward compatibility +surpress_user_warnings = suppress_user_warnings
291-332: Pad all mutually‑dependent fields or assert presenceConsider asserting required keys (input_ids, input_lengths, sample_mask, optional token_mask/reference_policy_logprobs if present) and that padded shapes align. Avoid silent shape drift.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
3rdparty/Megatron-Bridge-workspace/is_megatron_bridge_installed.py(0 hunks)3rdparty/Megatron-Bridge-workspace/pyproject.toml(0 hunks)3rdparty/Megatron-Bridge-workspace/setup.py(0 hunks)3rdparty/Megatron-LM-workspace/is_megatron_installed.py(0 hunks)3rdparty/Megatron-LM-workspace/pyproject.toml(0 hunks)3rdparty/Megatron-LM-workspace/setup.py(0 hunks)3rdparty/THIRD_PARTY_LICENSE_MATPLOTLIB(0 hunks)nemo_rl/algorithms/utils.py(2 hunks)
💤 Files with no reviewable changes (7)
- 3rdparty/Megatron-LM-workspace/pyproject.toml
- 3rdparty/Megatron-Bridge-workspace/setup.py
- 3rdparty/Megatron-Bridge-workspace/pyproject.toml
- 3rdparty/THIRD_PARTY_LICENSE_MATPLOTLIB
- 3rdparty/Megatron-LM-workspace/is_megatron_installed.py
- 3rdparty/Megatron-LM-workspace/setup.py
- 3rdparty/Megatron-Bridge-workspace/is_megatron_bridge_installed.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (1)
nemo_rl/algorithms/utils.py (1)
81-85: Device mismatch hazard when indexing promptsEnsure is_matching_prompt and the arange index live on the same device; using rewards.device above resolves it, but please double‑check prompts.device == rewards.device in callers.
ℹ️ File Consistency CheckCheck based on commit: 29f7072 (PR #876 from This is a test comment This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: rohitrango <rohit.rango@gmail.com> Co-authored-by: Parth Chadha <pchadha@nvidia.com>
Signed-off-by: rohitrango <rohit.rango@gmail.com> Co-authored-by: Parth Chadha <pchadha@nvidia.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
What does this PR do ?
Patches #874 where some models may start with a nan loss due to very high absolute kl values.
Summary by CodeRabbit