[Pytorch][Common] Hybrid quantization#2817
Conversation
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds hybrid (per-direction) quantization support to TransformerEngine's PyTorch frontend, allowing rowwise and columnwise directions to use different quantization formats (e.g., MXFP8 rowwise + Float8 columnwise). It introduces
Confidence Score: 4/5Safe to merge for Hopper and Blackwell workflows; on non-Hopper, FSDP2 sharding assigns the wrong nominal shape to Float8 columnwise sub-storages, which propagates through subsequent all-gather cycles. The aten.split handler for Float8Tensor uses split_transpose_tensor.shape ([K, M/n]) as the nominal shard shape when _data is None, instead of the correct [M/n, K]. This wrong shape is picked up by _infer_shape inside fsdp_post_all_gather and baked into the assembled hybrid tensor's columnwise sub-storage on every FSDP2 iteration. Everything else in the PR is well-constructed. transformer_engine/pytorch/tensor/float8_tensor.py (aten.split shape for columnwise-only path) and transformer_engine/pytorch/cpp_extensions/gemm.py (_unwrap_hybrid_A/B None passthrough). Important Files Changed
Reviews (7): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
timmoon10
left a comment
There was a problem hiding this comment.
Overall I think this moves us in a good direction. I see some minor bugs, as well as bugs reported by @greptile-apps.
| rowwise_result = self.rowwise_quantizer.quantize(tensor) | ||
| columnwise_result = self.columnwise_quantizer.quantize(tensor) |
There was a problem hiding this comment.
Do we handle the case where not all usages are needed? I'd expect something like:
| rowwise_result = self.rowwise_quantizer.quantize(tensor) | |
| columnwise_result = self.columnwise_quantizer.quantize(tensor) | |
| rowwise_result = self.rowwise_quantizer.quantize(tensor) if self.rowwise_usage else None | |
| columnwise_result = self.columnwise_quantizer.quantize(tensor) if self.columnwise_usage else None |
| requires_grad: bool = False, | ||
| pin_memory: bool = False, | ||
| ) -> HybridQuantizedTensor: | ||
| self.rowwise_quantizer.internal = True |
There was a problem hiding this comment.
Could we just set internal=True in the constructor? I don't think we ever need PyTorch tensor functionality in the per-usage data.
There was a problem hiding this comment.
This would not work under FSDP2.
| def factory(role): | ||
| if role == "linear_weight": | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_fp8_quantizer(), | ||
| columnwise_quantizer=_make_mxfp8_quantizer(), | ||
| ) | ||
| if role == "linear_input": | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_fp8_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| if role in ("linear_grad_output", "linear_grad_input"): | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_mxfp8_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| return None |
There was a problem hiding this comment.
This is horrifying. Good test.
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
| outs = [ | ||
| Float8Tensor.make_like( | ||
| tensor, | ||
| data=split_tensor, | ||
| data_transpose=split_transpose_tensor, | ||
| shape=split_tensor.shape, | ||
| shape=( | ||
| split_tensor.shape | ||
| if split_tensor is not None | ||
| else split_transpose_tensor.shape | ||
| ), | ||
| ) | ||
| for split_tensor, split_transpose_tensor in zip(func_out, t_func_out) | ||
| ] |
There was a problem hiding this comment.
When
_data is None (columnwise-only sub-storage of a HybridQuantizedTensor on non-Hopper), the split falls back to split_transpose_tensor.shape, which is the transposed layout's shape [K, M/n]. The correct nominal shape for the shard is [M/n, K]. This wrong nominal shape propagates into the HybridQuantizedTensor through fsdp_post_all_gather (which calls _infer_shape on the gathered _transpose buffer to build col_sub), so after the first FSDP2 iteration the assembled full-parameter hybrid's _columnwise_storage reports [K, M] instead of [M, K]. Any Python-side code that calls .size() on that sub-storage (e.g., HybridQuantizedTensorStorage.size() when rowwise is also None, workspace-validity checks, debugging assertions) will see the wrong dimensions.
| outs = [ | |
| Float8Tensor.make_like( | |
| tensor, | |
| data=split_tensor, | |
| data_transpose=split_transpose_tensor, | |
| shape=split_tensor.shape, | |
| shape=( | |
| split_tensor.shape | |
| if split_tensor is not None | |
| else split_transpose_tensor.shape | |
| ), | |
| ) | |
| for split_tensor, split_transpose_tensor in zip(func_out, t_func_out) | |
| ] | |
| outs = [ | |
| Float8Tensor.make_like( | |
| tensor, | |
| data=split_tensor, | |
| data_transpose=split_transpose_tensor, | |
| shape=( | |
| split_tensor.shape | |
| if split_tensor is not None | |
| # _transpose has shape [K, M/n] but the shard's nominal shape | |
| # is [M/n, K]. Recover the correct shard shape by reversing | |
| # the last two dims of the transposed piece. | |
| else (*split_transpose_tensor.shape[1:], split_transpose_tensor.shape[0]) | |
| ), | |
| ) | |
| for split_tensor, split_transpose_tensor in zip(func_out, t_func_out) | |
| ] |
Description
Hybrid (per-direction) quantization. Functional.
C++ optimizations (fusions, etc.) will come in the next PRs.
TODO: double quantization
Integration
Ecosystem integration (all functional, unit-tested):
Megatron-LM integration status:
--fp{4,8}-param-gather+ dist opt (persistent low-precision params viaquantized_model_init+ sharded-master FP32 → quantized cast viaquantize_master_weights.)- [Done] Per-tensor Float8 hybrid (delayed and/or current, any per-direction combination
including same-format, cross-format Float8, single-direction)
- [TODO] Per-block hybrid sub-quantizers (MXFP8, NVFP4, Float8Blockwise) — each rejected per-direction by
quantize_master_weights; unblocker is TE-side cast-helper / kernel.--fp{4,8}-param-gather(fix private attribute access)--fp{4,8}-param-gather- [Done] TE-side hybrid FSDP2 path works end-to-end for Float8 / MXFP8 / Float8Blockwise sub-storages (TODO: need some minor MLM update)
- [TODO] NVFP4 sub-storage FSDP2 hooks
_hybrid_split_quantizeunder Megatron MoE)Review
Total diff +9000
New hybrid source (
hybrid_tensor.py,hybrid_tensor_storage.py) ~1000Adjacent modifications ~1000
Tests are the rest
Surface to review is ~2000 lines
Suggested reading order
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: