[Pytorch] Add QuantizedTensor support in FusedAdam.step for MXFP8BlockScaling and Float8BlockScaling quantized model init.#2753
Conversation
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
Greptile SummaryThis PR adds Key changes:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Trainer
participant FusedAdam
participant AdamKernel as multi_tensor_adam (FP32)
participant QTensor as QuantizedTensor<br/>(MXFP8/Float8Block/NVFP4)
participant MasterW as FP32 Master Weight
Trainer->>FusedAdam: step()
FusedAdam->>FusedAdam: For each QuantizedTensor param:<br/>• assert master_weights=True<br/>• assert not capturable<br/>• cast grad to float32
FusedAdam->>MasterW: append master_param.data → p_f32_model
FusedAdam->>AdamKernel: apply_multi_tensor_adam([grad_f32, master_param, exp_avg, exp_avg_sq])
AdamKernel-->>MasterW: update FP32 master weight in-place
loop Post-step writeback
FusedAdam->>QTensor: local_p.quantize_(master_w.data)
QTensor-->>QTensor: re-quantize updated FP32 values
end
FusedAdam-->>Trainer: return loss
note over QTensor,MasterW: FSDP2 path (Float8BlockwiseQTensor)
Trainer->>QTensor: fsdp_pre_all_gather()
QTensor->>QTensor: transpose columnwise (K,M)→(M,K)<br/>unpad scale_inv dim1
QTensor-->>Trainer: (rw_data, rw_scale [, cw_data, cw_scale]), metadata
Trainer->>Trainer: FSDP2 all-gather across ranks
Trainer->>QTensor: fsdp_post_all_gather()
QTensor->>QTensor: transpose columnwise (full_M,K)→(K,full_M)<br/>repad scale_inv to multiple of 4
QTensor-->>Trainer: Float8BlockwiseQTensor (unsharded), all_gather_outputs
Last reviewed commit: b8e61e6 |
| elif isinstance(p, QuantizedTensor) or ( | ||
| isinstance(p, DTensor) and isinstance(p._local_tensor, QuantizedTensor) | ||
| ): | ||
| # Block-scaling quantized params (MXFP8Tensor, Float8BlockwiseQTensor, | ||
| # NVFP4Tensor). Operate on FP32 master weights, requantize back after | ||
| # Adam update. | ||
| if not self.master_weights: | ||
| local_p = p._local_tensor if isinstance(p, DTensor) else p | ||
| raise RuntimeError( | ||
| "FusedAdam without master_weights does not support " | ||
| f"{type(local_p).__name__} parameters. Use master_weights=True." | ||
| ) | ||
| # Route to the FP32 master-weight path: Adam updates the FP32 master, | ||
| # then we write back to the quantized param after kernels run. | ||
| p_f32_model.append(unscaled_state["master_param"].data) | ||
| g_of_f32_model.append(p_grad.data) | ||
| m_of_f32_model.append(unscaled_state["exp_avg"]) | ||
| v_of_f32_model.append(unscaled_state["exp_avg_sq"]) | ||
| quantized_params_to_update.append((p, unscaled_state["master_param"])) |
There was a problem hiding this comment.
Missing capturable=True guard for QuantizedTensor path
The existing code at line 664 raises a RuntimeError when capturable=True is combined with Float8Tensor (FP8) params. No equivalent guard was added for the new block-scaling QuantizedTensor path.
When capturable=True and block-scaling quantized params are present, the code routes their FP32 master weights through p_f32_model/g_of_f32_model (which is handled by the capturable Adam kernel), then calls quantize_ outside the CUDA graph in the writeback loop. If the CUDA graph captures the Adam kernel launch, the quantize_ writes back to the quantized tensor after graph execution — which breaks the captured computation graph semantics and may silently produce incorrect parameter values or cause memory violations.
A guard consistent with the FP8 check should be added inside the parameter loop:
# After: quantized_params_to_update.append((p, unscaled_state["master_param"]))
if self.capturable and len(quantized_params_to_update) > 0:
raise RuntimeError(
"FusedAdam does not support block-scaling quantized weights with capturable=True."
)| has_meta_params = any(p.is_meta for p in model.parameters()) | ||
| custom_attrs = save_custom_attrs(model) | ||
| mesh = DeviceMesh("cuda", list(range(world_size))) | ||
| for child in model.children(): | ||
| fully_shard(child, mesh=mesh) | ||
| fully_shard(model, mesh=mesh) | ||
| restore_custom_attrs(model, custom_attrs) | ||
| if has_meta_params: | ||
| for module in model.modules(): | ||
| if hasattr(module, "reset_parameters"): | ||
| module.reset_parameters() | ||
| return model |
There was a problem hiding this comment.
reset_parameters() called after restore_custom_attrs() may overwrite quantizer metadata
restore_custom_attrs is called first to re-attach attributes (e.g. custom quantization metadata saved from the pre-sharding meta-device params) to the post-fully_shard DTensor parameters. reset_parameters() is then called to materialize the parameter values from the meta device.
However, save_custom_attrs explicitly ignores all private _-prefixed attributes for QuantizedTensor params:
if isinstance(param, QuantizedTensor):
ignore_keys = [key for key in param.__dict__.keys() if key.startswith("_")]Since most of the critical quantization state lives in private attributes (_quantizer, _fp8_dtype, _scale_inv, _data, etc.), those are not saved or restored. The only attributes that are saved are public ones — which may be empty or meaningless on a meta-device tensor. If reset_parameters() sets up these private quantizer attributes correctly (as part of actual initialization), the ordering is fine. However, if any correctness depends on public attrs being visible to reset_parameters() during its execution, this order is fragile.
Please confirm that:
reset_parameters()inte.TransformerLayerfully sets up quantizer metadata from scratch (not relying on anythingrestore_custom_attrssets), or- Add a comment documenting the intended ordering and why
restore_custom_attrsbeforereset_parametersis safe/necessary.
There was a problem hiding this comment.
I addressed this
….capturable Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
6461d34 to
eb31beb
Compare
| p_f32_model.append(unscaled_state["master_param"].data) | ||
| g_of_f32_model.append(p_grad.data) | ||
| m_of_f32_model.append(unscaled_state["exp_avg"]) | ||
| v_of_f32_model.append(unscaled_state["exp_avg_sq"]) | ||
| quantized_params_to_update.append((p, unscaled_state["master_param"])) |
There was a problem hiding this comment.
Potential gradient dtype mismatch with the F32 Adam kernel
p_grad.data is appended to g_of_f32_model, which is consumed by multi_tensor_adam (the pure FP32 Adam kernel). For block-scaling quantized parameters in a standard BF16 training loop the backward pass produces BF16 gradients, not FP32 ones.
In contrast, the per-tensor Float8Tensor path deliberately places gradients in g_of_fp8_model and dispatches to the specialised multi_tensor_adam_fp8 kernel that is designed to accept mixed-precision inputs.
If multi_tensor_adam requires all tensors in a list to be FP32, this will either raise a CUDA error or silently compute with truncated precision. Consider either:
- asserting / explicitly casting
p_grad.datato FP32 before appending, or - documenting that
multi_tensor_adamis verified to accept BF16 gradients mixed with FP32 params.
# Option A – explicit cast (safe, matches what DelayedScaling already does internally)
g_of_f32_model.append(p_grad.data.float())There was a problem hiding this comment.
@jomitchellnv this comment makes sense since grads can be bf16. Data can be misinterpreted in that case. Can we add this?
There was a problem hiding this comment.
Good catch. The backward pass produces gradients in the compute dtype (typically BF16), but this path appends them to g_of_f32_model which feeds into the FP32 multi_tensor_adam kernel. Fixed by explicitly casting:
g_of_f32_model.append(p_grad.data.float())
This matches the intent — master weights, exp_avg, and exp_avg_sq are all FP32, so the gradients should be too. The existing per-tensor Float8Tensor path doesn't need this because it uses a dedicated mixed-precision FP8 kernel (multi_tensor_adam_fp8), but our QuantizedTensor path routes through the plain FP32 kernel.
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
3383496 to
c04d866
Compare
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
462de10 to
6c1e0f9
Compare
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
b23898f to
b2738e5
Compare
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
12e51ed to
0518eea
Compare
| ): | ||
| # Block-scaling quantized params (MXFP8Tensor, Float8BlockwiseQTensor, | ||
| # NVFP4Tensor). Operate on FP32 master weights, requantize back after | ||
| # Adam update. |
There was a problem hiding this comment.
Minor Comment: can we add TODO
Adam calculation and Quantization are unfused for now for these block-scaling parameters. Fusion needs to be done later.
Its just that the optimizer is called "fused adam" and we are indeed doing unfused adam for these 3 cases. So we know it needs to be done for all 3 of the cases.
| from ._quantization_helpers import _IdentityFunc | ||
| from ..utils import devices_match, round_up_to_nearest_multiple | ||
|
|
||
| try: |
There was a problem hiding this comment.
yes ill remove it.
vthumbe1503
left a comment
There was a problem hiding this comment.
Thanks for the PR. LGTM. Left a few minor questions/comments. Lets get this merged post CI sucess.
|
/te-ci L1 pytorch |
| """FusedAdam(master_weights=True) + FSDP2 + quantized_model_init.""" | ||
| if fp_recipe in ("Float8BlockScaling", "MXFP8BlockScaling", "NVFP4BlockScaling"): | ||
| """FusedAdam(master_weights=True) + FSDP2 + quantized_model_init (meta device init).""" | ||
| if fp_recipe in ("NVFP4BlockScaling",): |
There was a problem hiding this comment.
Could you please add a pytest.skip in case MXFP8/block-scaling support is missing? This test fails on H100 systems since it doesnt have MXFP8 suppport. The function to check support is already above I think. Its also done in a lot of other tests in the repo for you to refer.
There was a problem hiding this comment.
done!. I changed _parametrize_fp8_recipes() to use pytest.mark.skipif instead of pytest.mark.xfail for hardware support checks. On systems without MXFP8/block-scaling support, tests are now skipped outright instead of running the torchrun subprocess and failing noisily. Per-test xfail marks for known functional gaps (e.g., NVFP4 lacking FSDP2 hooks) remain unchanged since those are feature limitations, not hardware issues.
| """ | ||
| model, recipe = self._build_model() | ||
|
|
||
| # Verify weight params are QuantizedTensors (bias stays bf16) |
There was a problem hiding this comment.
Could you also please add a pytest skip if float8 blockscaling is missing? I see you have added it already to mxfp8 test. We need it here similarly to avoid ci test failures.
There was a problem hiding this comment.
Done — added te.is_fp8_block_scaling_available(return_reason=True) check in setup_method and pytest.skip guards in both test_float8block_linear_fused_adam_master_weights and test_float8block_linear_forward_backward_step, matching the MXFP8 pattern.
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
389a526 to
fe248fe
Compare
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
83e468d to
7726070
Compare
|
/te-ci L1 pytorch |
1 similar comment
|
/te-ci L1 pytorch |
Signed-off-by: Jonathan Mitchell <jomitchell@umb-b300-dp-147.ipp4a1.colossus.nvidia.com>
24c5649 to
7ffa352
Compare
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
84f215a to
4c1d280
Compare
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
2 similar comments
|
/te-ci L1 pytorch |
|
/te-ci L1 pytorch |
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
|
/te-ci L1 pytorch |
|
/te-ci L1 pytorch |
…kScaling and Float8BlockScaling quantized model init. (NVIDIA#2753) * Updates FusedAdam with FSDP2 and MXFP8 Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com> * removes xfailing unit test for MXFPr MXFP8 Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com> * addresses comments related to reset parameters and guard against self.capturable Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com> * adds e2e unit test Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com> * adds test to non meta device init Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com> * attempts to add float8block scaling fsdp hooks Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com> * adds e2e test for Float8BlockScaling Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com> * addresses review comments and code cleanup Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com> * more review comments addressed Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com> * removes unused block_len param Signed-off-by: Jonathan Mitchell <jomitchell@umb-b300-dp-147.ipp4a1.colossus.nvidia.com> * fixes failing unit test because we still need to xfail nvfp4 dcp Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * lint - replacing todo with note Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com> --------- Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com> Signed-off-by: Jonathan Mitchell <jomitchell@umb-b300-dp-147.ipp4a1.colossus.nvidia.com> Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com> Co-authored-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com> Co-authored-by: Jonathan Mitchell <jomitchell@umb-b300-dp-147.ipp4a1.colossus.nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: vthumbe1503 <vthumbe@nvidia.com> Co-authored-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com> Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
1. nvfp4_tensor.py — as_strided hardening + with_2d_quantization assertion 2. test_nvfp4_fsdp2_hooks.py — remove unused tex import 3. run_fsdp2_model.py — remove stale Float8BlockScaling xfail PR comments to leave: - greptile P1 (as_strided): Fixed — validates storage_offset == 0 and raises NotImplementedError for non-identity calls. - vthumbe1503 (2D quantization): Added assertion that with_2d_quantization == True when columnwise data is in the all-gather. - greptile P2 (unused import): Removed. - vthumbe1503 (Float8BlockScaling xfail): Correct, fixed in NVIDIA#2753 — removed the xfail Signed-off-by: Jonathan Mitchell <jomitchell@r6515-0097.ipp1a1.colossus.nvidia.com>
1. nvfp4_tensor.py — as_strided hardening + with_2d_quantization assertion 2. test_nvfp4_fsdp2_hooks.py — remove unused tex import 3. run_fsdp2_model.py — remove stale Float8BlockScaling xfail PR comments to leave: - greptile P1 (as_strided): Fixed — validates storage_offset == 0 and raises NotImplementedError for non-identity calls. - vthumbe1503 (2D quantization): Added assertion that with_2d_quantization == True when columnwise data is in the all-gather. - greptile P2 (unused import): Removed. - vthumbe1503 (Float8BlockScaling xfail): Correct, fixed in NVIDIA#2753 — removed the xfail Signed-off-by: Jonathan Mitchell <jomitchell@r6515-0097.ipp1a1.colossus.nvidia.com>
Summary
NVFP4) through the FP32 Adam kernel on master weights, then writes back via quantize_() after the optimizer step
training with Float8BlockScaling recipe
sharding), required because wrapper subclass tensors have data_ptr() == 0
Details
FusedAdam QuantizedTensor path (fused_adam.py):
and routes them to the FP32 master-weight Adam kernel
both rowwise and columnwise quantized data
CUDA-graph captured)
Float8BlockwiseQTensor FSDP2 hooks (float8_blockwise_tensor.py):
works correctly. Always sends both rowwise and columnwise (the GEMM kernel needs both forms since they have different shapes,
unlike MXFP8 which stores both in the same shape)
constructs/updates the tensor
reconstructed tensor via QuantizedTensor's default dispatch
Meta-device init for FSDP2 (run_fsdp2_fused_adam.py):
due to inaccessible storage
Tests:
and Float8BlockScaling. NVFP4BlockScaling remains xfailed (no FSDP2 hooks)
(xfailed)
forward/backward/step e2e test
Test plan
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: