Skip to content

Add QK layernorm support for dot-product attention in MambaModel#4067

Merged
Phlip79 merged 21 commits intoNVIDIA:mainfrom
Phlip79:philip/add-qk-norm
Apr 17, 2026
Merged

Add QK layernorm support for dot-product attention in MambaModel#4067
Phlip79 merged 21 commits intoNVIDIA:mainfrom
Phlip79:philip/add-qk-norm

Conversation

@Phlip79
Copy link
Copy Markdown
Member

@Phlip79 Phlip79 commented Mar 31, 2026

Summary

  • Adds config-driven QK norm support to SelfAttention.__init__ so models with static specs (e.g. MambaModel) can enable QK layernorm via --qk-layernorm without modifying their specs
  • Config selects the default norm class (TENorm or L2Norm); spec overrides if q_layernorm/k_layernorm are explicitly set
  • GPTModel behavior is unchanged — its specs always provide q_layernorm/k_layernorm, which take precedence

Test plan

  • Default config (no norm) — verify q_layernorm and k_layernorm are None
  • qk_layernorm=True — verify norms are created via TENorm
  • qk_l2_norm=True — verify norms are L2Norm instances
  • Spec-provided norm takes precedence over config default
  • Forward pass with qk_layernorm=True produces correct output shape
  • Functional tests

Closes MCORE-20.

Convert static mamba_stack_spec and mamba_inference_stack_spec into
config-driven functions (get_mamba_stack_spec, get_mamba_inference_stack_spec)
that read qk_layernorm and qk_l2_norm from TransformerConfig, matching
GPTModel's approach. Backward-compatible constants are preserved.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Mar 31, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Mar 31, 2026

/claude review

Comment thread megatron/core/models/mamba/mamba_layer_specs.py Outdated
Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Light review — the refactor from static specs to config-driven functions looks correct and consistent with the GPT layer specs pattern. One gap: the new QK-norm code paths have no test coverage (see inline comment).

Tests cover: default (no config), qk_layernorm=True, qk_l2_norm=True,
inference spec, backward-compatible constant, and a full forward pass
with qk_layernorm enabled.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@Phlip79 Phlip79 marked this pull request as ready for review March 31, 2026 00:20
@Phlip79 Phlip79 requested review from a team as code owners March 31, 2026 00:20
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team March 31, 2026 00:20
@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Mar 31, 2026

/claude review

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Mar 31, 2026

/ok to test 367b8a8

@svcnvidia-nemo-ci svcnvidia-nemo-ci added this to the Core 0.16 milestone Mar 31, 2026
@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Mar 31, 2026

/ok to test 1c46827

@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Mar 31, 2026

/ok to test b28a11f

@janEbert
Copy link
Copy Markdown
Contributor

I think this is exactly what we want to avoid; as far as I understand, we do not want to start to make the spec dynamic in code. :)
Isn't this solvable by dynamically passing the arguments in the MambaStack or MambaModel constructors?

@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Mar 31, 2026

/claude review

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Mar 31, 2026

/ok to test bdf2d12

@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Mar 31, 2026

/ok to test 488e448

@Phlip79 Phlip79 removed the request for review from a team April 15, 2026 19:05
Phlip79 and others added 2 commits April 16, 2026 22:08
Replace the default_norm_cls intermediate with explicit branches on
qk_l2_norm / qk_layernorm, matching the file's ValueError convention
and validating spec/config consistency when both flags are disabled.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…add-qk-norm

# Conflicts:
#	megatron/core/transformer/attention.py
@Phlip79 Phlip79 requested a review from a team as a code owner April 16, 2026 22:09
@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the Final Review PR is in the "final review" stage label Apr 16, 2026
@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Apr 16, 2026

/ok to test 861fe3b

Phlip79 and others added 2 commits April 16, 2026 23:58
The strict else-branch in SelfAttention requires config and spec to agree
when any QK norm is active. This test was relying on spec-only enablement,
which is now rejected.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Apr 17, 2026

/ok to test e16a762

Covers the two new behaviors introduced by this PR: config.qk_layernorm
falling back to TENorm when the spec leaves q/k_layernorm as None, and
ValueError when the spec sets a concrete norm but both config flags are
disabled.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Apr 17, 2026

/ok to test 236c50f

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Apr 17, 2026

/ok to test c8a313d

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the Approved All necessary approvals have been made label Apr 17, 2026
TENorm is a factory class (__new__ returns te.pytorch.LayerNorm), so
isinstance against TENorm itself is always False. Check against the
returned class instead. Also add the symmetric qk_l2_norm config-only
fallback test to close the coverage gap.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Apr 17, 2026

/ok to test 1ca9061

@Phlip79 Phlip79 enabled auto-merge April 17, 2026 21:37
@Phlip79 Phlip79 added this pull request to the merge queue Apr 17, 2026
@svcnvidia-nemo-ci
Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/24588820395

@svcnvidia-nemo-ci
Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/24589587870

Merged via the queue into NVIDIA:main with commit e15ec3c Apr 17, 2026
63 checks passed
@Phlip79 Phlip79 deleted the philip/add-qk-norm branch April 17, 2026 23:08
Victarry added a commit to yanring/Megatron-LM that referenced this pull request Apr 20, 2026
* origin/main: (286 commits)
  Rename MambaModel/MambaStack to HybridModel/HybridStack (NVIDIA#4099)
  Fix Megatron initialization with extra_args_provider (NVIDIA#4327)
  Fix RL to once again work with --skip-train (NVIDIA#4249)
  Add activation logging and tokens per expert logging (NVIDIA#3842)
  Make param_index_map always use unpacked (full numel) offsets (NVIDIA#4328)
  FA4 Inference (NVIDIA#4186)
  Fix RL reward due to stop token (NVIDIA#4096)
  cp: Fix UT timeout (NVIDIA#4310) (NVIDIA#4373)
  feat(ckpt): add --async-ckpt-use-cpu-shm argument (NVIDIA#4355)
  Update copy-pr-bot.yaml [skip ci]
  Docs: improve docstrings and comments in example training loop (NVIDIA#4041)
  Add QK layernorm support for dot-product attention in MambaModel (NVIDIA#4067)
  Fix bug with non-partial rollouts (NVIDIA#3964)
  [docs] ci: use parent-relative json_url for version picker (NVIDIA#4367)
  Add tables and histogram for RL staleness (NVIDIA#4097)
  Port DeepSeek Sparse Attention to `MambaModel` (NVIDIA#3553)
  docs: bump versions1.json to 0.17.0 (latest) (NVIDIA#4360)
  Fix potential coredump issue that occurs when saving a checkpoint (NVIDIA#1871)
  ci(gb200): add 1-node mr-github functional test variants (NVIDIA#4334)
  fix: wait for async P2P send before deallocating output tensor (NVIDIA#4047)
  ...

# Conflicts:
#	megatron/core/transformer/cuda_graphs.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Approved All necessary approvals have been made complexity: medium

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants