Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,41 @@ def _cudnn_compute_wgrad(

fp8_dtype = torch.float8_e4m3fn

# a_tensor = DY^T = (out_features, total_tokens) row-major
a_tensor = grouped_dy.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, out_features).T
# b_tensor = X = (total_tokens, in_features) column-major
b_tensor = grouped_x.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, in_features)

sfa_leading_dim = ((out_features + 127) // 128) * 128
sfb_leading_dim = ((in_features + 127) // 128) * 128
sfa_tensor = grouped_dy.columnwise_scale_inv.view(sfa_leading_dim, -1).view(
dtype=torch.float8_e8m0fnu
)
sfb_tensor = grouped_x.columnwise_scale_inv.view(sfb_leading_dim, -1).view(
dtype=torch.float8_e8m0fnu
)

if total_tokens == 0:
# A workaround for the case with zero-token experts.
# Even for this case, cuteDSL still requires the same
# stride requirements for the input and scale tensors.
Comment on lines +64 to +67
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing TODO for temporary workaround

The PR description states this workaround will be removed once the upstream fix lands in cutedsl, but the in-code comment has no corresponding TODO or issue-tracker reference. Without one, there's no actionable reminder to clean this up once the upstream fix is released.

Suggested change
if total_tokens == 0:
# A workaround for the case with zero-token experts.
# Even for this case, cuteDSL still requires the same
# stride requirements for the input and scale tensors.
if total_tokens == 0:
# TODO: Remove this workaround once cuteDSL relaxes stride
# divisibility requirements for zero-element tensors (tracked in
# <upstream issue link>).
# A workaround for the case with zero-token experts.
# Even for this case, cuteDSL still requires the same
# stride requirements for the input and scale tensors.

device = grouped_dy.columnwise_data.device
a_tensor = torch.empty_strided((out_features, 0), (16, 1), dtype=fp8_dtype, device=device)
b_tensor = torch.empty_strided(
(0, in_features), (in_features, 1), dtype=fp8_dtype, device=device
)
sfa_tensor = torch.empty_strided(
(sfa_leading_dim, 0),
(16, 1),
dtype=torch.float8_e8m0fnu,
device=device,
)
sfb_tensor = torch.empty_strided(
(sfb_leading_dim, 0),
(16, 1),
dtype=torch.float8_e8m0fnu,
device=device,
)
Comment on lines +69 to +84
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Hardcoded stride 16 undocumented

The value 16 is used as the leading stride for both a_tensor and the scale tensors (sfa_tensor, sfb_tensor) in the zero-token path, but there is no comment explaining why 16 specifically satisfies cuteDSL's divisibility requirement. In the non-zero path the leading stride of a_tensor is 1 (column-major after transpose), so this value is not derived from the tensor layout. If the cuteDSL requirement ever changes (e.g. requires 32 or 128 alignment), this silent constant will be wrong without any indication of why it was chosen. A brief comment citing the minimum stride constraint would make future maintenance safer.

else:
a_tensor = (
grouped_dy.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, out_features).T
)
b_tensor = grouped_x.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, in_features)
sfa_tensor = grouped_dy.columnwise_scale_inv.view(sfa_leading_dim, -1).view(
dtype=torch.float8_e8m0fnu
)
sfb_tensor = grouped_x.columnwise_scale_inv.view(sfb_leading_dim, -1).view(
dtype=torch.float8_e8m0fnu
)

# Prepare wgrad output
if single_grouped_weight:
Expand Down
Loading