Skip to content

fix: resolve TP+PP for nemotron super 49B#1607

Merged
akoumpa merged 4 commits intomainfrom
huiyingl/fix-tp-pp-pipeline-parallelism-bugs
Mar 25, 2026
Merged

fix: resolve TP+PP for nemotron super 49B#1607
akoumpa merged 4 commits intomainfrom
huiyingl/fix-tp-pp-pipeline-parallelism-bugs

Conversation

@HuiyingLi
Copy link
Copy Markdown
Contributor

@HuiyingLi HuiyingLi commented Mar 25, 2026

single node hellaswag
image
2nodes squad
image

When pipeline parallelism splits a model, nn.ModuleList layers are
converted to nn.ModuleDict. Three issues surfaced with custom models
(e.g. DeciLM/Nemotron-49B) that use explicit self.num_heads in
attention views and return tuples from decoder layers:

1. _update_attention_head_counts_for_tp iterates `for layer in layers`,
   which yields string keys (not modules) for ModuleDict — head counts
   were never updated, causing shape mismatches in the Q/K/V view.

2. The walrus operator fallback for causal_mask_mapping could leave a
   raw 2D attention_mask in place of the expected 4D causal mask when
   the import or computation failed silently.

3. The batch device-move code filtered out None values from nested
   dicts, dropping causal_mask_mapping entries for sdpa-configured
   models where create_causal_mask returns None.

Additionally, decoder layers in older-style HF models (pre-v5) return
tuples rather than bare tensors, and raw 2D padding masks that leak
through the pipeline schedule need to be dropped before reaching
custom attention code.

Verified on nvidia/Llama-3_3-Nemotron-Super-49B-v1_5 with tp4pp2
(100 training steps, hellaswag dataset, 8xH100).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Mar 25, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 25, 2026

/ok to test 15f9cbe

@HuiyingLi
Copy link
Copy Markdown
Contributor Author

/claude review

claude[bot]
claude Bot previously approved these changes Mar 25, 2026
Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Both fixes are correct and well-targeted:

  1. ModuleDict iteration (parallelizer.py): Properly handles the PP-converted ModuleDict by iterating .values() instead of yielding keys.
  2. Tuple unpacking + kwargs removal (hf_utils.py): Correctly extracts hidden_states from the decoder layer's tuple output, matching the standard HF contract.

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 25, 2026

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@thomasdhc thomasdhc force-pushed the huiyingl/fix-tp-pp-pipeline-parallelism-bugs branch from 4abe87d to 578e85c Compare March 25, 2026 23:09
akoumpa added 2 commits March 25, 2026 16:23
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Copy link
Copy Markdown
Contributor

@akoumpa akoumpa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. thank you @HuiyingLi

@akoumpa akoumpa merged commit 9db8c1f into main Mar 25, 2026
4 checks passed
@akoumpa akoumpa deleted the huiyingl/fix-tp-pp-pipeline-parallelism-bugs branch March 25, 2026 23:25
@chtruong814
Copy link
Copy Markdown
Contributor

/claude review

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

linnanwang pushed a commit that referenced this pull request Apr 24, 2026
* fix: resolve TP+PP pipeline parallelism bugs for custom HF models

When pipeline parallelism splits a model, nn.ModuleList layers are
converted to nn.ModuleDict. Three issues surfaced with custom models
(e.g. DeciLM/Nemotron-49B) that use explicit self.num_heads in
attention views and return tuples from decoder layers:

1. _update_attention_head_counts_for_tp iterates `for layer in layers`,
   which yields string keys (not modules) for ModuleDict — head counts
   were never updated, causing shape mismatches in the Q/K/V view.

2. The walrus operator fallback for causal_mask_mapping could leave a
   raw 2D attention_mask in place of the expected 4D causal mask when
   the import or computation failed silently.

3. The batch device-move code filtered out None values from nested
   dicts, dropping causal_mask_mapping entries for sdpa-configured
   models where create_causal_mask returns None.

Additionally, decoder layers in older-style HF models (pre-v5) return
tuples rather than bare tensors, and raw 2D padding masks that leak
through the pipeline schedule need to be dropped before reaching
custom attention code.

Verified on nvidia/Llama-3_3-Nemotron-Super-49B-v1_5 with tp4pp2
(100 training steps, hellaswag dataset, 8xH100).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* update recipe

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

* fix

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

---------

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants