[JAX] Fix bf16 precision loss in TestGroupedDense reference dbias#2942
Conversation
…cumulated numerical error Signed-off-by: tdophung <tdophung@nvidia.com>
Greptile SummaryThis test-only fix addresses a bf16 precision loss in the reference bias-grad computation for Confidence Score: 5/5Safe to merge — test-only fix with correct fp32 accumulation logic and no production code changes. No P0 or P1 findings. The fix correctly aligns the reference bias-grad accumulation dtype with the primitive, the forward and backward math is sound, zero-size group edge cases are handled correctly by broadcasting, and all changes are confined to the test file. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
subgraph REF ["Reference path (_ref_sum_grouped_dense) — AFTER FIX"]
R1["_ref_grouped_dense(x, kernel, bias=None)"] -->|"out_i (bf16, zero bias)"| R2["out_i.astype(fp32) + bias_i.astype(fp32)"]
R2 -->|"out_with_bias_fp32"| R3["jnp.sum (fp32)"]
R3 --> R4["jnp.sum(out_sum_list) / sqrt(x.size)"]
R4 -->|"JAX autodiff"| R5["ref_dbias (fp32 sum-over-m)"]
end
subgraph PRIM ["Primitive path (_primitive_sum_grouped_dense)"]
P1["grouped_dense(x, kernel, bias=bias)"] -->|"out (bf16 + bf16 bias)"| P2["out.astype(fp32)"]
P2 --> P3["jnp.sum(fp32) / sqrt(x.size)"]
P3 -->|"primitive kernel"| P4["prim_dbias (fp32 segment_sum)"]
end
R5 -->|"assert_allclose"| CMP{"Match?"}
P4 -->|"assert_allclose"| CMP
CMP -->|"✅ Both fp32"| PASS["Test passes"]
Reviews (2): Last reviewed commit: "address greptile comments" | Re-trigger Greptile |
| if bias is None: | ||
| for out_i in out_list: | ||
| out_sum_list.append(jnp.sum(out_i.astype(jnp.float32))) | ||
| else: | ||
| for out_i, bias_i in zip(out_list, bias): | ||
| out_with_bias_fp32 = out_i.astype(jnp.float32) + bias_i.astype(jnp.float32) | ||
| out_sum_list.append(jnp.sum(out_with_bias_fp32)) |
There was a problem hiding this comment.
Dead
bias is None branch — unreachable in current grad tests
_ref_sum_grouped_dense is only called via value_and_grad(..., (0, 1, 2)) in test_grouped_dense_grad_fp16 and test_grouped_dense_grad_fp8, both of which always pass a concrete bias array (with_bias=True). The if bias is None branch is therefore never exercised by any existing test. If it is meant as a defensive fallback, consider adding a comment; if dbias is expected to be None, the autodiff of value_and_grad would return None for argument 2 and the downstream assert_allclose(prim_dbias, ref_dbias, ...) would need to handle that case.
| for out_i, bias_i in zip(out_list, bias): | ||
| out_with_bias_fp32 = out_i.astype(jnp.float32) + bias_i.astype(jnp.float32) | ||
| out_sum_list.append(jnp.sum(out_with_bias_fp32)) | ||
| return jnp.sum(jnp.asarray(out_sum_list)) / jnp.sqrt(x.size).astype(jnp.float32) |
There was a problem hiding this comment.
Redundant
.astype(jnp.float32) on jnp.sqrt
jnp.sqrt(x.size) already returns float32 in JAX's default 32-bit mode (x64 disabled), so the .astype(jnp.float32) call is a no-op. Same pattern appears in _primitive_sum_grouped_dense. This is harmless but adds visual noise.
| return jnp.sum(jnp.asarray(out_sum_list)) / jnp.sqrt(x.size).astype(jnp.float32) | |
| return jnp.sum(jnp.asarray(out_sum_list)) / jnp.sqrt(x.size) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Signed-off-by: tdophung <tdophung@nvidia.com>
|
/te-ci jax |
…IDIA#2942) * accumulate bias in fp32 instead of bf16 in ref impl dbias to avoid accumulated numerical error Signed-off-by: tdophung <tdophung@nvidia.com>
Description
TestGroupedDense::test_grouped_dense_grad_fp8fails onMXFP8_1D_SCALINGwithinput_shape=(8, 64, 128, 256):prim_dbiasvsref_dbiasdiffers by ~6% on the largest group, exceeding the bf16 rtol (other rows pass; fwd, dgrad, wgrad pass).Reference path adds bias in bf16 inside
_ref_grouped_dense, so JAX autodifflowers the bias-grad sum-over-m in bf16 and saturates on the largest group. The primitive's
grouped_dbiascasts to fp32 beforesegment_sum, so it's accurate. The two computed dbias values therefore disagree even though the kernel is correct. Test-construction issue, not a kernel regression.To fix this, in the reference impl, add NO bias to the grouped dense call, then add bias externally in fp32, so the autodiff bias-grad accumulates in fp32 --> This matches what
grouped_dbiasdoes for the primitive.Reason why this bug just surface now was because of commit 70af730 ("[JAX] MXFP8 Grouped Quant+GEMM (#2763)") introduced
group_size_multiplier=128for MXFP8, which quadrupled M for this shape and pushed the largest group's bf16 bias-grad sum past the bf16 rtol.Fixes # (issue)
Type of change
Changes
_ref_sum_grouped_dense: passbias=Noneto_ref_grouped_dense, add bias externally in fp32, sum and divisor in fp32_primitive_sum_grouped_dense: castoutand divisor to fp32 to mirror the referenceChecklist: