Phambinh/skip fused attention navi lds limit#720
Merged
phambinhfin merged 3 commits intorocm-jaxlib-v0.8.2from Mar 4, 2026
Merged
Phambinh/skip fused attention navi lds limit#720phambinhfin merged 3 commits intorocm-jaxlib-v0.8.2from
phambinhfin merged 3 commits intorocm-jaxlib-v0.8.2from
Conversation
Navi/RDNA GPUs (AMD Radeon) have 64KB of LDS (shared memory), which is insufficient for the Pallas fused attention kernels when head_dim is padded to 128+. The forward kernel requests up to 98KB and the backward kernel up to 90KB, both exceeding the 64KB limit. These tests pass on MI300 (Instinct) GPUs which have larger shared memory capacity. The skip is gated on "Radeon" in device_kind so it only affects Navi/RDNA consumer GPUs.
The test hardcodes num_gpus=4 and spawns 4 tasks referencing GPU indices 0-3. This fails on systems with fewer than 4 GPUs.
57aaf2b to
94878f2
Compare
Ruturaj4
approved these changes
Mar 3, 2026
Arech8
reviewed
Mar 3, 2026
TF32 (TensorFloat-32) is an NVIDIA-only format. AMD Radeon (Navi/RDNA) GPUs do not support it, causing UNIMPLEMENTED errors at runtime.
i-chaochen
reviewed
Mar 3, 2026
94878f2 to
171e73c
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This one addresses the ticket https://ontrack-internal.amd.com/browse/SWDEV-579393 that happen in Navi
Navi/RDNA GPU Test Skips
Summary of tests skipped on Navi/RDNA GPUs and the reasons why.
1. Fused Attention Tests — 64KB LDS Limit
Tests:
tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd0tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd1tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd4tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd7tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd1tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd4tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd7tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd8Reason: Navi/RDNA GPUs have 64KB of LDS (Local Data Store / shared memory). When
head_dimis padded to 128+, the Pallas fused attention kernels request up to 90–98KBof shared memory, exceeding the 64KB limit. The runtime error is:
These tests pass on MI300 (Instinct) GPUs which have larger shared memory capacity.
Detection: Skip when
"Radeon"is indevice_kindandhead_dimpadded to ≥128.2. TF32 Dot Algorithm Test — Unsupported on Navi
Test:
tests/lax_test.py::LaxTest::testDotAlgorithm13Reason: TF32 (TensorFloat-32) is not supported on Navi/RDNA GPUs. XLA only allows
ALG_DOT_TF32_TF32_F32on ROCm MI100 and above (gfx9 CDNA series: gfx908, gfx90a,gfx942). Navi GPUs (gfx11xx, gfx12xx) are not in the
gfx9_mi100_or_later()allowlist.See
xla/service/algorithm_util.cclines 290–295 andxla/stream_executor/rocm/rocm_compute_capability.hlines 124–126. The runtime error is:Detection: Skip when
"Radeon"is indevice_kindand algorithm isTF32_TF32_F32.3. Distributed Visible Devices Test — Requires 4 GPUs
Test:
tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devicesReason: The test hardcodes
num_gpus = 4and spawns 4 tasks referencing GPU indices0–3. On some Navi clusters (e.g.,
ctr-navi4x-aj50-ws08) there are only 2 GPUs, so thetest fails because GPU indices 2 and 3 do not exist.
Detection: Skip when
jax.device_count() < 4.Results
pytest tests/lax_test.py::LaxTest::testDotAlgorithm13 tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd1 tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd4 tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd7 tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd8 tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd0 tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd1 tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd4 tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd7 tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices --tb=short -q
Test session starting on GPU ?
ssssssssss [100%]
10 skipped in 30.75s