Skip to content

[Fix] Fix CUTLASS grouped GEMM segfault for empty groups#3067

Open
Baibaifan wants to merge 1 commit into
NVIDIA:mainfrom
Baibaifan:empty_groupgemm
Open

[Fix] Fix CUTLASS grouped GEMM segfault for empty groups#3067
Baibaifan wants to merge 1 commit into
NVIDIA:mainfrom
Baibaifan:empty_groupgemm

Conversation

@Baibaifan
Copy link
Copy Markdown

@Baibaifan Baibaifan commented Jun 1, 2026

Description

Handle grouped GEMM calls where all groups are empty.

MoE routing can legally produce a microbatch where no local expert receives
tokens. The PyTorch grouped GEMM wrapper filters those zero-token GEMMs, but
the CUTLASS grouped GEMM path could still be reached with num_gemms == 0 and
then dereference A[0]/B[0]/D[0], causing a native segfault.

Return early after filtering all GEMMs in te_general_grouped_gemm, and add a
defensive num_gemms <= 0 guard in nvte_multi_tensor_gemm.

Add a Hopper/CUTLASS regression test covering all-empty grouped GEMM inputs for
TN, NN, and NT layouts.

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Changes

Please list the changes introduced in this PR:

  • Return early from te_general_grouped_gemm when all GEMMs were filtered.
  • Add a defensive num_gemms <= 0 guard in nvte_multi_tensor_gemm.
  • Add a Hopper-only regression test for all-empty grouped GEMM inputs under
    NVTE_USE_CUTLASS_GROUPED_GEMM=1.
  • Cover TN, NN, and NT layouts.

Testing

pytest -q tests/pytorch/test_grouped_linear.py::test_grouped_gemm_cutlass_empty_groups -s

Checklist:

  • I have read and followed the [contributing guidelines]
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
@Baibaifan Baibaifan requested a review from ksivaman as a code owner June 1, 2026 07:02
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 1, 2026
@Baibaifan
Copy link
Copy Markdown
Author

hi, @ptrendx

This is a new pull request after fixing the code conflict. The original pull request is here: #3037.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 1, 2026

Greptile Summary

This PR fixes a native segfault in the CUTLASS grouped GEMM path triggered when MoE routing produces a microbatch where all experts receive zero tokens. The existing per-group filtering loop already zero'd empty-group outputs, but never guarded against the case where all groups were filtered out — leaving nvte_multi_tensor_gemm to be called with zero groups and then immediately dereference A[0]/B[0]/D[0] inside is_supported_dtype().

  • te_general_grouped_gemm: adds an early return after the filtering loop when te_A_wrappers is empty; output tensors and bias grads are already handled in-place by the existing per-group zeroing logic, so the return value is semantically identical to the normal exit path.
  • nvte_multi_tensor_gemm: adds a num_gemms <= 0 guard as a belt-and-suspenders defense for any direct caller of this public API with zero groups.
  • Test: adds a Hopper-only regression test covering TN, NN, and NT layouts under NVTE_USE_CUTLASS_GROUPED_GEMM=1; the NT case is the most meaningful assertion (non-empty wgrad buffer verified to be zeroed), while TN/NN operate on 0-element tensors so the assert_close there is primarily a crash guard.

Confidence Score: 4/5

Safe to merge; the two C++ changes are minimal, targeted, and leave all existing non-empty code paths unaffected.

Both C++ changes are correct and narrow: the early return in te_general_grouped_gemm reuses the existing per-group zeroing path, and the num_gemms <= 0 guard in nvte_multi_tensor_gemm is a pure defense-in-depth no-op for normal inputs. The test for TN/NN layouts uses 0-element output tensors, so those assert_close calls pass trivially and cannot catch a future regression in output correctness for those layouts.

tests/pytorch/test_grouped_linear.py — the TN/NN assertions on empty tensors are vacuous; consider adding a non-empty output variant to those layout tests to strengthen regression coverage.

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_gemm.cu Adds a defensive num_gemms <= 0 guard at the entry of nvte_multi_tensor_gemm, preventing a segfault that occurred when accessing A[0]/B[0]/D[0] in is_supported_dtype() with zero groups on the Hopper/CUTLASS path.
transformer_engine/pytorch/csrc/extensions/gemm.cpp Adds an early return in te_general_grouped_gemm after the per-group filtering loop when all groups are empty, preventing nvte_multi_tensor_gemm from being invoked with a zero-length group list. Output tensors are already zeroed/left unchanged by the existing per-group handling before the new check.
tests/pytorch/test_grouped_linear.py Adds a Hopper-only regression test for all-empty grouped GEMM across TN/NN/NT layouts with NVTE_USE_CUTLASS_GROUPED_GEMM=1. The assertions are meaningful for NT (non-empty output zeroed in-place) but trivially pass for TN/NN (output tensors already empty).

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[general_grouped_gemm Python] --> B[te_general_grouped_gemm C++]
    B --> C{Loop over each group}
    C --> D{A or B has zero elements?}
    D -- Yes --> E[Zero output in-place, continue]
    E --> C
    D -- No --> F[Push group to wrappers]
    F --> C
    C -- All done --> G{te_A_wrappers empty? NEW GUARD}
    G -- Yes --> H[return bias early - no segfault]
    G -- No --> I[nvte_multi_tensor_gemm]
    I --> J{num_gemms le 0? NEW GUARD}
    J -- Yes --> K[return early - defensive check]
    J -- No --> L{Hopper AND CUTLASS enabled?}
    L -- Yes --> M[is_supported_dtype accesses A0 B0 D0]
    M --> N[cutlass_grouped_gemm]
    L -- No --> O[multi_stream_cublas_gemm]
Loading

Reviews (1): Last reviewed commit: "[fix] Fix CUTLASS grouped GEMM segfault ..." | Re-trigger Greptile

Comment on lines +912 to +913
for tensor in out:
torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Vacuous assertion for TN/NN layouts

For TN and NN layouts, out is constructed as [torch.empty(0, n)] / [torch.empty(0, k)], so the tensor has zero elements. torch.zeros_like(tensor) on a 0-element tensor is also 0-element, and assert_close on two empty tensors always succeeds regardless of the fix. The test effectively only checks "no crash occurs" for those two layouts — it cannot catch a regression where the output buffer wasn't properly left untouched. Only the NT layout exercises a non-empty output that is meaningfully asserted to be zero.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Jun 1, 2026

/te-ci pytorch

@Baibaifan
Copy link
Copy Markdown
Author

/te-ci pytorch

hi @ptrendx

The falling check doesn't seem to be related to my changes. Would it be convenient to retest it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants