Skip to content

feat(deepseek-v4): add Multi-Token Prediction (MTP) training support#2191

Merged
HuiyingLi merged 25 commits into
NVIDIA-NeMo:mainfrom
khazic:khazic/feat/deepseek-v4-flash-mtp
May 19, 2026
Merged

feat(deepseek-v4): add Multi-Token Prediction (MTP) training support#2191
HuiyingLi merged 25 commits into
NVIDIA-NeMo:mainfrom
khazic:khazic/feat/deepseek-v4-flash-mtp

Conversation

@khazic
Copy link
Copy Markdown
Contributor

@khazic khazic commented May 8, 2026

Summary

Adds Multi-Token Prediction (MTP) training support for DeepSeek V4 (Flash). MTP layers run as standard pre-norm attention + MoE blocks (no HC machinery), with rotary embeddings shared from the main backbone. The auxiliary loss is computed via the recipe-side calculate_mtp_loss and added to the main CE loss.

What's in this PR

Model side

  • components/models/common/mtp/: model-agnostic scaffold (MTPConfig, MTPModule, roll_tensor).
  • components/models/deepseek_v4/mtp.py: V4-specific DeepseekV4MTPSublayer and build_deepseek_v4_mtp factory. compress_ratios is forced to None for MTP attention to avoid IndexError past the backbone layer count; rotary refs are stored via object.__setattr__ so they don't pollute state_dict.
  • components/models/deepseek_v4/model.py: DeepseekV4ForCausalLM now constructs self.mtp when num_nextn_predict_layers > 0 and returns a DeepseekV4CausalLMOutput dataclass (logits + optional mtp_per_depth_h).

State-dict adapter

  • from_hf runs MTP layers (layers.{N+k}.*) through the same dequantize / aggregate-experts / rename pipeline as the backbone (renumber to layers.{k}.*, run pipeline, re-prefix to mtp.layers.{k}.*). Previously MTP keys bypassed dequantization and FP8/FP4 buffers were left raw.
  • to_hf rewrites mtp.layers.{k}.* into model.layers.{N+k}.* and runs the unified split / rename / quantize path; an explicit fallback strips the leftover model. prefix for fusion-only modules (eh_proj / enorm / hnorm / final_layernorm) that have no entry in the rename table.

Recipe (recipes/llm/train_ft.py)

  • calculate_mtp_loss: per-depth CE through the configured loss class (FusedLinearCE / MaskedCE), summed with loss_scaling_factor / D weighting.
  • _forward_backward_step (non-PP branch) reads out.mtp_per_depth_h and adds the MTP loss to the main loss.
  • _mtp_is_enabled(cfg, model_parts) + setup-time guard: raises NotImplementedError if pipeline parallelism is enabled together with MTP, since the PP schedule does not currently aggregate the MTP auxiliary loss. PP + MTP is intentionally deferred to a follow-up PR.

Tests

  • test_deepseek_v4_mtp.py: config / construction / forward / backward / state-dict coverage.
  • test_dsv4_state_dict_adapter.py: MTP round-trip for layer rename, FP8 dequantize, expert aggregation, and the fusion-only fallback in both directions.
  • test_dsv4_model_smoke.py: updated to read .logits from the new dataclass output.

Overlap with #2161

PR #2161 (Nemotron V3 MTP) introduces the same calculate_mtp_loss helper and the same non-PP integration in _forward_backward_step. Those two regions are byte-identical between the branches.

This is intentional — both PRs need the same recipe-side scaffolding, and the model-agnostic MTP base (components/models/common/mtp/) is shared. When #2161 lands first, those duplicated lines will be auto-resolved on rebase, and this PR will reduce to the V4-specific changes (model, MTP sublayer, adapter, PP guard, V4 tests).

Test plan

wandb: https://wandb.ai/Nemo-automodel/huiyingl_workspace?nw=nwuserhuiyingl
image

khazic added 2 commits May 8, 2026 17:20
- Add model-agnostic MTP scaffold (MTPConfig, MTPModule, roll_tensor) under
  nemo_automodel/components/models/common/mtp/
- Add DeepseekV4MTPSublayer: pre-norm attention+MoE blocks without HC
  machinery; compress_ratios forced to None to avoid IndexError; rotary
  embeddings stored as non-registered references via object.__setattr__
- Add build_mtp_config_from_hf and build_deepseek_v4_mtp factory functions
- Add DeepseekV4CausalLMOutput dataclass so forward returns logits + optional
  mtp_per_depth_h list for MTP loss computation in train_ft.py
- Update DeepseekV4ForCausalLM.__init__ to construct MTP module when
  num_nextn_predict_layers > 0
- Update state_dict_adapter.py: from_hf splits MTP keys and converts back
- Add calculate_mtp_loss to train_ft.py and wire into _forward_backward_step
- Add 8 unit tests covering config, construction, forward, backward, state dict

Signed-off-by: khazic <khazzz1c@gmail.com>
State-dict adapter:
- from_hf: route MTP layers (layers.{N+k}.*) through dequantize +
  aggregate-experts + rename pipeline by renumbering them as layers.{k}.*
  and re-prefixing the result to mtp.layers.{k}.*. Previously MTP keys
  bypassed dequantization, leaving FP8/FP4 buffers undequantized.
- to_hf: rewrite mtp.layers.{k}.* into model.layers.{N+k}.* and run the
  unified split / rename / quantize path; strip the leftover model.
  prefix for fusion-only modules (eh_proj, enorm, hnorm, final_layernorm)
  that have no entry in the rename table.
- Drop dead _apply_inverse_rename helper.

Recipe (train_ft.py):
- Add _mtp_is_enabled(cfg, model_parts) helper that detects MTP via
  YAML override (model.config.num_nextn_predict_layers) or via an
  enabled mtp_config attribute on any constructed submodule.
- Raise NotImplementedError in setup() when PP and MTP are both
  enabled. The PP schedule does not aggregate the MTP auxiliary loss,
  so the MTP head would silently receive no gradients. PP + MTP
  wiring is intentionally deferred to a follow-up PR.
- Add TODO marker in _forward_backward_step PP branch pointing at the
  same follow-up.

Tests:
- Fix test_forward_shape / test_backward to read .logits from the new
  DeepseekV4CausalLMOutput dataclass returned by forward.
- Add MTP round-trip coverage: layer rename, FP8 dequantize, expert
  aggregation, to_hf rename / split / quantize, and the fusion-only
  fallback for both directions.

Signed-off-by: khazic <khazzz1c@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 8, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 3990e0c

DeepSeek-V4 HF safetensors emit MTP layer keys in two forms:

* ``model.layers.{N+k}.*`` for the standard self_attn / mlp / norms
  (carries the canonical ``model.`` prefix like every backbone block).
* ``layers.{N+k}.*`` for V4's MTP-only fusion modules (``eh_proj``,
  ``enorm``, ``hnorm``, ``final_layernorm``) which sit outside the
  HF ``model.`` namespace.

The previous split regex (``r"^layers\.(\d+)\."``) only matched the
unprefixed form, so the prefixed self_attn / mlp / norms keys silently
fell into the backbone bucket. They were then renamed by the standard
backbone pipeline and ended up at ``model.layers.{N+k}.*`` in the
converted state dict — but the model only has ``model.layers.{0..N-1}``,
so DCP load dropped them and ``model.mtp.layers[*].*`` started from
random init. End result: MTP-enabled training silently ran without
loading the MTP head weights from the HF checkpoint.

Repro on a tiny config (num_hidden_layers=2, num_nextn_predict_layers=1):

    Model expects 38 mtp.* state_dict keys
    adapter.from_hf produced  4 mtp.* keys (the 4 unprefixed fusion ones)
    35 mtp.* keys MISSING, 24 keys leaked to model.layers.2.* (dropped)

Make the regex prefix-tolerant (``^(model\.)?layers\.(\d+)\.``) and use
the second capture group as the layer index. After the fix, the same
repro produces 0 missing / 0 extra, and a save→load round-trip via
to_hf -> from_hf reconstructs every mtp.* key the model exposes.

Add a regression test ``test_from_hf_renames_mtp_layer_with_model_prefix``
that exercises the prefixed form so this cannot silently regress again.

Signed-off-by: khazic <khazzz1c@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test c228ec4

Comment thread nemo_automodel/components/moe/parallelizer.py Outdated
HuiyingLi and others added 4 commits May 15, 2026 17:21
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Co-authored-by: Adil <47084919+adil-a@users.noreply.github.com>
Signed-off-by: Huiying <willwin.lee@gmail.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test c4e0fcc

@yiakwy-xpu-ml-framework-team
Copy link
Copy Markdown

@HuiyingLi @khazic Do we have accuracy and accept length report ?

@HuiyingLi
Copy link
Copy Markdown
Contributor

HuiyingLi commented May 17, 2026

@yiakwy-xpu-ml-framework-team Thank you for the comment. I will work on an accuracy report.

HuiyingLi and others added 2 commits May 18, 2026 15:49
…resume

DeepseekV4Indexer's wkv / wgate / ape_param / wq_b / weights_proj feed
into topk(...).indices, which is non-differentiable. No gradient ever
reaches these params, so AdamW never allocates lazy state slots, so
checkpoint save serializes a partial optim state, so DCP resume fails:

  RuntimeError: Missing key in checkpoint state_dict:
  optim.state.model.layers.{i}.self_attn.compressor.indexer.wkv.weight.step.

Mirror PR NVIDIA-NeMo#1698's KV-sharing fix: freeze the dead params in
apply_model_infrastructure() before sharding so they are never tracked
by the optimizer.

Locally reproduced with 4-layer DSV4-Flash + compress_ratios=[0,0,4,128]
on 8x H100 via torchrun: un-fixed resume errors on the missing key at
layer 2; fixed resume loads cleanly with bit-identical step-0/1 losses
to the un-fixed forward pass.

Adds 9 CPU unit tests under tests/unit_tests/models/deepseek_v4/
test_indexer_freeze.py covering no-op cases, frozen-FQN correctness,
non-indexer params untouched, the bug reproduction, and a full save/load
roundtrip.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Conflict in nemo_automodel/components/distributed/pipelining/functional.py
between two designs for keeping non-layer PP dependencies on the right
stages:

* HEAD: generic `_get_optional_hook(model, "customize_pipeline_stage_modules")`
  + a per-model `customize_pipeline_stage_modules` method. Covers DSV4's
  rotary_emb_compress, hc_head, and the new mtp module.
* origin/main: inline `is_v4_keep` special-case in functional.py. Covers
  rotary_emb_compress, hc_head, and swa_rotary_emb (the latter actually
  belongs to MiMoV2Flash, not DSV4 -- main's check was not gated on
  model_type for that one).

Kept HEAD's hook pattern and ported swa_rotary_emb coverage to the model
class where it belongs:

* nemo_automodel/components/models/mimo_v2_flash/model.py: add
  customize_pipeline_stage_modules on MiMoV2FlashForCausalLM that pins
  swa_rotary_emb to every PP stage.
* tests/unit_tests/models/mimo_v2_flash/test_model.py: add a unit test
  for the new hook.

All four modules (rotary_emb_compress, hc_head, mtp, swa_rotary_emb) are
now covered. DSV4-specific knowledge stays in deepseek_v4/model.py;
MiMo-specific knowledge stays in mimo_v2_flash/model.py;
pipelining/functional.py remains model-agnostic.

Tests: 75 passed locally (DSV4 indexer-freeze, DSV4 MTP pipeline hooks,
generic test_functional.py hook tests, new MiMo hook test).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 74c5284

Add the blank line ruff expects between third-party torch import and
the first-party nemo_automodel imports. Caught by `ruff check` during
the babysit pre-flight; no behavior change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/claude review

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 93905f8

return self.lm_head(hidden)

@torch.no_grad()
def customize_pipeline_stage_modules(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: The new method was inserted between the @torch.no_grad() decorator and initialize_weights, so the decorator now applies to customize_pipeline_stage_modules instead of initialize_weights.

On origin/main, line 677–678 is:

@torch.no_grad()
def initialize_weights(

After this PR, @torch.no_grad() decorates customize_pipeline_stage_modules (harmless but unintended), and initialize_weights loses its decorator (weight init may inadvertently build autograd graphs).

Fix: move the @torch.no_grad() back onto initialize_weights (line 695).

Suggested change
def customize_pipeline_stage_modules(
def customize_pipeline_stage_modules(

modules.append(fqn)
return stage_modules

def initialize_weights(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add @torch.no_grad() back here — it was stolen by the insertion above.

Suggested change
def initialize_weights(
@torch.no_grad()
def initialize_weights(

The previous commit inserted `customize_pipeline_stage_modules` between
`@torch.no_grad()` and `initialize_weights`, which silently transferred
the decorator onto the new method (harmless) and stripped it from
`initialize_weights` (could let weight init build autograd graphs).

Move the decorator back onto `initialize_weights`. Pure structural fix
caught by review feedback on NVIDIA-NeMo#2191.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/claude review

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test e2c0b5c

Comment thread nemo_automodel/components/loss/mtp.py Outdated
Comment on lines +42 to +43
from nemo_automodel.recipes.llm.train_ft import calculate_loss, calculate_mtp_loss

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug/layering concern: components/loss/ imports from recipes/llm/train_ft.py, which in turn imports PipelineCausalLMLoss from this file (line 83 of train_ft.py). The deferred import avoids the import-time cycle, but this is an architectural inversion — components should not depend on recipes.

Could calculate_loss and calculate_mtp_loss be extracted into a shared utility (e.g., components/loss/utils.py or alongside this file) that both this module and train_ft.py import?

@@ -0,0 +1,289 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Copyright year is 2025, but this is a new file added in 2026. The other new file in this PR (components/loss/mtp.py) correctly uses 2026.

Suggested change
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good PR — solid test coverage (MTP config, construction, forward/backward, state-dict round-trip, PP hooks, indexer freeze, and PipelineCausalLMLoss). The generic customize_pipeline_stage_modules / get_pipeline_stage_metas hooks are a nice decoupling of DSV4-specific PP logic from the framework.

Two items:

  1. Circular import between components/loss/mtp.py and recipes/llm/train_ft.py — see inline comment. The deferred import avoids the import-time cycle, but the dependency direction (component → recipe) is inverted.

  2. PR description vs. code mismatch — the description says "raises NotImplementedError if pipeline parallelism is enabled together with MTP" and "PP + MTP is intentionally deferred to a follow-up PR." However, this PR actually implements PP+MTP support via PipelineCausalLMLoss, get_pipeline_stage_metas, and customize_pipeline_stage_modules. The description should be updated to reflect that PP+MTP is now supported (non-interleaved schedules at least).

… year

Two follow-ups on the latest claude review (PR NVIDIA-NeMo#2191):

1. **Layering inversion fixed.** ``components/loss/mtp.py``'s
   ``PipelineCausalLMLoss.forward`` had a deferred ``from
   nemo_automodel.recipes.llm.train_ft import calculate_loss,
   calculate_mtp_loss`` to dodge an import-time cycle. The dependency
   direction was wrong: components should never reach into recipes.

   - Extract ``calculate_loss`` (already model-agnostic, used by KD too)
     into a new ``nemo_automodel/components/loss/utils.py``.
   - Move ``calculate_mtp_loss`` alongside ``PipelineCausalLMLoss`` in
     ``components/loss/mtp.py``; it now imports ``calculate_loss`` from
     ``..utils`` at module scope.
   - ``recipes/llm/train_ft.py`` re-imports both names from the new
     locations, preserving the public ``train_ft.calculate_loss`` /
     ``train_ft.calculate_mtp_loss`` import surface that
     ``recipes/llm/kd.py`` and ``tests/functional_tests/checkpoint/
     test_dcp.py`` already rely on. Drop now-unused
     ``get_lm_head_module`` / ``get_lm_head_weight`` imports from
     train_ft.

2. **Copyright year.** ``deepseek_v4/mtp.py`` is a new file added in
   2026, not 2025; align with sibling new file ``loss/mtp.py``.

No behavior change. 179 unit tests across the touched modules pass
locally.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/claude review

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 65e5b2b

claude[bot]
claude Bot previously approved these changes May 18, 2026
Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 65e5b2b

…ontract

CI's import-linter `Components must not import each other` contract was
broken by the previous layering refactor, with three violations:

  - loss.mtp -> training.model_output_utils (pre-existing)
  - loss.mtp -> utils.model_utils (introduced)
  - loss.utils -> utils.model_utils (introduced)

Inline the three small helpers `get_lm_head_module`, `get_lm_head_weight`,
`get_final_hidden_states` as private `_`-prefixed copies inside
``components/loss/utils.py`` so ``components/loss/`` no longer reaches
into other component packages. Use them from both ``loss/utils.py``
(calculate_loss) and ``loss/mtp.py`` (calculate_mtp_loss,
PipelineCausalLMLoss.forward).

`lint-imports` now reports `1 kept, 0 broken`. 116 unit tests pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test da7aac2

Copy link
Copy Markdown
Collaborator

@adil-a adil-a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, awesome work everyone!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants