[Fix] Fix CUTLASS grouped GEMM segfault for empty groups#3067
Conversation
Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
Greptile SummaryThis PR fixes a native segfault in the CUTLASS grouped GEMM path triggered when MoE routing produces a microbatch where all experts receive zero tokens. The existing per-group filtering loop already zero'd empty-group outputs, but never guarded against the case where all groups were filtered out — leaving
Confidence Score: 4/5Safe to merge; the two C++ changes are minimal, targeted, and leave all existing non-empty code paths unaffected. Both C++ changes are correct and narrow: the early return in te_general_grouped_gemm reuses the existing per-group zeroing path, and the num_gemms <= 0 guard in nvte_multi_tensor_gemm is a pure defense-in-depth no-op for normal inputs. The test for TN/NN layouts uses 0-element output tensors, so those assert_close calls pass trivially and cannot catch a future regression in output correctness for those layouts. tests/pytorch/test_grouped_linear.py — the TN/NN assertions on empty tensors are vacuous; consider adding a non-empty output variant to those layout tests to strengthen regression coverage. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[general_grouped_gemm Python] --> B[te_general_grouped_gemm C++]
B --> C{Loop over each group}
C --> D{A or B has zero elements?}
D -- Yes --> E[Zero output in-place, continue]
E --> C
D -- No --> F[Push group to wrappers]
F --> C
C -- All done --> G{te_A_wrappers empty? NEW GUARD}
G -- Yes --> H[return bias early - no segfault]
G -- No --> I[nvte_multi_tensor_gemm]
I --> J{num_gemms le 0? NEW GUARD}
J -- Yes --> K[return early - defensive check]
J -- No --> L{Hopper AND CUTLASS enabled?}
L -- Yes --> M[is_supported_dtype accesses A0 B0 D0]
M --> N[cutlass_grouped_gemm]
L -- No --> O[multi_stream_cublas_gemm]
Reviews (1): Last reviewed commit: "[fix] Fix CUTLASS grouped GEMM segfault ..." | Re-trigger Greptile |
| for tensor in out: | ||
| torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0) |
There was a problem hiding this comment.
Vacuous assertion for TN/NN layouts
For TN and NN layouts, out is constructed as [torch.empty(0, n)] / [torch.empty(0, k)], so the tensor has zero elements. torch.zeros_like(tensor) on a 0-element tensor is also 0-element, and assert_close on two empty tensors always succeeds regardless of the fix. The test effectively only checks "no crash occurs" for those two layouts — it cannot catch a regression where the output buffer wasn't properly left untouched. Only the NT layout exercises a non-empty output that is meaningfully asserted to be zero.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
/te-ci pytorch |
hi @ptrendx The falling check doesn't seem to be related to my changes. Would it be convenient to retest it? |
Description
Handle grouped GEMM calls where all groups are empty.
MoE routing can legally produce a microbatch where no local expert receives
tokens. The PyTorch grouped GEMM wrapper filters those zero-token GEMMs, but
the CUTLASS grouped GEMM path could still be reached with num_gemms == 0 and
then dereference A[0]/B[0]/D[0], causing a native segfault.
Return early after filtering all GEMMs in
te_general_grouped_gemm, and add adefensive
num_gemms <= 0guard in nvte_multi_tensor_gemm.Add a Hopper/CUTLASS regression test covering all-empty grouped GEMM inputs for
TN, NN, and NT layouts.
Type of change
Changes
Please list the changes introduced in this PR:
te_general_grouped_gemmwhen all GEMMs were filtered.num_gemms <= 0guard innvte_multi_tensor_gemm.NVTE_USE_CUTLASS_GROUPED_GEMM=1.TN,NN, andNTlayouts.Testing
pytest -q tests/pytorch/test_grouped_linear.py::test_grouped_gemm_cutlass_empty_groups -sChecklist: