From 04269b845ee4493cc0d75e0696ada3582701248d Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 30 Apr 2026 00:21:52 +0000 Subject: [PATCH] Add workaround for cuteDSL stride requirement for zero token expert Signed-off-by: Kirthi Shankar Sivamani --- .../pytorch/ops/fused/backward_grouped_mlp.py | 44 ++++++++++++++----- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 510fea0edd..0921d523cc 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -58,19 +58,41 @@ def _cudnn_compute_wgrad( fp8_dtype = torch.float8_e4m3fn - # a_tensor = DY^T = (out_features, total_tokens) row-major - a_tensor = grouped_dy.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, out_features).T - # b_tensor = X = (total_tokens, in_features) column-major - b_tensor = grouped_x.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, in_features) - sfa_leading_dim = ((out_features + 127) // 128) * 128 sfb_leading_dim = ((in_features + 127) // 128) * 128 - sfa_tensor = grouped_dy.columnwise_scale_inv.view(sfa_leading_dim, -1).view( - dtype=torch.float8_e8m0fnu - ) - sfb_tensor = grouped_x.columnwise_scale_inv.view(sfb_leading_dim, -1).view( - dtype=torch.float8_e8m0fnu - ) + + if total_tokens == 0: + # A workaround for the case with zero-token experts. + # Even for this case, cuteDSL still requires the same + # stride requirements for the input and scale tensors. + device = grouped_dy.columnwise_data.device + a_tensor = torch.empty_strided((out_features, 0), (16, 1), dtype=fp8_dtype, device=device) + b_tensor = torch.empty_strided( + (0, in_features), (in_features, 1), dtype=fp8_dtype, device=device + ) + sfa_tensor = torch.empty_strided( + (sfa_leading_dim, 0), + (16, 1), + dtype=torch.float8_e8m0fnu, + device=device, + ) + sfb_tensor = torch.empty_strided( + (sfb_leading_dim, 0), + (16, 1), + dtype=torch.float8_e8m0fnu, + device=device, + ) + else: + a_tensor = ( + grouped_dy.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, out_features).T + ) + b_tensor = grouped_x.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, in_features) + sfa_tensor = grouped_dy.columnwise_scale_inv.view(sfa_leading_dim, -1).view( + dtype=torch.float8_e8m0fnu + ) + sfb_tensor = grouped_x.columnwise_scale_inv.view(sfb_leading_dim, -1).view( + dtype=torch.float8_e8m0fnu + ) # Prepare wgrad output if single_grouped_weight: