[fix] Fix CUTLASS grouped GEMM segfault for empty groups#3037
[fix] Fix CUTLASS grouped GEMM segfault for empty groups#3037Baibaifan wants to merge 4 commits into
Conversation
Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
Greptile SummaryThis PR fixes a native segfault that occurred when the CUTLASS grouped-GEMM path was reached with all groups empty:
Confidence Score: 5/5The change is safe to merge: both guards are minimal, idempotent, and sit on paths that were already broken (they only fire when there is nothing to do). Both fixes are tightly scoped: the No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[general_grouped_gemm Python] --> B[te_general_grouped_gemm C++]
B --> C{Loop over A/B groups}
C -->|te_A.numel==0 or te_B.numel==0| D[zero out_tensor / bias / pre_gelu_out\ncontinue]
D --> C
C -->|all groups processed| E{te_A_wrappers.empty?\nNEW GUARD}
E -->|yes - all groups filtered| F[return bias early\navoid segfault]
E -->|no - some non-empty groups| G[swizzle scales\nbuild NVTETensor vectors]
G --> H[nvte_multi_tensor_gemm]
H --> I{num_gemms <= 0?\nNEW GUARD}
I -->|yes| J[return early]
I -->|no - Hopper + CUTLASS| K[CUTLASS grouped GEMM]
I -->|no - other path| L[cuBLAS multi-stream GEMM]
Reviews (3): Last reviewed commit: "[fix] add empty_groups unit test." | Re-trigger Greptile |
| for tensor in out: | ||
| torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0) | ||
|
|
||
|
|
||
| def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: | ||
| data = grouped_tensor.rowwise_data |
There was a problem hiding this comment.
Zero-assertion is trivially true for TN and NN layouts
For TN and NN, out is constructed as a list containing a single 0-element tensor (torch.empty(0, n/k, ...)). torch.testing.assert_close on two empty tensors passes unconditionally regardless of any computation, so those two sub-cases only serve as crash/segfault guards. The meaningful assertion only fires for NT, where out[0] is a full (n, k) buffer that the C++ code zeros in-place. Consider either documenting this in a comment or, for TN/NN, adding a small non-empty output tensor and asserting it is zero to provide the same level of postcondition coverage as NT.
|
Hi @Baibaifan, could you resolve the conflicts? |
hi, @ptrendx Due to a code merging issue, I have resubmitted a pull request (PR). The new PR can be found here: #3067. |
Description
Handle grouped GEMM calls where all groups are empty.
MoE routing can legally produce a microbatch where no local expert receives
tokens. The PyTorch grouped GEMM wrapper filters those zero-token GEMMs, but
the CUTLASS grouped GEMM path could still be reached with num_gemms == 0 and
then dereference A[0]/B[0]/D[0], causing a native segfault.
Return early after filtering all GEMMs in
te_general_grouped_gemm, and add adefensive
num_gemms <= 0guard in nvte_multi_tensor_gemm.Add a Hopper/CUTLASS regression test covering all-empty grouped GEMM inputs for
TN, NN, and NT layouts.
Type of change
Changes
Please list the changes introduced in this PR:
te_general_grouped_gemmwhen all GEMMs were filtered.num_gemms <= 0guard innvte_multi_tensor_gemm.NVTE_USE_CUTLASS_GROUPED_GEMM=1.TN,NN, andNTlayouts.Testing
pytest -q tests/pytorch/test_numerics.py::test_grouped_gemm_cutlass_empty_groups -sChecklist: