[PyTorch] Error out if constructing LayerNormLinear with row tensor parallelism#2688
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Greptile SummaryThis PR adds an explicit Key changes:
Minor issues found:
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["LayerNormLinear.__init__(parallel_mode)"] --> B{"parallel_mode in\nGemmParallelModes?"}
B -- No --> C["assert fails ❌"]
B -- Yes --> D{"parallel_mode\n== 'row'?"}
D -- Yes --> E["raise NotImplementedError\n'Normalization does not support\ntensor-parallel distribution.' ❌\n(NEW — this PR)"]
D -- No --> F{"parallel_mode\n== 'column'?"}
F -- Yes --> G["out_features //= tp_size"]
F -- No (None) --> H["No feature-size adjustment"]
G --> I["Continue init…"]
H --> I
|
Additional Comments (1)
|
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch L1 |
Additional Comments (3)
|
|
/te-ci pytorch |
ksivaman
left a comment
There was a problem hiding this comment.
We can do more to clean up as we currently do attempt to implement this in the module, but it is better to keep it to make future support easier.
… parallelism (NVIDIA#2688) * Error out if constructing LayerNormLinear with row tensor parallelism Signed-off-by: Tim Moon <tmoon@nvidia.com> * Disable Userbuffers test for row-TP LayerNormLinear Signed-off-by: Tim Moon <tmoon@nvidia.com> --------- Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Description
LayerNormLinearmodules with row tensor-parallel have input tensors that are sharded along the inner dimension:TransformerEngine/transformer_engine/pytorch/module/layernorm_linear.py
Lines 1199 to 1200 in 7e48fa1
However, we currently don't support tensor-parallel LayerNorm or RMSNorm, which would involve a tensor-parallel all-reduce to compute statistics. If the user attempts to run
LayerNormLinearwith row tensor parallelism, then they experience an illegal memory access when the norm kernel accesses values in the unsharded norm weight tensor. We haven't experienced problems so far because row TP is usually used for the proj and fc2 layers, which are usuallyLinears.This PR adds an error message to make the failure more obvious.
Type of change
Changes
LayerNormLinearwith row tensor parallelismChecklist: