Modifying the NVFP4 block quantization kernel to handle the single le…#5680
Modifying the NVFP4 block quantization kernel to handle the single le…#5680jjsjann123 merged 3 commits intomainfrom
Conversation
…vel quantization correctly
|
Review updated until commit c9b7cf1 Description
|
| Relevant files | |||||
|---|---|---|---|---|---|
| Bug fix |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Numerical Stability
1.0f / scaled_max could cause numerical issues if scaled_max is zero or very close to zero. Consider adding bounds checking or using a safe reciprocal function to prevent division by zero or overflow. |
Test failures
-
(Medium, 1)
CUDA OOM in nvFuser TmaPointwiseTestF.SplitGridDim2D (runner dlcluster_h100)Test Name H100 Source TmaPointwiseTestF.SplitGridDim2D ❌ Link
Greptile SummaryThis PR fixes a mathematical bug in the NVFP4 block quantization kernel for single-level quantization (when no global scale is provided). Key Changes:
Technical Context:
The bug was that step 4 was missing for the single-level case, while it was correctly implemented for the two-level case (line 132). This PR adds the symmetrical logic for single-level quantization. Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Input as Input Data (FP32/BF16)
participant Kernel as block_quantize_to_nvfp4
participant MaxCalc as Max Calculation
participant FP8Conv as FP8 Conversion
participant Scale as Scale Computation
participant Quant as Quantization
Input->>Kernel: Array of values (ITEMS_PER_THREAD)
Kernel->>MaxCalc: Convert to float & compute local max
MaxCalc->>MaxCalc: Reduce across threads (16/ITEMS_PER_THREAD)
MaxCalc-->>Kernel: block_max
alt USE_GLOBAL_SCALE = true
Kernel->>FP8Conv: scaled_max = block_max * global_scale * (1/6)
else USE_GLOBAL_SCALE = false
Kernel->>FP8Conv: scaled_max = block_max * (1/6)
end
FP8Conv->>FP8Conv: clamped_max_fp8 = __float2e4m3(scaled_max)
FP8Conv->>FP8Conv: scaled_max = __e4m32float(clamped_max_fp8)
FP8Conv-->>Scale: quantized block scale
alt USE_GLOBAL_SCALE = true
Scale->>Scale: scaled_max = global_scale / scaled_max
else USE_GLOBAL_SCALE = false (NEW)
Scale->>Scale: scaled_max = 1.0 / scaled_max
end
Scale-->>Quant: scaling factor (scaled_max)
Quant->>Quant: Write block_scales to global memory
Quant->>Quant: Scale input values by scaled_max
Quant->>Quant: Convert to FP4: __float2e2m1
Quant-->>Kernel: Quantized FP4 output
|
|
!test |
|
!test |
|
!test |
The quantization needs to NVFP4 when there's no global scale needs a minor correction in the math.
This PR also updates the existing C++ test to reflect that.
We can't test this against TE as yet since there's no implementation without the global scale.