Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions tests/pytorch/test_grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,53 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass, monkeypatch
torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2)


@pytest.mark.skipif(
torch.cuda.get_device_capability() != (9, 0),
reason="Only enable CUTLASS grouped gemm on Hopper",
)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
def test_grouped_gemm_cutlass_empty_groups(layout, monkeypatch):
dtype = torch.bfloat16
z, k, n = 1, 2048, 1536
m_splits = [0] * z

if layout == "TN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = [torch.empty(0, k, dtype=dtype, device="cuda") for _ in range(z)] # input
out = [torch.empty(0, n, dtype=dtype, device="cuda")] # output
grad = False
single_output = True
elif layout == "NN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = [torch.empty(0, n, dtype=dtype, device="cuda") for _ in range(z)] # grad_output
out = [torch.empty(0, k, dtype=dtype, device="cuda")] # dgrad
grad = True
single_output = True
else: # layout == "NT"
A = [torch.empty(0, k, dtype=dtype, device="cuda") for _ in range(z)] # input
B = [torch.empty(0, n, dtype=dtype, device="cuda") for _ in range(z)] # grad_output
out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad
grad = True
single_output = False

monkeypatch.setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1")
general_grouped_gemm(
A,
B,
out,
[None] * z,
dtype,
m_splits=m_splits,
grad=grad,
layout=layout,
single_output=single_output,
)
torch.cuda.synchronize()

for tensor in out:
torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0)
Comment on lines +912 to +913
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Vacuous assertion for TN/NN layouts

For TN and NN layouts, out is constructed as [torch.empty(0, n)] / [torch.empty(0, k)], so the tensor has zero elements. torch.zeros_like(tensor) on a 0-element tensor is also 0-element, and assert_close on two empty tensors always succeeds regardless of the fix. The test effectively only checks "no crash occurs" for those two layouts — it cannot catch a regression where the output buffer wasn't properly left untouched. Only the NT layout exercises a non-empty output that is meaningfully asserted to be zero.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!



def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None:
data = grouped_tensor.rowwise_data
if data is None:
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,10 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_gemm);

if (num_gemms <= 0) {
return;
}

const int current_device = transformer_engine::cuda::current_device();
const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90);
const bool use_cutlass = transformer_engine::getenv<bool>("NVTE_USE_CUTLASS_GROUPED_GEMM", false);
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,10 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out));
}

if (te_A_wrappers.empty()) {
return bias;
}

// Keep the swizzled scaling factor tensors alive during the GEMM.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;

Expand Down
Loading