Skip to content

Rename MambaModel/MambaStack to HybridModel/HybridStack#4099

Merged
Phlip79 merged 20 commits intoNVIDIA:mainfrom
Phlip79:philip/rename-to-hybrid
Apr 19, 2026
Merged

Rename MambaModel/MambaStack to HybridModel/HybridStack#4099
Phlip79 merged 20 commits intoNVIDIA:mainfrom
Phlip79:philip/rename-to-hybrid

Conversation

@Phlip79
Copy link
Copy Markdown
Member

@Phlip79 Phlip79 commented Apr 1, 2026

Summary

Rename the generic model classes in megatron/core/ that support multiple layer types (Mamba SSM, Attention, MoE, GDN, MLP) via hybrid_layer_pattern:

  • MambaModelHybridModel, MambaStackHybridStack, MambaStackSubmodulesHybridStackSubmodules
  • mamba_stack_spechybrid_stack_spec, mamba_inference_stack_spechybrid_inference_stack_spec
  • get_mamba_stack_modelopt_specget_hybrid_stack_modelopt_spec
  • Move canonical files to megatron/core/models/hybrid/ (hybrid_model.py, hybrid_block.py, hybrid_layer_specs.py, hybrid_layer_allocation.py)
  • Backward-compatible re-export stubs at old import paths (megatron.core.models.mamba, megatron.core.ssm.mamba_block, etc.)
  • MambaModel is a thin subclass of HybridModel that accepts the deprecated mamba_stack_spec kwarg
  • Mamba-specific SSM classes (MambaLayer, MambaMixer, MambaContextParallel, etc.) unchanged
  • megatron/core/models/hybrid/__init__.py is intentionally empty to avoid circular import with megatron.core

This PR only touches megatron/core/. Non-core renames (scripts, tools, tests, examples) are in #4159.

Testing

Functional tests

@Phlip79 Phlip79 requested review from a team as code owners April 1, 2026 21:30
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 1, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 1, 2026

This PR has been automatically converted to draft because all PRs must start as drafts.

When you are ready for review, click Ready for Review to begin the review process. This will:

  1. Add the oncall reviewer (optional reviewer)
  2. Add required review teams based on your changes

See the contribution guide for more details.

@svcnvidia-nemo-ci svcnvidia-nemo-ci marked this pull request as draft April 1, 2026 21:30
@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Apr 1, 2026

/ok to test 360b582

@svcnvidia-nemo-ci svcnvidia-nemo-ci added this to the Core 0.16 milestone Apr 1, 2026
@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Apr 1, 2026

/ok to test c7dec8a

@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Apr 2, 2026

/claude review

Comment thread tests/unit_tests/inference/engines/test_hybrid_prefix_caching_e2e.py Outdated
Comment thread pretrain_hybrid.py
Comment thread megatron/core/models/hybrid/hybrid_model.py Outdated
Comment thread megatron/core/models/hybrid/hybrid_block.py
Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean mechanical rename with proper backward-compatible aliases and re-exports. Left a few minor nits on stale references to 'MambaBlock' in comments and a '2026-2026' copyright typo, but nothing blocking.

@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Apr 2, 2026

/claude review

Comment thread megatron/core/models/hybrid/hybrid_model.py Outdated
Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean rename with good backward-compatible re-exports at the old import paths. One issue: the MambaModel = HybridModel alias doesn't cover the renamed mamba_stack_spechybrid_stack_spec keyword parameter in __init__, so existing callers using MambaModel(mamba_stack_spec=...) will break with a TypeError. See inline comment for a suggested fix.

Comment thread megatron/core/models/hybrid/hybrid_layer_specs.py
Comment thread megatron/post_training/arguments.py
Comment thread megatron/inference/utils.py Outdated
Comment thread megatron/core/ssm/mamba_hybrid_layer_allocation.py
@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Apr 2, 2026

/ok to test f115ba7

@Phlip79 Phlip79 marked this pull request as ready for review April 3, 2026 02:07
@Phlip79 Phlip79 requested a review from a team as a code owner April 3, 2026 02:07
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the Approved All necessary approvals have been made label Apr 16, 2026
Copy link
Copy Markdown
Contributor

@ko3n1g ko3n1g left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has this been cross tested with MBridge?

@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Apr 16, 2026

has this been cross tested with MBridge?

I just re-kicked off MBridge testing (previously failing due to linting error on MBridge side). All changes in this PR are backwards compatible.

Copy link
Copy Markdown
Contributor

@ko3n1g ko3n1g left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I see you're got an eye on testing & backwards compa. i'll lift my block

Phlip79 and others added 2 commits April 18, 2026 00:39
…brid

# Conflicts:
#	megatron/core/inference/contexts/dynamic_context.py
#	megatron/core/models/mamba/mamba_layer_specs.py
#	megatron/core/models/mamba/mamba_model.py
#	megatron/core/ssm/mamba_block.py
#	megatron/core/ssm/mamba_hybrid_layer_allocation.py
Test classes named after MambaModel/MambaStack/MambaStackSubmodules
are renamed to match the new Hybrid class names:
- TestMambaModel -> TestHybridModel
- TestMambaQKLayernorm -> TestHybridQKLayernorm
- TestMambaWithDynamicInference -> TestHybridWithDynamicInference
- TestMambaMoEModel -> TestHybridMoEModel
- TestMambaBlock -> TestHybridBlock
- TestModelOptMambaModel -> TestModelOptHybridModel
- TestMultiTokenPredictionMamba -> TestMultiTokenPredictionHybrid
- TestParallelMambaBlockCudagraphs -> TestParallelHybridBlockCudagraphs

Also updates the TestHybridQKLayernorm class (added by upstream merge)
to use HybridModel/hybrid_stack_spec consistently.

Test classes for Mamba-specific SSM components (MambaLayer, MambaMixer,
MambaContextParallel, MambaMetadata, MambaSlotAllocator, etc.) are
unchanged.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@Phlip79 Phlip79 requested a review from a team as a code owner April 18, 2026 00:50
@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Apr 18, 2026

/ok to test b88d93b

This stale file was missed in the original directory rename. The
canonical location is now modelopt/hybrid/model_specs.py. This stub
re-exports from the new canonical location to preserve backward
compatibility with any external imports from the old path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Apr 18, 2026

/ok to test 3375b10

MultiTokenPredictionLayer and MultiTokenPredictionBlock now accept
mamba_submodules as a deprecated alias for hybrid_submodules, emitting
DeprecationWarning and forwarding the value (raises if both are set).
Also drops the redundant explicit HybridModel import in mamba_model.py.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Apr 18, 2026

/ok to test 556eaa5

The test_dsa_layer_types test was added to test_hybrid_block.py by an
upstream merge and used bare references to mamba_stack_spec and
MambaStack (not via the backward-compat import). The file imports
hybrid_stack_spec and HybridStack, so those bare references NameError'd.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Phlip79
Copy link
Copy Markdown
Member Author

Phlip79 commented Apr 18, 2026

/ok to test 5202abc

@Phlip79 Phlip79 enabled auto-merge April 18, 2026 05:23
@Phlip79 Phlip79 added this pull request to the merge queue Apr 18, 2026
@svcnvidia-nemo-ci
Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/24598437932

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Apr 18, 2026
@Phlip79 Phlip79 added this pull request to the merge queue Apr 19, 2026
@svcnvidia-nemo-ci
Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/24637957751

Merged via the queue into NVIDIA:main with commit 15e07a2 Apr 19, 2026
67 of 68 checks passed
@Phlip79 Phlip79 deleted the philip/rename-to-hybrid branch April 19, 2026 20:37
Victarry added a commit to yanring/Megatron-LM that referenced this pull request Apr 20, 2026
* origin/main: (286 commits)
  Rename MambaModel/MambaStack to HybridModel/HybridStack (NVIDIA#4099)
  Fix Megatron initialization with extra_args_provider (NVIDIA#4327)
  Fix RL to once again work with --skip-train (NVIDIA#4249)
  Add activation logging and tokens per expert logging (NVIDIA#3842)
  Make param_index_map always use unpacked (full numel) offsets (NVIDIA#4328)
  FA4 Inference (NVIDIA#4186)
  Fix RL reward due to stop token (NVIDIA#4096)
  cp: Fix UT timeout (NVIDIA#4310) (NVIDIA#4373)
  feat(ckpt): add --async-ckpt-use-cpu-shm argument (NVIDIA#4355)
  Update copy-pr-bot.yaml [skip ci]
  Docs: improve docstrings and comments in example training loop (NVIDIA#4041)
  Add QK layernorm support for dot-product attention in MambaModel (NVIDIA#4067)
  Fix bug with non-partial rollouts (NVIDIA#3964)
  [docs] ci: use parent-relative json_url for version picker (NVIDIA#4367)
  Add tables and histogram for RL staleness (NVIDIA#4097)
  Port DeepSeek Sparse Attention to `MambaModel` (NVIDIA#3553)
  docs: bump versions1.json to 0.17.0 (latest) (NVIDIA#4360)
  Fix potential coredump issue that occurs when saving a checkpoint (NVIDIA#1871)
  ci(gb200): add 1-node mr-github functional test variants (NVIDIA#4334)
  fix: wait for async P2P send before deallocating output tensor (NVIDIA#4047)
  ...

# Conflicts:
#	megatron/core/transformer/cuda_graphs.py
santhnm2 pushed a commit to santhnm2/Megatron-LM that referenced this pull request Apr 20, 2026
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@Phlip79 Phlip79 mentioned this pull request Apr 22, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Approved All necessary approvals have been made complexity: high

Projects

None yet

Development

Successfully merging this pull request may close these issues.