[PyTorch] Make modules.GroupedLinear graph-safe#3038
Conversation
Greptile SummaryThis PR introduces a new cuBLASLt grouped GEMM path for
Confidence Score: 4/5Merge with caution: the new grouped-tensor path silently loses its CUDA-graph-safety guarantee when callers pass m_splits as a list rather than a CUDA tensor. The new transformer_engine/pytorch/module/grouped_linear.py — specifically the Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["GroupedLinear.forward(inp, m_splits)"] --> B["Normalize m_splits\n(list → CPU tensor)"]
B --> C{_is_grouped_tensor_path_supported?}
C -- "Yes\n(SM100+, MXFP8/BF16/FP16,\nenv var set)" --> D["_forward_grouped_tensor\n(GroupedTensor path)"]
C -- No --> E["Legacy path\ntex.split_quantize / general_grouped_gemm"]
D --> D1["tex.group_quantize\n(or _make_grouped_tensor)"]
D1 --> D2["general_grouped_gemm_for_grouped_tensor\n(TN layout — fprop)"]
D2 --> D3["Return out\n(CUDA-graph-safe iff m_splits is CUDA tensor)"]
E --> E1["tex.split_quantize\n(or torch.split)"]
E1 --> E2["general_grouped_gemm\n(per-group GEMM)"]
subgraph Backward grouped tensor path
B1["_backward_grouped_tensor"] --> B2["dgrad: general_grouped_gemm_for_grouped_tensor NN"]
B1 --> B3["bias grad: compute_grouped_dbias\nor bgrad_group_quantize"]
B1 --> B4{delay_wgrad_compute?}
B4 -- Yes --> B5["wgrad_store.put\ngrouped_x grouped_dy wgrad_list"]
B5 --> B6["backward_dw called later\n grouped_gemm_wgrad NT"]
B4 -- No --> B7["grouped_gemm_wgrad NT\ngeneral_grouped_gemm_for_grouped_tensor"]
end
Reviews (6): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
Signed-off-by: Xin Yao <xiny@nvidia.com>
d176247 to
698383e
Compare
|
/te-ci pytorch |
|
/te-ci pytorch |
There was a problem hiding this comment.
In the future, we should consider moving the grouped MLP tests from test_fusible_ops.py into this file.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Handle tensor splits in both legacy and graph-safe impls. Create weight grad tensors as subviews of a larger buffer. Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
| m_splits : torch.Tensor | ||
| Split sizes for the input tensor. |
There was a problem hiding this comment.
Going into the future, I think we should prefer passing in the splits as a torch.Tensor, even if they are on CPU. This makes the API more consistent. We need to support lists of ints for backward compatibility, but it should be considered deprecated.
timmoon10
left a comment
There was a problem hiding this comment.
LGTM, pending CI.
As a followup, we should consolidate the implementation so that we can reuse the same code in te.ops.GroupedLinear and te.GroupedLinear.
|
/te-ci pytorch |
| if not isinstance(m_splits, torch.Tensor): | ||
| # Convert list of ints to tensor for backward compatibility | ||
| m_splits = torch.tensor(m_splits, dtype=torch.int64, device="cpu") | ||
| elif m_splits.dtype != torch.int64: | ||
| m_splits = m_splits.to(dtype=torch.int64) | ||
| if m_splits.size() != (num_gemms,): | ||
| raise ValueError( | ||
| f"Number of splits ({len(m_splits)}) should match number of" | ||
| f" GEMMs ({self.num_gemms})." | ||
| f"Shape of splits tensor ({tuple(m_splits.size())}) " | ||
| f"does not match number of GEMMs ({num_gemms})." | ||
| ) |
There was a problem hiding this comment.
CPU→GPU sync silently breaks CUDA-graph capture on the new path
When NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM=1 is active and m_splits is supplied as a Python list (or as a CPU tensor), the normalization at line 1690 creates a CPU tensor. Inside _forward_grouped_tensor that CPU tensor is moved to the compute device via split_sizes = m_splits.to(device=device) (line 241). This CPU→GPU transfer cannot be captured by CUDA Graph, silently defeating the PR's primary graph-safety guarantee for any caller that still passes a list.
Since the legacy path (which correctly handles lists) is only a m_splits.tolist() call away, a simple guard here would prevent the silent mis-use: raise (or warn) when _is_grouped_tensor_path_supported() would return True but m_splits is not already on the compute device.
There was a problem hiding this comment.
If the splits are passed in as a list or on CPU, there's no hope of CUDA Graph capture anyways. A blocking H2D memcpy is the best we can do. Warning is also excessive, since CPU splits are perfectly valid (if suboptimal) when running without CUDA Graphs.
Description
Enable grouped quantization and cuBLASLt grouped gemm for
modules.GroupedLinearto benefit cases where cuteDSL fused grouped gemm is not available.Move grouped gemm and grouped linear related tests to a standalone file.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: