Skip to content

[Pytorch] Add QuantizedTensor support in FusedAdam.step for MXFP8BlockScaling and Float8BlockScaling quantized model init.#2753

Merged
vthumbe1503 merged 15 commits intoNVIDIA:mainfrom
jomitchellnv:jm/fix-fused-adam-quantized-weights-fsdp2
Mar 13, 2026
Merged

[Pytorch] Add QuantizedTensor support in FusedAdam.step for MXFP8BlockScaling and Float8BlockScaling quantized model init.#2753
vthumbe1503 merged 15 commits intoNVIDIA:mainfrom
jomitchellnv:jm/fix-fused-adam-quantized-weights-fsdp2

Conversation

@jomitchellnv
Copy link
Copy Markdown
Contributor

@jomitchellnv jomitchellnv commented Mar 11, 2026

Summary

  • Add a QuantizedTensor code path in FusedAdam.step() that routes block-scaling quantized parameters (MXFP8, Float8Blockwise,
    NVFP4) through the FP32 Adam kernel on master weights, then writes back via quantize_() after the optimizer step
  • Add FSDP2 all-gather hooks (fsdp_pre_all_gather / fsdp_post_all_gather) to Float8BlockwiseQTensor, enabling end-to-end FSDP2
    training with Float8BlockScaling recipe
  • Add meta-device initialization pattern for FSDP2 + block-scaling quantized tensors (device="meta" + reset_parameters() after
    sharding), required because wrapper subclass tensors have data_ptr() == 0

Details

FusedAdam QuantizedTensor path (fused_adam.py):

  • After the existing Float8Tensor check (per-tensor FP8 → FP8 kernel), a new QuantizedTensor check catches block-scaling types
    and routes them to the FP32 master-weight Adam kernel
  • After all kernels run, a post-step loop calls local_p.quantize_(master_w.data) to write updated FP32 master weights back to
    both rowwise and columnwise quantized data
  • Raises RuntimeError if master_weights=False (required for block-scaling) or capturable=True (quantize writeback can't be
    CUDA-graph captured)

Float8BlockwiseQTensor FSDP2 hooks (float8_blockwise_tensor.py):

  • fsdp_pre_all_gather: Transposes columnwise data/scales from (K, M) to (M, K) before all-gather so FSDP2's dim0 concatenation
    works correctly. Always sends both rowwise and columnwise (the GEMM kernel needs both forms since they have different shapes,
    unlike MXFP8 which stores both in the same shape)
  • fsdp_post_all_gather: Transposes columnwise back to (K, M), repads scale_inv to GEMM alignment (multiples of 4),
    constructs/updates the tensor
  • Adds aten.as_strided and aten.slice no-op handlers in torch_dispatch to prevent FSDP2 from silently dequantizing the
    reconstructed tensor via QuantizedTensor's default dispatch
  • Only supports 2D block scaling (block_scaling_dim=2, the default for weights)

Meta-device init for FSDP2 (run_fsdp2_fused_adam.py):

  • _build_model() accepts use_meta_device param; creates on device="meta" when fp8_init=True
  • _shard_model() detects meta params and calls reset_parameters() after fully_shard(), then restore_custom_attrs()
  • This is the only viable path for block-scaling QuantizedTensors with FSDP2 — direct CUDA init crashes in reset_sharded_param()
    due to inaccessible storage

Tests:

  • Multi-GPU FSDP2: test_fsdp2_fused_adam_fp8_master_weights passes for DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling,
    and Float8BlockScaling. NVFP4BlockScaling remains xfailed (no FSDP2 hooks)
  • Multi-GPU FSDP2: test_fsdp2_fused_adam_fp8_master_weights_no_meta documents the CUDA-init failure for block-scaling types
    (xfailed)
  • Single-GPU: TestFusedAdamMXFP8 and TestFusedAdamFloat8Block — each with a synthetic-grad master-weight test and a full
    forward/backward/step e2e test

Test plan

  • pytest tests/pytorch/distributed/test_torch_fsdp2.py -v -k fused_adam (18 passed, 7 xfailed)
  • pytest tests/pytorch/test_fused_optimizer.py::TestFusedAdamMXFP8 -v (2 passed)
  • pytest tests/pytorch/test_fused_optimizer.py::TestFusedAdamFloat8Block -v (2 passed)
  • Verify no regressions in existing DelayedScaling / Float8CurrentScaling FSDP2 tests

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Jonathan Mitchell added 2 commits March 10, 2026 17:42
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 11, 2026

Greptile Summary

This PR adds QuantizedTensor support to FusedAdam.step() for block-scaling recipes (MXFP8BlockScaling, Float8BlockScaling, NVFP4BlockScaling), FSDP2 all-gather hooks to Float8BlockwiseQTensor, and a meta-device initialization pattern for FSDP2 + block-scaling quantized tensors.

Key changes:

  • fused_adam.py: A new QuantizedTensor branch in the optimizer's parameter dispatch loop routes block-scaling params through the FP32 multi_tensor_adam kernel on their master weights, casts BF16/FP16 gradients to float32 before the kernel (correctly matching the FP32 kernel expectations), and writes back via quantize_() after all kernels run. Guards for master_weights=False and capturable=True are present and correct.
  • float8_blockwise_tensor.py: fsdp_pre_all_gather transposes columnwise data/scales from (K, M)(M, K) and strips M-block padding so FSDP2's dim0 all-gather works correctly. fsdp_post_all_gather reverses the transpose and repads to multiples of 4 for GEMM alignment. Two new __torch_dispatch__ no-op handlers for aten.as_strided and aten.slice prevent FSDP2 from silently dequantizing the reconstructed tensor. An unused import (_get_module_fsdp_state) is left inside fsdp_pre_all_gather. The aten.slice handler uses a variable named length for what is actually the end argument of the ATen op, and neither new dispatch handler guards against storage_offset != 0 (mirrors existing MXFP8Tensor patterns).
  • run_fsdp2_fused_adam.py / test_torch_fsdp2.py: Meta-device init path is added and a new _no_meta test variant documents the known crash for block-scaling types without meta-device init. xfail handling is correctly extended to Float8BlockScaling and NVFP4BlockScaling in relevant distributed tests.
  • test_fused_optimizer.py: New single-GPU test classes TestFusedAdamMXFP8 and TestFusedAdamFloat8Block provide solid coverage: synthetic-gradient master-weight parity checks against reference Adam, and full end-to-end forward/backward/step tests with loss-decrease and state-dtype verification.

Confidence Score: 4/5

  • This PR is safe to merge — the core optimizer and FSDP2 hook logic is correct, previously flagged issues have been addressed, and the new functionality is well-tested.
  • The QuantizedTensor optimizer path is logically sound: gradient dtype cast to float32 is present, capturable guard is in place, and master-weight writeback via quantize_() happens after all kernels complete. The FSDP2 all-gather hooks correctly handle the (K,M)↔(M,K) transpose and scale_inv padding round-trip. Three minor style issues remain: an unused import in fsdp_pre_all_gather, a misleadingly named variable in the aten.slice handler, and a missing storage_offset check in the aten.as_strided handler (the latter two also exist in the existing MXFP8Tensor implementation and do not affect current test-passing behavior).
  • transformer_engine/pytorch/tensor/float8_blockwise_tensor.py — the new FSDP2 hooks warrant attention for the unused import and the dispatch handler edge-case gaps.

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py Adds FSDP2 pre/post all-gather hooks and two new __torch_dispatch__ no-op handlers for aten.as_strided and aten.slice. The all-gather logic correctly handles the (K,M)→(M,K) transpose for columnwise data and scale_inv padding/unpadding. Minor issues: unused _get_module_fsdp_state import, misleadingly named length variable (should be end), and as_strided handler doesn't check storage_offset.
transformer_engine/pytorch/optimizers/fused_adam.py Adds QuantizedTensor (MXFP8, Float8Blockwise, NVFP4) code path to FusedAdam.step(): routes master-weight Adam update through the FP32 kernel, explicitly casts BF16 gradients to float32, and writes back via quantize_(). Also adds correct capturable=True guard. The implementation is correct and the previously reviewed issues (capturable guard, gradient dtype cast) are both addressed.
tests/pytorch/distributed/run_fsdp2_fused_adam.py Adds meta-device model init path (use_meta_device=True) and a new test_fused_adam_fp8_master_weights_no_meta test that documents the CUDA-init crash for block-scaling types. The ordering of reset_parameters() before restore_custom_attrs() is correctly documented and justified.
tests/pytorch/distributed/test_torch_fsdp2.py Adds new test_fsdp2_fused_adam_fp8_master_weights_no_meta test, extends xfail guards to cover Float8BlockScaling and NVFP4BlockScaling in test_fsdp2_fused_adam_fp8_no_master, and updates the async DCP test's xfail logic. Also correctly changes parametrize from xfail to skipif for unsupported hardware.
tests/pytorch/test_fused_optimizer.py Adds thorough single-GPU test classes TestFusedAdamMXFP8 and TestFusedAdamFloat8Block. Each class has a synthetic-gradient master-weight comparison test and an end-to-end forward/backward/step test with loss-decrease verification and optimizer state dtype checks.

Sequence Diagram

sequenceDiagram
    participant Trainer
    participant FusedAdam
    participant AdamKernel as multi_tensor_adam (FP32)
    participant QTensor as QuantizedTensor<br/>(MXFP8/Float8Block/NVFP4)
    participant MasterW as FP32 Master Weight

    Trainer->>FusedAdam: step()
    FusedAdam->>FusedAdam: For each QuantizedTensor param:<br/>• assert master_weights=True<br/>• assert not capturable<br/>• cast grad to float32
    FusedAdam->>MasterW: append master_param.data → p_f32_model
    FusedAdam->>AdamKernel: apply_multi_tensor_adam([grad_f32, master_param, exp_avg, exp_avg_sq])
    AdamKernel-->>MasterW: update FP32 master weight in-place

    loop Post-step writeback
        FusedAdam->>QTensor: local_p.quantize_(master_w.data)
        QTensor-->>QTensor: re-quantize updated FP32 values
    end

    FusedAdam-->>Trainer: return loss

    note over QTensor,MasterW: FSDP2 path (Float8BlockwiseQTensor)

    Trainer->>QTensor: fsdp_pre_all_gather()
    QTensor->>QTensor: transpose columnwise (K,M)→(M,K)<br/>unpad scale_inv dim1
    QTensor-->>Trainer: (rw_data, rw_scale [, cw_data, cw_scale]), metadata

    Trainer->>Trainer: FSDP2 all-gather across ranks

    Trainer->>QTensor: fsdp_post_all_gather()
    QTensor->>QTensor: transpose columnwise (full_M,K)→(K,full_M)<br/>repad scale_inv to multiple of 4
    QTensor-->>Trainer: Float8BlockwiseQTensor (unsharded), all_gather_outputs
Loading

Last reviewed commit: b8e61e6

Comment on lines +626 to +644
elif isinstance(p, QuantizedTensor) or (
isinstance(p, DTensor) and isinstance(p._local_tensor, QuantizedTensor)
):
# Block-scaling quantized params (MXFP8Tensor, Float8BlockwiseQTensor,
# NVFP4Tensor). Operate on FP32 master weights, requantize back after
# Adam update.
if not self.master_weights:
local_p = p._local_tensor if isinstance(p, DTensor) else p
raise RuntimeError(
"FusedAdam without master_weights does not support "
f"{type(local_p).__name__} parameters. Use master_weights=True."
)
# Route to the FP32 master-weight path: Adam updates the FP32 master,
# then we write back to the quantized param after kernels run.
p_f32_model.append(unscaled_state["master_param"].data)
g_of_f32_model.append(p_grad.data)
m_of_f32_model.append(unscaled_state["exp_avg"])
v_of_f32_model.append(unscaled_state["exp_avg_sq"])
quantized_params_to_update.append((p, unscaled_state["master_param"]))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing capturable=True guard for QuantizedTensor path

The existing code at line 664 raises a RuntimeError when capturable=True is combined with Float8Tensor (FP8) params. No equivalent guard was added for the new block-scaling QuantizedTensor path.

When capturable=True and block-scaling quantized params are present, the code routes their FP32 master weights through p_f32_model/g_of_f32_model (which is handled by the capturable Adam kernel), then calls quantize_ outside the CUDA graph in the writeback loop. If the CUDA graph captures the Adam kernel launch, the quantize_ writes back to the quantized tensor after graph execution — which breaks the captured computation graph semantics and may silently produce incorrect parameter values or cause memory violations.

A guard consistent with the FP8 check should be added inside the parameter loop:

# After: quantized_params_to_update.append((p, unscaled_state["master_param"]))
if self.capturable and len(quantized_params_to_update) > 0:
    raise RuntimeError(
        "FusedAdam does not support block-scaling quantized weights with capturable=True."
    )

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Comment on lines +118 to 129
has_meta_params = any(p.is_meta for p in model.parameters())
custom_attrs = save_custom_attrs(model)
mesh = DeviceMesh("cuda", list(range(world_size)))
for child in model.children():
fully_shard(child, mesh=mesh)
fully_shard(model, mesh=mesh)
restore_custom_attrs(model, custom_attrs)
if has_meta_params:
for module in model.modules():
if hasattr(module, "reset_parameters"):
module.reset_parameters()
return model
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reset_parameters() called after restore_custom_attrs() may overwrite quantizer metadata

restore_custom_attrs is called first to re-attach attributes (e.g. custom quantization metadata saved from the pre-sharding meta-device params) to the post-fully_shard DTensor parameters. reset_parameters() is then called to materialize the parameter values from the meta device.

However, save_custom_attrs explicitly ignores all private _-prefixed attributes for QuantizedTensor params:

if isinstance(param, QuantizedTensor):
    ignore_keys = [key for key in param.__dict__.keys() if key.startswith("_")]

Since most of the critical quantization state lives in private attributes (_quantizer, _fp8_dtype, _scale_inv, _data, etc.), those are not saved or restored. The only attributes that are saved are public ones — which may be empty or meaningless on a meta-device tensor. If reset_parameters() sets up these private quantizer attributes correctly (as part of actual initialization), the ordering is fine. However, if any correctness depends on public attrs being visible to reset_parameters() during its execution, this order is fragile.

Please confirm that:

  1. reset_parameters() in te.TransformerLayer fully sets up quantizer metadata from scratch (not relying on anything restore_custom_attrs sets), or
  2. Add a comment documenting the intended ordering and why restore_custom_attrs before reset_parameters is safe/necessary.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed this

….capturable

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
@jomitchellnv jomitchellnv force-pushed the jm/fix-fused-adam-quantized-weights-fsdp2 branch from 6461d34 to eb31beb Compare March 11, 2026 01:31
Comment on lines +640 to +644
p_f32_model.append(unscaled_state["master_param"].data)
g_of_f32_model.append(p_grad.data)
m_of_f32_model.append(unscaled_state["exp_avg"])
v_of_f32_model.append(unscaled_state["exp_avg_sq"])
quantized_params_to_update.append((p, unscaled_state["master_param"]))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential gradient dtype mismatch with the F32 Adam kernel

p_grad.data is appended to g_of_f32_model, which is consumed by multi_tensor_adam (the pure FP32 Adam kernel). For block-scaling quantized parameters in a standard BF16 training loop the backward pass produces BF16 gradients, not FP32 ones.

In contrast, the per-tensor Float8Tensor path deliberately places gradients in g_of_fp8_model and dispatches to the specialised multi_tensor_adam_fp8 kernel that is designed to accept mixed-precision inputs.

If multi_tensor_adam requires all tensors in a list to be FP32, this will either raise a CUDA error or silently compute with truncated precision. Consider either:

  • asserting / explicitly casting p_grad.data to FP32 before appending, or
  • documenting that multi_tensor_adam is verified to accept BF16 gradients mixed with FP32 params.
# Option A – explicit cast (safe, matches what DelayedScaling already does internally)
g_of_f32_model.append(p_grad.data.float())

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jomitchellnv this comment makes sense since grads can be bf16. Data can be misinterpreted in that case. Can we add this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. The backward pass produces gradients in the compute dtype (typically BF16), but this path appends them to g_of_f32_model which feeds into the FP32 multi_tensor_adam kernel. Fixed by explicitly casting:

g_of_f32_model.append(p_grad.data.float())

This matches the intent — master weights, exp_avg, and exp_avg_sq are all FP32, so the gradients should be too. The existing per-tensor Float8Tensor path doesn't need this because it uses a dedicated mixed-precision FP8 kernel (multi_tensor_adam_fp8), but our QuantizedTensor path routes through the plain FP32 kernel.

@jomitchellnv
Copy link
Copy Markdown
Contributor Author

@greptileai

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
@jomitchellnv jomitchellnv force-pushed the jm/fix-fused-adam-quantized-weights-fsdp2 branch from 3383496 to c04d866 Compare March 11, 2026 01:57
@jomitchellnv jomitchellnv changed the title [DRAFT - Do not review yet] Fix fused adam quantized weights fsdp2 Add QuantizedTensor support in FusedAdam.step() for Block-scaling quantized paramters Mar 11, 2026
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
@jomitchellnv jomitchellnv force-pushed the jm/fix-fused-adam-quantized-weights-fsdp2 branch from 462de10 to 6c1e0f9 Compare March 11, 2026 02:30
@jomitchellnv jomitchellnv changed the title Add QuantizedTensor support in FusedAdam.step() for Block-scaling quantized paramters Add QuantizedTensor support in FusedAdam.step() for MXFP8BlockScaling quantized paramters Mar 11, 2026
@jomitchellnv jomitchellnv changed the title Add QuantizedTensor support in FusedAdam.step() for MXFP8BlockScaling quantized paramters [DRAFT/ DUP] Add QuantizedTensor support in FusedAdam.step() for MXFP8BlockScaling quantized paramters Mar 11, 2026
@jomitchellnv jomitchellnv changed the title [DRAFT/ DUP] Add QuantizedTensor support in FusedAdam.step() for MXFP8BlockScaling quantized paramters [DRAFT] Add QuantizedTensor support in FusedAdam.step() for MXFP8BlockScaling quantized paramters Mar 11, 2026
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
@jomitchellnv jomitchellnv force-pushed the jm/fix-fused-adam-quantized-weights-fsdp2 branch from b23898f to b2738e5 Compare March 11, 2026 03:33
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
@jomitchellnv jomitchellnv force-pushed the jm/fix-fused-adam-quantized-weights-fsdp2 branch from 12e51ed to 0518eea Compare March 11, 2026 03:47
@jomitchellnv jomitchellnv changed the title [DRAFT] Add QuantizedTensor support in FusedAdam.step() for MXFP8BlockScaling quantized paramters [DRAFT] Add QuantizedTensor support in FusedAdam.step for MXFP8BlockScaling and Float8BlockScaling quantized model init. Mar 11, 2026
):
# Block-scaling quantized params (MXFP8Tensor, Float8BlockwiseQTensor,
# NVFP4Tensor). Operate on FP32 master weights, requantize back after
# Adam update.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor Comment: can we add TODO
Adam calculation and Quantization are unfused for now for these block-scaling parameters. Fusion needs to be done later.

Its just that the optimizer is called "fused adam" and we are indeed doing unfused adam for these 3 cases. So we know it needs to be done for all 3 of the cases.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

from ._quantization_helpers import _IdentityFunc
from ..utils import devices_match, round_up_to_nearest_multiple

try:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this dead code?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes ill remove it.

vthumbe1503
vthumbe1503 previously approved these changes Mar 11, 2026
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. LGTM. Left a few minor questions/comments. Lets get this merged post CI sucess.

@vthumbe1503
Copy link
Copy Markdown
Collaborator

/te-ci L1 pytorch

@vthumbe1503 vthumbe1503 changed the title [DRAFT] Add QuantizedTensor support in FusedAdam.step for MXFP8BlockScaling and Float8BlockScaling quantized model init. [Pytorch] Add QuantizedTensor support in FusedAdam.step for MXFP8BlockScaling and Float8BlockScaling quantized model init. Mar 11, 2026
"""FusedAdam(master_weights=True) + FSDP2 + quantized_model_init."""
if fp_recipe in ("Float8BlockScaling", "MXFP8BlockScaling", "NVFP4BlockScaling"):
"""FusedAdam(master_weights=True) + FSDP2 + quantized_model_init (meta device init)."""
if fp_recipe in ("NVFP4BlockScaling",):
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a pytest.skip in case MXFP8/block-scaling support is missing? This test fails on H100 systems since it doesnt have MXFP8 suppport. The function to check support is already above I think. Its also done in a lot of other tests in the repo for you to refer.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!. I changed _parametrize_fp8_recipes() to use pytest.mark.skipif instead of pytest.mark.xfail for hardware support checks. On systems without MXFP8/block-scaling support, tests are now skipped outright instead of running the torchrun subprocess and failing noisily. Per-test xfail marks for known functional gaps (e.g., NVFP4 lacking FSDP2 hooks) remain unchanged since those are feature limitations, not hardware issues.

"""
model, recipe = self._build_model()

# Verify weight params are QuantizedTensors (bias stays bf16)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also please add a pytest skip if float8 blockscaling is missing? I see you have added it already to mxfp8 test. We need it here similarly to avoid ci test failures.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — added te.is_fp8_block_scaling_available(return_reason=True) check in setup_method and pytest.skip guards in both test_float8block_linear_fused_adam_master_weights and test_float8block_linear_forward_backward_step, matching the MXFP8 pattern.

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
@jomitchellnv jomitchellnv force-pushed the jm/fix-fused-adam-quantized-weights-fsdp2 branch from 389a526 to fe248fe Compare March 11, 2026 19:28
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
@jomitchellnv jomitchellnv force-pushed the jm/fix-fused-adam-quantized-weights-fsdp2 branch from 83e468d to 7726070 Compare March 11, 2026 19:57
@jomitchellnv
Copy link
Copy Markdown
Contributor Author

/te-ci L1 pytorch

1 similar comment
@vthumbe1503
Copy link
Copy Markdown
Collaborator

/te-ci L1 pytorch

Signed-off-by: Jonathan Mitchell <jomitchell@umb-b300-dp-147.ipp4a1.colossus.nvidia.com>
@jomitchellnv jomitchellnv force-pushed the jm/fix-fused-adam-quantized-weights-fsdp2 branch from 24c5649 to 7ffa352 Compare March 11, 2026 23:05
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
@jomitchellnv jomitchellnv force-pushed the jm/fix-fused-adam-quantized-weights-fsdp2 branch from 84f215a to 4c1d280 Compare March 12, 2026 03:06
@jomitchellnv
Copy link
Copy Markdown
Contributor Author

/te-ci L1 pytorch

2 similar comments
@pstjohn
Copy link
Copy Markdown
Contributor

pstjohn commented Mar 12, 2026

/te-ci L1 pytorch

@vthumbe1503
Copy link
Copy Markdown
Collaborator

/te-ci L1 pytorch

@vthumbe1503
Copy link
Copy Markdown
Collaborator

/te-ci L1 pytorch

@jomitchellnv
Copy link
Copy Markdown
Contributor Author

/te-ci L1 pytorch

@vthumbe1503 vthumbe1503 merged commit fcceeb9 into NVIDIA:main Mar 13, 2026
22 of 25 checks passed
vthumbe1503 added a commit to ksivaman/TransformerEngine-1 that referenced this pull request Apr 1, 2026
…kScaling and Float8BlockScaling quantized model init. (NVIDIA#2753)

* Updates FusedAdam with FSDP2 and MXFP8

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>

* removes xfailing unit test for MXFPr MXFP8

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>

* addresses comments related to reset parameters and guard against self.capturable

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>

* adds e2e unit test

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>

* adds test to non meta device init

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>

* attempts to add float8block scaling fsdp hooks

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>

* adds e2e test for Float8BlockScaling

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>

* addresses review comments and code cleanup

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>

* more review comments addressed

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>

* removes unused block_len param

Signed-off-by: Jonathan Mitchell <jomitchell@umb-b300-dp-147.ipp4a1.colossus.nvidia.com>

* fixes failing unit test because we still need to xfail nvfp4 dcp

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* lint - replacing todo with note

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>

---------

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@umb-b300-dp-147.ipp4a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
Co-authored-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
Co-authored-by: Jonathan Mitchell <jomitchell@umb-b300-dp-147.ipp4a1.colossus.nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: vthumbe1503 <vthumbe@nvidia.com>
Co-authored-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 mentioned this pull request Apr 6, 2026
13 tasks
jomitchellnv pushed a commit to jomitchellnv/TransformerEngine that referenced this pull request Apr 8, 2026
  1. nvfp4_tensor.py — as_strided hardening + with_2d_quantization assertion
  2. test_nvfp4_fsdp2_hooks.py — remove unused tex import
  3. run_fsdp2_model.py — remove stale Float8BlockScaling xfail

  PR comments to leave:
  - greptile P1 (as_strided): Fixed — validates storage_offset == 0 and raises NotImplementedError for
  non-identity calls.
  - vthumbe1503 (2D quantization): Added assertion that with_2d_quantization == True when columnwise data is in
  the all-gather.
  - greptile P2 (unused import): Removed.
  - vthumbe1503 (Float8BlockScaling xfail): Correct, fixed in NVIDIA#2753 — removed the xfail

Signed-off-by: Jonathan Mitchell <jomitchell@r6515-0097.ipp1a1.colossus.nvidia.com>
jomitchellnv pushed a commit to jomitchellnv/TransformerEngine that referenced this pull request Apr 9, 2026
  1. nvfp4_tensor.py — as_strided hardening + with_2d_quantization assertion
  2. test_nvfp4_fsdp2_hooks.py — remove unused tex import
  3. run_fsdp2_model.py — remove stale Float8BlockScaling xfail

  PR comments to leave:
  - greptile P1 (as_strided): Fixed — validates storage_offset == 0 and raises NotImplementedError for
  non-identity calls.
  - vthumbe1503 (2D quantization): Added assertion that with_2d_quantization == True when columnwise data is in
  the all-gather.
  - greptile P2 (unused import): Removed.
  - vthumbe1503 (Float8BlockScaling xfail): Correct, fixed in NVIDIA#2753 — removed the xfail

Signed-off-by: Jonathan Mitchell <jomitchell@r6515-0097.ipp1a1.colossus.nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants