Skip to content

[Pytorch][Common] Hybrid quantization#2817

Open
negvet wants to merge 19 commits into
NVIDIA:mainfrom
negvet:hybrid_quantization
Open

[Pytorch][Common] Hybrid quantization#2817
negvet wants to merge 19 commits into
NVIDIA:mainfrom
negvet:hybrid_quantization

Conversation

@negvet
Copy link
Copy Markdown
Collaborator

@negvet negvet commented Mar 31, 2026

Description

Hybrid (per-direction) quantization. Functional.
C++ optimizations (fusions, etc.) will come in the next PRs.

TODO: double quantization

Integration

Ecosystem integration (all functional, unit-tested):

  • [Done] quantized_model_init
  • [Done] FSDP2 (TODO: optimize communication buffers)
  • [Done] CPU offloading
  • [Done] Activation recomputation
  • [Done] TP/SP (TODO: enable quantized AG)

Megatron-LM integration status:

  • [Done] 1 GPU baseline
  • [Done] DP + distributed optimizer
  • [TODO] quantized_model_init + --fp{4,8}-param-gather + dist opt (persistent low-precision params via quantized_model_init + sharded-master FP32 → quantized cast via quantize_master_weights.)
    - [Done] Per-tensor Float8 hybrid (delayed and/or current, any per-direction combination
    including same-format, cross-format Float8, single-direction)
    - [TODO] Per-block hybrid sub-quantizers (MXFP8, NVFP4, Float8Blockwise) — each rejected per-direction by quantize_master_weights; unblocker is TE-side cast-helper / kernel.
  • [TODO] Megatron-FSDP + --fp{4,8}-param-gather (fix private attribute access)
  • [TODO] Torch FSDP2 + --fp{4,8}-param-gather
    - [Done] TE-side hybrid FSDP2 path works end-to-end for Float8 / MXFP8 / Float8Blockwise sub-storages (TODO: need some minor MLM update)
    - [TODO] NVFP4 sub-storage FSDP2 hooks
  • [Done] Activation recompute
  • [Done] CPU offload
  • [Done] TP/SP/PP
  • [Done] MoE + EP + grouped GEMM (qwen3 MoE; _hybrid_split_quantize under Megatron MoE)

Review

Total diff +9000
New hybrid source (hybrid_tensor.py, hybrid_tensor_storage.py) ~1000
Adjacent modifications ~1000
Tests are the rest

Surface to review is ~2000 lines

Suggested reading order

  1. Foundation — 7553e6a: Python containers + quantize/gemm dispatch/unwrap
  • tensor/hybrid_tensor.py — HybridQuantizer + HybridQuantizedTensor
  • tensor/storage/hybrid_tensor_storage.py
  • cpp_extensions/gemm.py — _unwrap_hybrid_A/B
  • common/transpose/quantize_transpose_square_blockwise.cu - Block FP8 columnwise-only null-checks
  • Module hooks in module/{base,grouped_linear,layernorm_linear,layernorm_mlp}.py
  • Tests: TestHybridQuantizer*, TestHybridGemmBitwiseIdentical* (proves zero-overhead vs vanilla recipes when both formats match), TestHybridDirectionUnwrap*, TestHybridGroupedLinear*
  1. quantized_model_init + FusedAdam — f80f5d0
  • hybrid_tensor.py::HybridQuantizer.update_quantized — delegates to each sub-quantizer; unblocks workspace-cache quantize_() and FusedAdam writeback
  • module/base.py workspace-cache invalidation
  • Tests: TestHybridQuantizedModelInit, TestHybridFusedAdam, TestHybridQuantizedParamsEndToEnd, TestHybridCheckpoint, TestQuantizedParamsEquivalence*
  1. FSDP2 support — 2185b30
  • New base FSDP2 buffer protocol on QuantizedTensorStorage: fsdp_buffer_fields / fsdp_extract_buffers / fsdp_assign_gathered. Generic, reusable beyond hybrid.
  • Per-format overrides on Float8TensorStorage (direction-aware) and MXFP8TensorStorage (trips/re-applies scale alignment padding around the gather)
  • hybrid_tensor.py::fsdp_pre/post_all_gather + torch_dispatch for the FSDP2 op set (view, split, as_strided, slice, copy_, new_zeros, clone, detach)
  • Non-safety in float8_tensor.py and mxfp8_tensor.py for single-direction sub-storages (columnwise-only on Hopper/L40)
  • Tests: TestHybridTorchDispatchFSDP2Ops, TestHybridFsdpPreAllGatherProtocol, TestHybridFsdpRoundtrip (bitwise-exact against manual all_gather(dequantize(shard))), plus tests/pytorch/distributed/fsdp2_tests/
  1. CPU offloading — 103fffe
  • hybrid_tensor_storage.py::clear() (v1 path) + prepare_for_saving / restore_from_saved chain (v2 path)
  • hybrid_tensor.py::detach() re-wraps each sub-storage via make_like (required by cpu_offload_v2's detach → prepare_for_saving pattern; sharing sub-storage objects would null-out fields on the original)
  • TestHybridCpuOffloadPushPop, plus updates to test_cpu_offloading*.py
  1. Activation recomputation — 16fb371
  • Uses existing QuantizedTensorStorage::prepare_for_saving / restore_from_saved protocol, preserving ordering across both sub-storages
  • Tests: 20 bitwise tests in TestHybridActivationRecompute
  1. TP/SP — a50fd63
  • hybrid_tensor.py::HybridQuantizer.supports_only_rowwise_all_gather — overrides to handle the NVFP4 columnwise-dequantize gap in the BF16 fallback path
  • distributed.py::gather_along_first_dim — hybrid branch re-quantizes with both directions after AG (since hybrid has no _create_transpose synthesis path)
  • Tests: 9 distributed tests in run_hybrid_tp_sp.py / test_hybrid_tp_sp.py
  1. Megatron-LM integration — a164cd3
  • tensor/utils.py::_route_hybrid_to_buckets — per-direction dispatch for quantize_master_weights: iterates both sub-storages, routes each independently into the per-format bucket matching its own sub-quantizer type
  • Hybrid branches in replace_raw_data and post_all_gather_processing
  • Today: per-tensor Float8 sub-quantizers (delayed + current) work in any per-direction combination. Per-block sub-quantizers raise per-direction with in-code TODOs naming the unblocker.
  • Tests: TestHybridQuantizeMasterWeights, TestHybridPostAllGatherProcessing

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 31, 2026

Greptile Summary

This PR adds hybrid (per-direction) quantization support to TransformerEngine's PyTorch frontend, allowing rowwise and columnwise directions to use different quantization formats (e.g., MXFP8 rowwise + Float8 columnwise). It introduces HybridQuantizedTensor, HybridQuantizer, and HybridQuantizedTensorStorage, plus FSDP2, CPU-offload, activation-recompute, and TP/SP integration.

  • New tensor types (HybridQuantizedTensor, HybridQuantizedTensorStorage, HybridQuantizer) support two-pass quantization, pickle/reduce, and torch-dispatch for FSDP2 ops (split, view, clone, copy_, new_zeros, as_strided, slice).
  • GroupedLinear gains _hybrid_split_quantize which calls tex.split_quantize twice (once per direction) and reassembles per-GEMM HybridQuantizedTensorStorage objects, with a classifier that rejects mixed HybridQuantizer/non-hybrid lists.
  • FSDP2 protocol is extended on Float8TensorStorage and MXFP8TensorStorage with fsdp_buffer_fields, fsdp_extract_buffers, and fsdp_assign_gathered; the hybrid tensor delegates to each sub-storage's implementation.

Confidence Score: 4/5

Safe to merge for Hopper and Blackwell workflows; on non-Hopper, FSDP2 sharding assigns the wrong nominal shape to Float8 columnwise sub-storages, which propagates through subsequent all-gather cycles.

The aten.split handler for Float8Tensor uses split_transpose_tensor.shape ([K, M/n]) as the nominal shard shape when _data is None, instead of the correct [M/n, K]. This wrong shape is picked up by _infer_shape inside fsdp_post_all_gather and baked into the assembled hybrid tensor's columnwise sub-storage on every FSDP2 iteration. Everything else in the PR is well-constructed.

transformer_engine/pytorch/tensor/float8_tensor.py (aten.split shape for columnwise-only path) and transformer_engine/pytorch/cpp_extensions/gemm.py (_unwrap_hybrid_A/B None passthrough).

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/hybrid_tensor.py New file: HybridQuantizer and HybridQuantizedTensor. Core logic is clean; FSDP2 protocol (fsdp_pre/post_all_gather), aten dispatch, and TP/SP gather_along_first_dim override all look well-designed.
transformer_engine/pytorch/tensor/float8_tensor.py Modified clone() and aten.split to handle _data=None for columnwise-only sub-storages. The clone fix is correct. The split fix has a shape bug: when _data=None and _transpose drives the split, it uses split_transpose_tensor.shape ([K, M/n]) as the nominal shape instead of the correct shard shape ([M/n, K]).
transformer_engine/pytorch/cpp_extensions/gemm.py Adds _unwrap_hybrid_A/B to extract the direction-appropriate sub-storage before passing to C++ GEMM. Direction mapping is correct. Missing guard: returns None silently if the required sub-storage was dropped.
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Adds fsdp_buffer_fields (direction-aware: returns _transpose when _data=None), fsdp_assign_gathered (clears _transpose_invalid after writing), and a _create_transpose early-return guard for _data=None.
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py Adds fsdp_buffer_fields, fsdp_extract_buffers (strips MXFP8 block-scale alignment padding before gather), and fsdp_assign_gathered (re-pads to 128/4 after gather). Logic is consistent with the existing split padding in mxfp8_tensor.py.
transformer_engine/pytorch/module/grouped_linear.py Adds _is_hybrid_quantizer_list classifier and _hybrid_split_quantize helper. Correctly rejects None+Hybrid mixed lists. All three forward/backward call sites guarded correctly.
transformer_engine/pytorch/tensor/utils.py Adds replace_raw_data for HybridQuantizedTensor, quantize_master_weights routing via _route_hybrid_to_buckets (per-direction Float8 sub-storages each enter their own bucket), and post_all_gather_processing recursion over Float8 sub-storages.
transformer_engine/pytorch/distributed.py Adds hybrid override in gather_along_first_dim: temporarily sets quantizer to both=True, re-quantizes the BF16 AG output, then restores usage flags via try/finally.

Reviews (7): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/grouped_linear.py
Comment thread transformer_engine/pytorch/tensor/storage/hybrid_tensor_storage.py
Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py Outdated
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think this moves us in a good direction. I see some minor bugs, as well as bugs reported by @greptile-apps.

Comment on lines +52 to +53
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we handle the case where not all usages are needed? I'd expect something like:

Suggested change
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)
rowwise_result = self.rowwise_quantizer.quantize(tensor) if self.rowwise_usage else None
columnwise_result = self.columnwise_quantizer.quantize(tensor) if self.columnwise_usage else None

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

requires_grad: bool = False,
pin_memory: bool = False,
) -> HybridQuantizedTensor:
self.rowwise_quantizer.internal = True
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just set internal=True in the constructor? I don't think we ever need PyTorch tensor functionality in the per-usage data.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would not work under FSDP2.

Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py Outdated
Comment on lines +1339 to +1355
def factory(role):
if role == "linear_weight":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_mxfp8_quantizer(),
)
if role == "linear_input":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
if role in ("linear_grad_output", "linear_grad_input"):
return HybridQuantizer(
rowwise_quantizer=_make_mxfp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
return None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is horrifying. Good test.

negvet and others added 10 commits April 6, 2026 10:26
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py
negvet and others added 2 commits April 29, 2026 16:02
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Comment thread transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py
negvet added 3 commits May 13, 2026 12:34
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet negvet requested a review from ksivaman as a code owner May 21, 2026 13:53
Comment on lines 665 to 677
outs = [
Float8Tensor.make_like(
tensor,
data=split_tensor,
data_transpose=split_transpose_tensor,
shape=split_tensor.shape,
shape=(
split_tensor.shape
if split_tensor is not None
else split_transpose_tensor.shape
),
)
for split_tensor, split_transpose_tensor in zip(func_out, t_func_out)
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 When _data is None (columnwise-only sub-storage of a HybridQuantizedTensor on non-Hopper), the split falls back to split_transpose_tensor.shape, which is the transposed layout's shape [K, M/n]. The correct nominal shape for the shard is [M/n, K]. This wrong nominal shape propagates into the HybridQuantizedTensor through fsdp_post_all_gather (which calls _infer_shape on the gathered _transpose buffer to build col_sub), so after the first FSDP2 iteration the assembled full-parameter hybrid's _columnwise_storage reports [K, M] instead of [M, K]. Any Python-side code that calls .size() on that sub-storage (e.g., HybridQuantizedTensorStorage.size() when rowwise is also None, workspace-validity checks, debugging assertions) will see the wrong dimensions.

Suggested change
outs = [
Float8Tensor.make_like(
tensor,
data=split_tensor,
data_transpose=split_transpose_tensor,
shape=split_tensor.shape,
shape=(
split_tensor.shape
if split_tensor is not None
else split_transpose_tensor.shape
),
)
for split_tensor, split_transpose_tensor in zip(func_out, t_func_out)
]
outs = [
Float8Tensor.make_like(
tensor,
data=split_tensor,
data_transpose=split_transpose_tensor,
shape=(
split_tensor.shape
if split_tensor is not None
# _transpose has shape [K, M/n] but the shard's nominal shape
# is [M/n, K]. Recover the correct shard shape by reversing
# the last two dims of the transposed piece.
else (*split_transpose_tensor.shape[1:], split_transpose_tensor.shape[0])
),
)
for split_tensor, split_transpose_tensor in zip(func_out, t_func_out)
]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants