Skip to content

fix: Float32RMSNorm torch.compile crash on PyTorch 2.11+#1650

Merged
akoumpa merged 5 commits intomainfrom
hemild/fix-torch-compile-rmsnorm
Apr 2, 2026
Merged

fix: Float32RMSNorm torch.compile crash on PyTorch 2.11+#1650
akoumpa merged 5 commits intomainfrom
hemild/fix-torch-compile-rmsnorm

Conversation

@hemildesai
Copy link
Copy Markdown
Contributor

Summary

  • Extract Float32RMSNorm.forward into a standalone _float32_rms_norm_fwd() compiled function to reduce dynamo guard-state combinations (eliminates guards on self and module state)
  • Bump default dynamo cache_size_limit to 64 unconditionally in compile_utils.py so per-method @torch.compile decorators don't hit FailOnRecompileLimitHit from varying ndim/autocast/grad_mode combinations
  • Root cause: PyTorch 2.11 is stricter about recompilation limits (default=8), and MoE training with variable-length sequences triggers 8+ guard-state combinations

Test plan

  • Verified fix runs 50 steps of Qwen3 MoE 30B SFT with rms_norm: torch_fp32 on PyTorch 2.11 (torch 2.11.0a0+eb65b36914.nv26.02) without crash
  • TPS/gpu: 3138 (fixed torch_fp32) vs 3218 (TE RMSNorm) — comparable performance
  • Pre-commit hooks pass (ruff, ruff-format)
  • CI unit tests

🤖 Generated with Claude Code

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 1, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@hemildesai
Copy link
Copy Markdown
Contributor Author

/claude review

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 3f567d4

@hemildesai hemildesai changed the title Fix Float32RMSNorm torch.compile crash on PyTorch 2.11+ fix: Float32RMSNorm torch.compile crash on PyTorch 2.11+ Apr 1, 2026
Comment thread nemo_automodel/components/utils/compile_utils.py Outdated
Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean refactor — extracting the forward into a standalone compiled function to reduce dynamo guards makes sense.

One concern flagged inline: the module-level configure_torch_dynamo(cache_size_limit=64) sets a limit lower than the existing default of 256 used by CompileConfig. This is only effective when full model compile is disabled (otherwise apply_torch_compile overrides it). The comment says "bump" which is misleading — worth clarifying the intent or aligning the values.

@hemildesai
Copy link
Copy Markdown
Contributor Author

/claude review

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 3406485

Comment thread nemo_automodel/components/utils/compile_utils.py Outdated
@hemildesai
Copy link
Copy Markdown
Contributor Author

/claude review

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test b18eaae

Fixes L0_Unit_Tests_CPU failure introduced by 980f23d on main.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
@hemildesai
Copy link
Copy Markdown
Contributor Author

/claude review

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test a2f4ee8

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@akoumpa akoumpa enabled auto-merge (squash) April 2, 2026 01:45
@akoumpa akoumpa merged commit ec2f724 into main Apr 2, 2026
53 checks passed
@akoumpa akoumpa deleted the hemild/fix-torch-compile-rmsnorm branch April 2, 2026 02:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants