Skip to content

[JAX] Fix bf16 precision loss in TestGroupedDense reference dbias#2942

Merged
tdophung merged 2 commits into
NVIDIA:mainfrom
tdophung:fix-groupeddense-mismatch
Apr 30, 2026
Merged

[JAX] Fix bf16 precision loss in TestGroupedDense reference dbias#2942
tdophung merged 2 commits into
NVIDIA:mainfrom
tdophung:fix-groupeddense-mismatch

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented Apr 29, 2026

Description

TestGroupedDense::test_grouped_dense_grad_fp8 fails on MXFP8_1D_SCALING with input_shape=(8, 64, 128, 256): prim_dbias vs ref_dbias differs by ~6% on the largest group, exceeding the bf16 rtol (other rows pass; fwd, dgrad, wgrad pass).

Reference path adds bias in bf16 inside _ref_grouped_dense, so JAX autodiff
lowers the bias-grad sum-over-m in bf16 and saturates on the largest group. The primitive's grouped_dbias casts to fp32 before segment_sum, so it's accurate. The two computed dbias values therefore disagree even though the kernel is correct. Test-construction issue, not a kernel regression.

To fix this, in the reference impl, add NO bias to the grouped dense call, then add bias externally in fp32, so the autodiff bias-grad accumulates in fp32 --> This matches what grouped_dbias does for the primitive.

Reason why this bug just surface now was because of commit 70af730 ("[JAX] MXFP8 Grouped Quant+GEMM (#2763)") introduced group_size_multiplier=128 for MXFP8, which quadrupled M for this shape and pushed the largest group's bf16 bias-grad sum past the bf16 rtol.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • _ref_sum_grouped_dense: pass bias=None to _ref_grouped_dense, add bias externally in fp32, sum and divisor in fp32
  • _primitive_sum_grouped_dense: cast out and divisor to fp32 to mirror the reference

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…cumulated numerical error

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung marked this pull request as ready for review April 29, 2026 23:50
@tdophung tdophung requested a review from phu0ngng April 29, 2026 23:50
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 29, 2026

Greptile Summary

This test-only fix addresses a bf16 precision loss in the reference bias-grad computation for TestGroupedDense. By moving bias addition out of _ref_grouped_dense (where JAX autodiff accumulated the bias gradient in bf16) and into an explicit fp32 addition, the reference now matches the primitive's grouped_dbias behavior (which casts the cotangent to fp32 before segment_sum). The _primitive_sum_grouped_dense change to cast out to fp32 before summing is a complementary tweak to keep the forward loss values and cotangent dtype consistent between reference and primitive paths. The fix is correct and well-explained.

Confidence Score: 5/5

Safe to merge — test-only fix with correct fp32 accumulation logic and no production code changes.

No P0 or P1 findings. The fix correctly aligns the reference bias-grad accumulation dtype with the primitive, the forward and backward math is sound, zero-size group edge cases are handled correctly by broadcasting, and all changes are confined to the test file.

No files require special attention.

Important Files Changed

Filename Overview
tests/jax/test_custom_call_compute.py Fixes bf16 precision loss in reference dbias accumulation by adding bias externally in fp32; also casts primitive output to fp32 before summing for consistency — logic is correct and well-commented.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    subgraph REF ["Reference path (_ref_sum_grouped_dense) — AFTER FIX"]
        R1["_ref_grouped_dense(x, kernel, bias=None)"] -->|"out_i (bf16, zero bias)"| R2["out_i.astype(fp32) + bias_i.astype(fp32)"]
        R2 -->|"out_with_bias_fp32"| R3["jnp.sum (fp32)"]
        R3 --> R4["jnp.sum(out_sum_list) / sqrt(x.size)"]
        R4 -->|"JAX autodiff"| R5["ref_dbias (fp32 sum-over-m)"]
    end

    subgraph PRIM ["Primitive path (_primitive_sum_grouped_dense)"]
        P1["grouped_dense(x, kernel, bias=bias)"] -->|"out (bf16 + bf16 bias)"| P2["out.astype(fp32)"]
        P2 --> P3["jnp.sum(fp32) / sqrt(x.size)"]
        P3 -->|"primitive kernel"| P4["prim_dbias (fp32 segment_sum)"]
    end

    R5 -->|"assert_allclose"| CMP{"Match?"}
    P4 -->|"assert_allclose"| CMP
    CMP -->|"✅ Both fp32"| PASS["Test passes"]
Loading

Reviews (2): Last reviewed commit: "address greptile comments" | Re-trigger Greptile

Comment thread tests/jax/test_custom_call_compute.py Outdated
Comment on lines +1929 to +1935
if bias is None:
for out_i in out_list:
out_sum_list.append(jnp.sum(out_i.astype(jnp.float32)))
else:
for out_i, bias_i in zip(out_list, bias):
out_with_bias_fp32 = out_i.astype(jnp.float32) + bias_i.astype(jnp.float32)
out_sum_list.append(jnp.sum(out_with_bias_fp32))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Dead bias is None branch — unreachable in current grad tests

_ref_sum_grouped_dense is only called via value_and_grad(..., (0, 1, 2)) in test_grouped_dense_grad_fp16 and test_grouped_dense_grad_fp8, both of which always pass a concrete bias array (with_bias=True). The if bias is None branch is therefore never exercised by any existing test. If it is meant as a defensive fallback, consider adding a comment; if dbias is expected to be None, the autodiff of value_and_grad would return None for argument 2 and the downstream assert_allclose(prim_dbias, ref_dbias, ...) would need to handle that case.

Comment thread tests/jax/test_custom_call_compute.py Outdated
for out_i, bias_i in zip(out_list, bias):
out_with_bias_fp32 = out_i.astype(jnp.float32) + bias_i.astype(jnp.float32)
out_sum_list.append(jnp.sum(out_with_bias_fp32))
return jnp.sum(jnp.asarray(out_sum_list)) / jnp.sqrt(x.size).astype(jnp.float32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Redundant .astype(jnp.float32) on jnp.sqrt

jnp.sqrt(x.size) already returns float32 in JAX's default 32-bit mode (x64 disabled), so the .astype(jnp.float32) call is a no-op. Same pattern appears in _primitive_sum_grouped_dense. This is harmless but adds visual noise.

Suggested change
return jnp.sum(jnp.asarray(out_sum_list)) / jnp.sqrt(x.size).astype(jnp.float32)
return jnp.sum(jnp.asarray(out_sum_list)) / jnp.sqrt(x.size)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Copy Markdown
Collaborator Author

/te-ci jax

Copy link
Copy Markdown
Collaborator

@phu0ngng phu0ngng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks

@tdophung tdophung merged commit cc05742 into NVIDIA:main Apr 30, 2026
21 of 24 checks passed
KshitijLakhani pushed a commit that referenced this pull request May 2, 2026
)

* accumulate bias in fp32 instead of bf16 in ref impl dbias to avoid accumulated numerical error

Signed-off-by: tdophung <tdophung@nvidia.com>
faradawn pushed a commit to faradawn/TransformerEngine that referenced this pull request May 14, 2026
…IDIA#2942)

* accumulate bias in fp32 instead of bf16 in ref impl dbias to avoid accumulated numerical error

Signed-off-by: tdophung <tdophung@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants