Skip to content

Feat: MLA eagle#689

Merged
h-guo18 merged 11 commits intomainfrom
haoguo/mla-eagle
Dec 19, 2025
Merged

Feat: MLA eagle#689
h-guo18 merged 11 commits intomainfrom
haoguo/mla-eagle

Conversation

@h-guo18
Copy link
Copy Markdown
Contributor

@h-guo18 h-guo18 commented Dec 15, 2025

What does this PR do?

Type of change: New Feature

Overview:

  • Add MLA Eagle support
    • Add new argument "eagle_decoder_type" to switch between llama and kimik2 eagle;
    • Add patches to load from kimik2 model implementations dynamically;
    • new default config for kimi k2;
    • Refactor eagle export to support multilayer/multitype eagle export concisely;
      • Rename some modules for simplified export logic;
  • Other minor improvements;

Usage

# Add a code snippet demonstrating how to use this

Testing

  • Tested that kimi k2 thinking works with eagle_type=kimik2:
image
  • Tested that llama 3.2 1b works with eagle_type=llama:
image

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Dec 15, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@h-guo18 h-guo18 self-assigned this Dec 15, 2025
@h-guo18 h-guo18 changed the title MLA eagle Feat: MLA eagle Dec 15, 2025
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Dec 15, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Comment thread examples/speculative_decoding/eagle_config.json
Comment thread examples/speculative_decoding/eagle_utils.py
Comment thread modelopt/torch/speculative/plugins/megatron_eagle.py
Comment thread modelopt/torch/speculative/plugins/transformers.py
Comment thread modelopt/torch/speculative/plugins/transformers.py
@yeyu-nvidia
Copy link
Copy Markdown
Contributor

Will need to add eagle_decoder_type support in megatron_eagle.py as well as export support. Can leave to next PR

Comment thread modelopt/torch/speculative/utils.py
Copy link
Copy Markdown
Contributor

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, are we loading an MoE layer here? Are we overriding it with MLP somehow?

Comment thread modelopt/torch/export/plugins/hf_spec_export.py
Comment thread modelopt/torch/export/plugins/hf_spec_export.py
@codecov
Copy link
Copy Markdown

codecov Bot commented Dec 18, 2025

Codecov Report

❌ Patch coverage is 37.50000% with 25 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.65%. Comparing base (b286165) to head (11187be).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/speculative/utils.py 28.57% 25 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #689      +/-   ##
==========================================
- Coverage   74.73%   74.65%   -0.09%     
==========================================
  Files         192      192              
  Lines       18870    18909      +39     
==========================================
+ Hits        14103    14117      +14     
- Misses       4767     4792      +25     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yeyu-nvidia
Copy link
Copy Markdown
Contributor

Please test e2e pipeline for parallel draft: from convert to train to export before merge

@h-guo18 h-guo18 requested a review from yeyu-nvidia December 18, 2025 03:55
Comment thread examples/speculative_decoding/eagle_config.json Outdated
Copy link
Copy Markdown
Contributor

@yeyu-nvidia yeyu-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address the comments

Comment thread examples/speculative_decoding/eagle_config.json Outdated
@yeyu-nvidia yeyu-nvidia dismissed their stale review December 18, 2025 06:47

Discussed offline

@h-guo18 h-guo18 marked this pull request as ready for review December 18, 2025 06:58
@h-guo18 h-guo18 requested review from a team as code owners December 18, 2025 06:58
Comment thread examples/speculative_decoding/eagle_config.json Outdated
@@ -43,3 +44,4 @@ def modify(
self.eagle_report_acc = eagle_report_acc
self.eagle_reuse_base_decoder = eagle_reuse_base_decoder
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be removed since you have eagle_decoder_type now

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't megatron still use this argument?

@yeyu-nvidia
Copy link
Copy Markdown
Contributor

yeyu-nvidia commented Dec 18, 2025

Training failed on a8. Command to reproduce: bash launch_train.sh --save_steps 20 --data_path /workspace/scratch.yeyu_hw/Daring-Anteater/llama3.2_1B_fp8.jsonl --training_seq_len 512

[rank7]: Traceback (most recent call last):
[rank7]: File "/workspace/scratch.yeyu_hw/TensorRT-Model-Optimizer/examples/speculative_decoding/main.py", line 263, in
[rank7]: train()
[rank7]: File "/workspace/scratch.yeyu_hw/TensorRT-Model-Optimizer/examples/speculative_decoding/main.py", line 257, in train
[rank7]: trainer.train(resume_from_checkpoint=checkpoint)
[rank7]: File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 2325, in train
[rank7]: return inner_training_loop(
[rank7]: ^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 2674, in _inner_training_loop
[rank7]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 4020, in training_step
[rank7]: loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/workspace/scratch.yeyu_hw/TensorRT-Model-Optimizer/examples/speculative_decoding/eagle_utils.py", line 502, in compute_loss
[rank7]: loss, outputs = super().compute_loss(return_outputs=True, *args, **kwargs)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 4110, in compute_loss
[rank7]: outputs = model(**inputs)
[rank7]: ^^^^^^^^^^^^^^^
[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank7]: return self._call_impl(*args, **kwargs)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank7]: return forward_call(*args, **kwargs)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/parallel/distributed.py", line 1633, in forward
[rank7]: inputs, kwargs = self._pre_forward(*inputs, **kwargs)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/parallel/distributed.py", line 1529, in _pre_forward
[rank7]: self._sync_buffers()
[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/parallel/distributed.py", line 2166, in _sync_buffers
[rank7]: self._sync_module_buffers(authoritative_rank)
[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/parallel/distributed.py", line 2170, in _sync_module_buffers
[rank7]: self._default_broadcast_coalesced(authoritative_rank=authoritative_rank)
[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/parallel/distributed.py", line 2192, in _default_broadcast_coalesced
[rank7]: self._distributed_broadcast_coalesced(bufs, bucket_size, authoritative_rank)
[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/parallel/distributed.py", line 2107, in _distributed_broadcast_coalesced
[rank7]: dist._broadcast_coalesced(
[rank7]: RuntimeError: No backend type associated with device type cpu

@yeyu-nvidia
Copy link
Copy Markdown
Contributor

Please test e2e pipeline for parallel draft: from convert to train to export before merge

tested

@h-guo18
Copy link
Copy Markdown
Contributor Author

h-guo18 commented Dec 18, 2025

Training failed on a8. Command to reproduce: bash launch_train.sh --save_steps 20 --data_path /workspace/scratch.yeyu_hw/Daring-Anteater/llama3.2_1B_fp8.jsonl --training_seq_len 512

[rank7]: Traceback (most recent call last):
[rank7]: File "/workspace/scratch.yeyu_hw/TensorRT-Model-Optimizer/examples/speculative_decoding/main.py", line 263, in
[rank7]: train()
[rank7]: File "/workspace/scratch.yeyu_hw/TensorRT-Model-Optimizer/examples/speculative_decoding/main.py", line 257, in train
[rank7]: trainer.train(resume_from_checkpoint=checkpoint)
[rank7]: File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 2325, in train
[rank7]: return inner_training_loop(
[rank7]: ^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 2674, in _inner_training_loop
[rank7]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 4020, in training_step
[rank7]: loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/workspace/scratch.yeyu_hw/TensorRT-Model-Optimizer/examples/speculative_decoding/eagle_utils.py", line 502, in compute_loss
[rank7]: loss, outputs = super().compute_loss(return_outputs=True, *args, **kwargs)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/usr/local/lib/python3.12/dist-packages/transformers/trainer.py", line 4110, in compute_loss
[rank7]: outputs = model(**inputs)
[rank7]: ^^^^^^^^^^^^^^^
[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank7]: return self._call_impl(*args, **kwargs)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank7]: return forward_call(*args, **kwargs)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/parallel/distributed.py", line 1633, in forward
[rank7]: inputs, kwargs = self._pre_forward(*inputs, **kwargs)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/parallel/distributed.py", line 1529, in _pre_forward
[rank7]: self._sync_buffers()
[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/parallel/distributed.py", line 2166, in _sync_buffers
[rank7]: self._sync_module_buffers(authoritative_rank)
[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/parallel/distributed.py", line 2170, in _sync_module_buffers
[rank7]: self._default_broadcast_coalesced(authoritative_rank=authoritative_rank)
[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/parallel/distributed.py", line 2192, in _default_broadcast_coalesced
[rank7]: self._distributed_broadcast_coalesced(bufs, bucket_size, authoritative_rank)
[rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/parallel/distributed.py", line 2107, in _distributed_broadcast_coalesced
[rank7]: dist._broadcast_coalesced(
[rank7]: RuntimeError: No backend type associated with device type cpu

fixed

Comment thread examples/speculative_decoding/eagle_config.json Outdated
Copy link
Copy Markdown
Contributor

@yeyu-nvidia yeyu-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Megatron will need some API refactoring due to this PR. We will need to add MLA to megatron as well.

h-guo18 and others added 11 commits December 18, 2025 22:39
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: yeyu-nvidia <yeyu@nvidia.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
import torch.distributed
from huggingface_hub import snapshot_download
from torch import nn
from transformers.cache_utils import DynamicCache
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These utils require transformers should better be moved to /plugins. No need to change now but a remark.

Copy link
Copy Markdown
Collaborator

@ChenhanYu ChenhanYu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I leave a short comment regarding dependency on transformers. This is great work.

@h-guo18 h-guo18 enabled auto-merge (squash) December 18, 2025 23:57
@h-guo18 h-guo18 merged commit bdd10c2 into main Dec 19, 2025
36 checks passed
@h-guo18 h-guo18 deleted the haoguo/mla-eagle branch December 19, 2025 00:11
@h-guo18
Copy link
Copy Markdown
Contributor Author

h-guo18 commented Dec 19, 2025

Megatron will need some API refactoring due to this PR. We will need to add MLA to megatron as well.

I think we should only refactor something if it's due to the need for new feature. We should not refactor it if it's due to this PR. I would appreciate to know if there is a better way for this feature. Thanks

@yeyu-nvidia
Copy link
Copy Markdown
Contributor

yeyu-nvidia commented Dec 19, 2025

I think we should only refactor something if it's due to the need for new feature. We should not refactor it if it's due to this PR. I would appreciate to know if there is a better way for this feature. Thanks

Isn't MLA a need to Megatron? This PR disables eagle_reuse_base_decoder for HF and introduce MLA decoder. Don't we need to refactor to enable eagle_decoder_type to Megatron as well and deprecate eagle_reuse_base_decoder? What is the definition for "need for new feature" if we don't support something that we claim we support but only half way support it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants