Fix memory overheads with FP4 native weights#2834
Conversation
Signed-off-by: qiyuw <qiyuw@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR reverts a previous optimization that batched FP32→bf16/fp16 dtype conversion of master weights via
Confidence Score: 5/5This PR is safe to merge — it is a targeted revert of a memory-regression optimization with no functional side-effects. The change is simple and narrowly scoped. None master_weight is guarded correctly, the per-parameter .to() is a no-op when dtypes match, and the restructured if-elif chain is behaviorally equivalent to the old nested structure. No P1 or P0 issues found. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[quantize_master_weights called] --> B[Loop over model_weights]
B --> C[clear_high_precision_init_val]
C --> D{master_weight is not None?}
D -- Yes --> E[master_weight.to\nmodel_weight.dtype]
D -- No --> F[get quantizer]
E --> F
F --> G{quantizer type?}
G -- NVFP4Quantizer --> H[nvfp4_params list]
G -- Float8Quantizer --> I[delayed_scaling_params list]
G -- Float8CurrentScalingQuantizer --> J[current_scaling_params list]
G -- Float8BlockQuantizer --> K[blockwise_scaling_params list]
G -- MXFP8Quantizer --> L[mxfp8_scaling_params list]
G -- Other --> M[raise ValueError]
H & I & J & K & L --> N[End loop]
N --> O[_cast_master_weights_to_fp8_delayed_scaling]
N --> P[_cast_master_weights_to_fp8_current_scaling]
N --> Q[_cast_master_weights_to_fp8_blockwise_scaling]
N --> R[_cast_master_weights_to_fp8_mxfp8_scaling]
N --> S[_cast_master_weights_to_nvfp4_2d]
Reviews (2): Last reviewed commit: "Merge branch 'main' into fix-fp4-mem" | Re-trigger Greptile |
|
/te-ci pytorch L1 |
|
/te-ci pytorch |
Description
Previous implementation concated master weights into one tensor and did fp32->bf16 conversion once. However, this torch.cat creates a full FP32 copy of ALL master weights into one contiguous buffer, causing the increase of peak memory usage. This diminishes the memory savings of FP4 native weights. This PR reverts the change and sticks with per-parameter conversion.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: