Skip to content

Skip test_batch_axis_sharding_jvp13 test 0.8.0#708

Merged
AratiGanesh merged 1 commit intorocm-jaxlib-v0.8.0from
skip-jvp13-v0.8.0
Feb 16, 2026
Merged

Skip test_batch_axis_sharding_jvp13 test 0.8.0#708
AratiGanesh merged 1 commit intorocm-jaxlib-v0.8.0from
skip-jvp13-v0.8.0

Conversation

@AratiGanesh
Copy link

Motivation

The test_batch_axis_sharding_jvp13 test (qr with complex64) fails on ROCm devices due to numerical precision issues in rocSolver as of ROCm 7.2.

Technical Details

Added a conditional skip in tests/linalg_sharding_test.py for test_batch_axis_sharding_jvp that only skips when:
Running on ROCm devices (jtu.is_device_rocm())
Function is lax.linalg.qr (fun_and_shapes[0] is lax.linalg.qr)
Dtype is np.complex64 (dtype == np.complex64)
This targets only the failing test case (test_batch_axis_sharding_jvp13). All other test cases continue to run normally.

Test Result

The test is now skipped on ROCm for the specific failing case, preventing test failures while preserving coverage for other variants.
image

Upstream PR - jax-ml#34966

@AratiGanesh AratiGanesh requested a review from a team February 10, 2026 20:41
@AratiGanesh AratiGanesh added open-upstream Tag when you want a copy of this PR to be opened on upstream cherry-pick-candidate Mark a PR to be cherry-picked into the next ROCm JAX. Remove IIF the latest upstream contain the PR. labels Feb 10, 2026
@AratiGanesh AratiGanesh merged commit e0b2a1e into rocm-jaxlib-v0.8.0 Feb 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cherry-pick-candidate Mark a PR to be cherry-picked into the next ROCm JAX. Remove IIF the latest upstream contain the PR. open-upstream Tag when you want a copy of this PR to be opened on upstream

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants