Skip to content

chore: patch KL loss to prevent nans#876

Merged
parthchadha merged 8 commits intomainfrom
rohit/kl_fix
Sep 25, 2025
Merged

chore: patch KL loss to prevent nans#876
parthchadha merged 8 commits intomainfrom
rohit/kl_fix

Conversation

@rohitrango
Copy link
Contributor

@rohitrango rohitrango commented Aug 8, 2025

What does this PR do ?

Patches #874 where some models may start with a nan loss due to very high absolute kl values.

Summary by CodeRabbit

  • New Features
    • Added a configurable clamp to the KL penalty calculation for improved numerical stability (default 20.0; can be disabled).
  • Chores
    • Removed deprecated Megatron-related packaging/configuration artifacts and installation check scripts from third-party workspaces.
    • Cleaned up unused third-party license document for Matplotlib.

Signed-off-by: rohitrango <rohit.rango@gmail.com>
@euronymous-aithal
Copy link
Contributor

@terrykong please review

parthchadha
parthchadha previously approved these changes Sep 17, 2025
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 17, 2025

Walkthrough

Removed packaging/check scripts and a third-party license under 3rdparty Megatron workspaces; added an optional clamping parameter to a KL penalty utility in nemo_rl, applying clamping by default to the log-probability ratio before computing the penalty.

Changes

Cohort / File(s) Summary of edits
Megatron-Bridge workspace cleanup
3rdparty/Megatron-Bridge-workspace/is_megatron_bridge_installed.py, 3rdparty/Megatron-Bridge-workspace/pyproject.toml, 3rdparty/Megatron-Bridge-workspace/setup.py
Deleted installation check module and packaging/build configuration files. Removed module-level INSTALLED variable.
Megatron-LM workspace cleanup
3rdparty/Megatron-LM-workspace/is_megatron_installed.py, 3rdparty/Megatron-LM-workspace/pyproject.toml, 3rdparty/Megatron-LM-workspace/setup.py
Deleted installation check module and packaging/build scripts. Removed module-level INSTALLED variable.
Third-party license removal
3rdparty/THIRD_PARTY_LICENSE_MATPLOTLIB
Deleted Matplotlib third-party license text file.
RL utilities update
nemo_rl/algorithms/utils.py
Updated calculate_kl_penalty_joschu2020 to accept clamp_value: Optional[float] = 20.0; clamps r = logprobs_reference - logprobs_policy to [-clamp_value, clamp_value] when not None before computing penalty.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor Caller as Trainer/Algorithm
  participant Utils as utils.calculate_kl_penalty_joschu2020
  Caller->>Utils: logprobs_policy, logprobs_reference, clamp_value=20.0 (default)
  activate Utils
  Utils->>Utils: r = logprobs_reference - logprobs_policy
  alt clamp_value is not None
    note over Utils: Clamp r to [-clamp_value, clamp_value]
  else
    note over Utils: No clamping
  end
  Utils->>Utils: penalty = exp(r) - r - 1
  Utils-->>Caller: penalty tensor
  deactivate Utils
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Poem

I tidied my burrow, swept out old crates,
Packed up the Megatron maps and gates.
A gentle clamp on ratios tight,
Keeps my gradients calm at night.
Thump-thump—deploy! with whiskers wide,
Leaner trails, a smoother ride. 🐇✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "chore: patch KL loss to prevent nans" is a concise, single-sentence summary that directly reflects the main change (adding a clamp to the KL penalty to avoid NaN losses) as shown in nemo_rl/algorithms/utils.py; it is specific enough for a reviewer scanning history and free of noisy details.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch rohit/kl_fix

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

ℹ️ File Consistency Check

Check based on commit: 97ff7ea (PR #876 from rohit/kl_fix)

This is a test comment


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link

ℹ️ File Consistency Check

Check based on commit: 4a4e5de (PR #876 from rohit/kl_fix)

This is a test comment


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
nemo_rl/algorithms/utils.py (2)

75-80: Bug: constructing device from uninitialized variable; use rewards.device

This will raise at runtime (reward_device referenced before assignment). Use the tensor’s device.

-    device_ordinal = rewards.get_device()
-    if device_ordinal == -1:
-        reward_device = torch.device("cpu")
-    else:
-        reward_device = torch.device(reward_device)
+    reward_device = rewards.device

287-287: Bug: dict has no attribute .size — use tensor batch dimension instead

This will crash. Use the length of a tensor in the batch (e.g., input_ids.size(0)) for batch size.

-    min_padding = (math.ceil(batch.size / (mbs * dp_size)) * mbs * dp_size) - batch.size
+    batch_size = batch["input_ids"].size(0)
+    min_padding = (math.ceil(batch_size / (mbs * dp_size)) * mbs * dp_size) - batch_size
🧹 Nitpick comments (4)
nemo_rl/algorithms/utils.py (4)

42-45: Docstring missing new parameter details

Please add an Args entry for clamp_value, including how to disable (None) and dtype considerations (fp16/bf16).


94-107: Leave‑one‑out denominator can undercount if the current sample is invalid

When leave_one_out_baseline is True and the current sample is invalid, subtracting 1 still reduces num_valid. Consider subtracting valid_mask[prompt_idx] (per‑row) instead of a scalar.


123-131: Typo: surpress_user_warnings → suppress_user_warnings

Rename for correctness; keep a deprecated alias if this is public.

-def surpress_user_warnings(f):  # type: ignore
+def suppress_user_warnings(f):  # type: ignore
     @wraps(f)
     def wrapper(*args, **kwargs):  # type: ignore
         with warnings.catch_warnings():
             warnings.filterwarnings("ignore", category=UserWarning)
             output = f(*args, **kwargs)
         return output
@@
-    return wrapper
+    return wrapper
+
+# Backward compatibility
+surpress_user_warnings = suppress_user_warnings

291-332: Pad all mutually‑dependent fields or assert presence

Consider asserting required keys (input_ids, input_lengths, sample_mask, optional token_mask/reference_policy_logprobs if present) and that padded shapes align. Avoid silent shape drift.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 94a3d49 and 97ff7ea.

📒 Files selected for processing (8)
  • 3rdparty/Megatron-Bridge-workspace/is_megatron_bridge_installed.py (0 hunks)
  • 3rdparty/Megatron-Bridge-workspace/pyproject.toml (0 hunks)
  • 3rdparty/Megatron-Bridge-workspace/setup.py (0 hunks)
  • 3rdparty/Megatron-LM-workspace/is_megatron_installed.py (0 hunks)
  • 3rdparty/Megatron-LM-workspace/pyproject.toml (0 hunks)
  • 3rdparty/Megatron-LM-workspace/setup.py (0 hunks)
  • 3rdparty/THIRD_PARTY_LICENSE_MATPLOTLIB (0 hunks)
  • nemo_rl/algorithms/utils.py (2 hunks)
💤 Files with no reviewable changes (7)
  • 3rdparty/Megatron-LM-workspace/pyproject.toml
  • 3rdparty/Megatron-Bridge-workspace/setup.py
  • 3rdparty/Megatron-Bridge-workspace/pyproject.toml
  • 3rdparty/THIRD_PARTY_LICENSE_MATPLOTLIB
  • 3rdparty/Megatron-LM-workspace/is_megatron_installed.py
  • 3rdparty/Megatron-LM-workspace/setup.py
  • 3rdparty/Megatron-Bridge-workspace/is_megatron_bridge_installed.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (1)
nemo_rl/algorithms/utils.py (1)

81-85: Device mismatch hazard when indexing prompts

Ensure is_matching_prompt and the arange index live on the same device; using rewards.device above resolves it, but please double‑check prompts.device == rewards.device in callers.

@github-actions
Copy link

ℹ️ File Consistency Check

Check based on commit: 29f7072 (PR #876 from rohit/kl_fix)

This is a test comment


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@parthchadha parthchadha requested a review from a team as a code owner September 24, 2025 21:37
@parthchadha parthchadha enabled auto-merge (squash) September 24, 2025 21:39
@parthchadha parthchadha merged commit 7aa7071 into main Sep 25, 2025
25 checks passed
@parthchadha parthchadha deleted the rohit/kl_fix branch September 25, 2025 05:02
PrinsYin pushed a commit to PrinsYin/RL that referenced this pull request Nov 30, 2025
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Co-authored-by: Parth Chadha <pchadha@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 21, 2026
Signed-off-by: rohitrango <rohit.rango@gmail.com>
Co-authored-by: Parth Chadha <pchadha@nvidia.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants