Skip to content

[PyTorch] Error out if constructing LayerNormLinear with row tensor parallelism#2688

Merged
ksivaman merged 4 commits intoNVIDIA:mainfrom
timmoon10:tmoon/row-tp-layernorm-linear
Mar 12, 2026
Merged

[PyTorch] Error out if constructing LayerNormLinear with row tensor parallelism#2688
ksivaman merged 4 commits intoNVIDIA:mainfrom
timmoon10:tmoon/row-tp-layernorm-linear

Conversation

@timmoon10
Copy link
Copy Markdown
Collaborator

Description

LayerNormLinear modules with row tensor-parallel have input tensors that are sharded along the inner dimension:

elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size)

However, we currently don't support tensor-parallel LayerNorm or RMSNorm, which would involve a tensor-parallel all-reduce to compute statistics. If the user attempts to run LayerNormLinear with row tensor parallelism, then they experience an illegal memory access when the norm kernel accesses values in the unsharded norm weight tensor. We haven't experienced problems so far because row TP is usually used for the proj and fc2 layers, which are usually Linears.

This PR adds an error message to make the failure more obvious.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Error out if constructing LayerNormLinear with row tensor parallelism

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@timmoon10 timmoon10 added the bug Something isn't working label Feb 17, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 17, 2026

Greptile Summary

This PR adds an explicit NotImplementedError in LayerNormLinear.__init__ when parallel_mode="row" is passed, preventing a cryptic illegal memory access that users would otherwise encounter at runtime. The corresponding row-parallel LayerNormLinear test cases are also removed from the comm+GEMM overlap test suite.

Key changes:

  • layernorm_linear.py: Raises NotImplementedError("Normalization does not support tensor-parallel distribution.") immediately after the GemmParallelModes assertion when parallel_mode == "row", making the unsupported configuration fail fast with a clear message.
  • test_comm_gemm_overlap.py: Removes (te.LayerNormLinear.__name__, "row", False) entries and their corresponding IDs from the test_bulk_overlaps, test_layers_with_overlap_bf16, and test_layers_with_overlap_fp8 parametrize lists.

Minor issues found:

  • The elif self.parallel_mode == "row": self.in_features = divide(...) branch immediately following the raise is now dead code and should be removed.
  • The parallel_mode docstring still documents 'row' as a valid value ({None, 'column', 'row'}); it should be updated to reflect the new restriction.

Confidence Score: 4/5

  • This PR is safe to merge — it converts a silent runtime crash into a clear NotImplementedError with no risk of regressions for supported configurations.
  • The change is minimal, targeted, and correct: the guard fires only on the unsupported "row" mode, leaving all "column" and None paths entirely unaffected. The only issues are a small piece of dead code and an outdated docstring, neither of which affects runtime behaviour.
  • transformer_engine/pytorch/module/layernorm_linear.py — dead elif branch (line 1202) and stale docstring (line 1096) should be cleaned up.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/layernorm_linear.py Adds a NotImplementedError guard in LayerNormLinear.__init__ when parallel_mode="row" is passed. Two minor follow-up issues: the elif self.parallel_mode == "row" branch on the next lines is now dead code, and the parallel_mode docstring still lists 'row' as a valid option.
tests/pytorch/distributed/test_comm_gemm_overlap.py Removes LayerNormLinear row-parallel test cases from test_bulk_overlaps and test_layers_with_overlap_bf16/test_layers_with_overlap_fp8 parametrize lists, consistent with the new NotImplementedError restriction added to the module.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["LayerNormLinear.__init__(parallel_mode)"] --> B{"parallel_mode in\nGemmParallelModes?"}
    B -- No --> C["assert fails ❌"]
    B -- Yes --> D{"parallel_mode\n== 'row'?"}
    D -- Yes --> E["raise NotImplementedError\n'Normalization does not support\ntensor-parallel distribution.' ❌\n(NEW — this PR)"]
    D -- No --> F{"parallel_mode\n== 'column'?"}
    F -- Yes --> G["out_features //= tp_size"]
    F -- No (None) --> H["No feature-size adjustment"]
    G --> I["Continue init…"]
    H --> I
Loading

Comments Outside Diff (2)

  1. transformer_engine/pytorch/module/layernorm_linear.py, line 1202-1203 (link)

    Unreachable dead code

    Lines 1202-1203 can never be reached because the raise NotImplementedError above unconditionally exits the constructor whenever self.parallel_mode == "row". This elif branch is now dead code and can be removed to avoid confusion.

  2. transformer_engine/pytorch/module/layernorm_linear.py, line 1096-1099 (link)

    Docstring still lists 'row' as a valid option

    The parallel_mode parameter documentation still advertises 'row' as a valid value, but LayerNormLinear now raises NotImplementedError if that value is passed. This can mislead users into thinking the option is supported. The docstring should be updated to reflect the new restriction.

Last reviewed commit: fc7c11c

greptile-apps[bot]

This comment was marked as outdated.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 17, 2026

Additional Comments (1)

transformer_engine/pytorch/module/layernorm_linear.py
Unreachable code: this elif branch will never execute since row parallel mode now raises NotImplementedError on line 1197. Consider removing these lines or moving the error check after this conditional.

@timmoon10

This comment was marked as outdated.

ptrendx
ptrendx previously approved these changes Feb 18, 2026
@timmoon10
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 20, 2026

Additional Comments (3)

transformer_engine/pytorch/module/layernorm_linear.py
Dead code - the NotImplementedError raised on line 1197 prevents execution from reaching this row-parallel case


transformer_engine/pytorch/module/layernorm_linear.py
Dead code - row-parallel mode raises NotImplementedError at line 1197, so these will always be False

        # Row-parallel overlaps (disabled - not supported)
        self.ub_overlap_rs_fprop = False
        self.ub_overlap_ag_dgrad = False

transformer_engine/pytorch/module/layernorm_linear.py
Dead code - row-parallel mode is not supported (line 1197), so this will always be False

        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        # Note: RPL is not supported for LayerNormLinear
        self.gemm_bias_unfused_add = False

@ksivaman
Copy link
Copy Markdown
Member

/te-ci pytorch

Copy link
Copy Markdown
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do more to clean up as we currently do attempt to implement this in the module, but it is better to keep it to make future support easier.

@ksivaman ksivaman merged commit 6a68c73 into NVIDIA:main Mar 12, 2026
18 of 24 checks passed
@ksivaman ksivaman deleted the tmoon/row-tp-layernorm-linear branch March 12, 2026 22:21
vthumbe1503 pushed a commit to ksivaman/TransformerEngine-1 that referenced this pull request Apr 1, 2026
… parallelism (NVIDIA#2688)

* Error out if constructing LayerNormLinear with row tensor parallelism

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Disable Userbuffers test for row-TP LayerNormLinear

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants