Skip to content

Fix memory overheads with FP4 native weights#2834

Merged
ksivaman merged 4 commits intoNVIDIA:mainfrom
WanZzzzzz:fix-fp4-mem
Apr 6, 2026
Merged

Fix memory overheads with FP4 native weights#2834
ksivaman merged 4 commits intoNVIDIA:mainfrom
WanZzzzzz:fix-fp4-mem

Conversation

@WanZzzzzz
Copy link
Copy Markdown
Contributor

@WanZzzzzz WanZzzzzz commented Apr 3, 2026

Description

Previous implementation concated master weights into one tensor and did fp32->bf16 conversion once. However, this torch.cat creates a full FP32 copy of ALL master weights into one contiguous buffer, causing the increase of peak memory usage. This diminishes the memory savings of FP4 native weights. This PR reverts the change and sticks with per-parameter conversion.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

WanZzzzzz and others added 3 commits April 3, 2026 14:02
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 3, 2026

Greptile Summary

This PR reverts a previous optimization that batched FP32→bf16/fp16 dtype conversion of master weights via torch.cat into a single kernel call, replacing it with per-parameter .to(model_weight.dtype) conversion. The batch approach introduced a full-copy FP32 allocation of all master weights simultaneously, negating the peak memory savings that FP4 native weights were meant to provide.

  • Removes the has_nvfp4 batch-conversion block (which called torch.cat on all master weights, then split) and replaces it with a simple per-parameter master_weight.to(model_weight.dtype) call applied before quantizer dispatch.
  • The if master_weight is not None: guard is preserved so None master weights (empty FSDP shards) are handled correctly.
  • The if-elif dispatch chain is restructured to be flat and unconditional — previously NVFP4 was handled separately from FP8 for dtype conversion — making the logic cleaner and consistent across all quantizer types.
  • Tradeoff is explicit: N individual conversion kernels instead of one batched kernel, but peak memory is substantially lower.

Confidence Score: 5/5

This PR is safe to merge — it is a targeted revert of a memory-regression optimization with no functional side-effects.

The change is simple and narrowly scoped. None master_weight is guarded correctly, the per-parameter .to() is a no-op when dtypes match, and the restructured if-elif chain is behaviorally equivalent to the old nested structure. No P1 or P0 issues found.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/utils.py Removed batched torch.cat master-weight dtype conversion and replaced with per-parameter .to() calls, reducing peak FP32 memory overhead; logic is correct and handles None master_weights safely.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[quantize_master_weights called] --> B[Loop over model_weights]
    B --> C[clear_high_precision_init_val]
    C --> D{master_weight is not None?}
    D -- Yes --> E[master_weight.to\nmodel_weight.dtype]
    D -- No --> F[get quantizer]
    E --> F
    F --> G{quantizer type?}
    G -- NVFP4Quantizer --> H[nvfp4_params list]
    G -- Float8Quantizer --> I[delayed_scaling_params list]
    G -- Float8CurrentScalingQuantizer --> J[current_scaling_params list]
    G -- Float8BlockQuantizer --> K[blockwise_scaling_params list]
    G -- MXFP8Quantizer --> L[mxfp8_scaling_params list]
    G -- Other --> M[raise ValueError]
    H & I & J & K & L --> N[End loop]
    N --> O[_cast_master_weights_to_fp8_delayed_scaling]
    N --> P[_cast_master_weights_to_fp8_current_scaling]
    N --> Q[_cast_master_weights_to_fp8_blockwise_scaling]
    N --> R[_cast_master_weights_to_fp8_mxfp8_scaling]
    N --> S[_cast_master_weights_to_nvfp4_2d]
Loading

Reviews (2): Last reviewed commit: "Merge branch 'main' into fix-fp4-mem" | Re-trigger Greptile

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@timmoon10
Copy link
Copy Markdown
Collaborator

/te-ci pytorch L1

@ksivaman
Copy link
Copy Markdown
Member

ksivaman commented Apr 6, 2026

/te-ci pytorch

Copy link
Copy Markdown
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ksivaman ksivaman merged commit ac96651 into NVIDIA:main Apr 6, 2026
21 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants