Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions tests/pytorch/test_grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,53 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass, monkeypatch
torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2)


@pytest.mark.skipif(
torch.cuda.get_device_capability() != (9, 0),
reason="Only enable CUTLASS grouped gemm on Hopper",
)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
def test_grouped_gemm_cutlass_empty_groups(layout, monkeypatch):
dtype = torch.bfloat16
z, k, n = 1, 2048, 1536
m_splits = [0] * z

if layout == "TN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = [torch.empty(0, k, dtype=dtype, device="cuda") for _ in range(z)] # input
out = [torch.empty(0, n, dtype=dtype, device="cuda")] # output
grad = False
single_output = True
elif layout == "NN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = [torch.empty(0, n, dtype=dtype, device="cuda") for _ in range(z)] # grad_output
out = [torch.empty(0, k, dtype=dtype, device="cuda")] # dgrad
grad = True
single_output = True
else: # layout == "NT"
A = [torch.empty(0, k, dtype=dtype, device="cuda") for _ in range(z)] # input
B = [torch.empty(0, n, dtype=dtype, device="cuda") for _ in range(z)] # grad_output
out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad
grad = True
single_output = False

monkeypatch.setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1")
general_grouped_gemm(
A,
B,
out,
[None] * z,
dtype,
m_splits=m_splits,
grad=grad,
layout=layout,
single_output=single_output,
)
torch.cuda.synchronize()

for tensor in out:
torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0)


def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None:
data = grouped_tensor.rowwise_data
if data is None:
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,10 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_gemm);

if (num_gemms <= 0) {
return;
}

const int current_device = transformer_engine::cuda::current_device();
const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90);
const bool use_cutlass = transformer_engine::getenv<bool>("NVTE_USE_CUTLASS_GROUPED_GEMM", false);
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,10 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out));
}

if (te_A_wrappers.empty()) {
return bias;
}

// Keep the swizzled scaling factor tensors alive during the GEMM.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;

Expand Down