Skip to content

Phambinh/skip fused attention navi lds limit#720

Merged
phambinhfin merged 3 commits intorocm-jaxlib-v0.8.2from
phambinh/skip-fused-attention-navi-lds-limit
Mar 4, 2026
Merged

Phambinh/skip fused attention navi lds limit#720
phambinhfin merged 3 commits intorocm-jaxlib-v0.8.2from
phambinh/skip-fused-attention-navi-lds-limit

Conversation

@phambinhfin
Copy link
Copy Markdown

@phambinhfin phambinhfin commented Mar 3, 2026

This one addresses the ticket https://ontrack-internal.amd.com/browse/SWDEV-579393 that happen in Navi

Navi/RDNA GPU Test Skips

Summary of tests skipped on Navi/RDNA GPUs and the reasons why.

1. Fused Attention Tests — 64KB LDS Limit

Tests:

  • tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd0
  • tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd1
  • tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd4
  • tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd7
  • tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd1
  • tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd4
  • tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd7
  • tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd8

Reason: Navi/RDNA GPUs have 64KB of LDS (Local Data Store / shared memory). When
head_dim is padded to 128+, the Pallas fused attention kernels request up to 90–98KB
of shared memory, exceeding the 64KB limit. The runtime error is:

RESOURCE_EXHAUSTED: Shared memory size limit exceeded: requested 90112, available: 65536

These tests pass on MI300 (Instinct) GPUs which have larger shared memory capacity.

Detection: Skip when "Radeon" is in device_kind and head_dim padded to ≥128.

2. TF32 Dot Algorithm Test — Unsupported on Navi

Test:

  • tests/lax_test.py::LaxTest::testDotAlgorithm13

Reason: TF32 (TensorFloat-32) is not supported on Navi/RDNA GPUs. XLA only allows
ALG_DOT_TF32_TF32_F32 on ROCm MI100 and above (gfx9 CDNA series: gfx908, gfx90a,
gfx942). Navi GPUs (gfx11xx, gfx12xx) are not in the gfx9_mi100_or_later() allowlist.
See xla/service/algorithm_util.cc lines 290–295 and
xla/stream_executor/rocm/rocm_compute_capability.h lines 124–126. The runtime error is:

UNIMPLEMENTED: Unsupported algorithm on the current device(s): ALG_DOT_TF32_TF32_F32

Detection: Skip when "Radeon" is in device_kind and algorithm is TF32_TF32_F32.

3. Distributed Visible Devices Test — Requires 4 GPUs

Test:

  • tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices

Reason: The test hardcodes num_gpus = 4 and spawns 4 tasks referencing GPU indices
0–3. On some Navi clusters (e.g., ctr-navi4x-aj50-ws08) there are only 2 GPUs, so the
test fails because GPU indices 2 and 3 do not exist.

Detection: Skip when jax.device_count() < 4.
Results
pytest tests/lax_test.py::LaxTest::testDotAlgorithm13 tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd1 tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd4 tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd7 tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd8 tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd0 tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd1 tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd4 tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd7 tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices --tb=short -q
Test session starting on GPU ?
ssssssssss [100%]
10 skipped in 30.75s

Navi/RDNA GPUs (AMD Radeon) have 64KB of LDS (shared memory), which
is insufficient for the Pallas fused attention kernels when head_dim
is padded to 128+. The forward kernel requests up to 98KB and the
backward kernel up to 90KB, both exceeding the 64KB limit.

These tests pass on MI300 (Instinct) GPUs which have larger shared
memory capacity. The skip is gated on "Radeon" in device_kind so
it only affects Navi/RDNA consumer GPUs.
The test hardcodes num_gpus=4 and spawns 4 tasks referencing GPU
indices 0-3. This fails on systems with fewer than 4 GPUs.
@phambinhfin phambinhfin requested a review from a team March 3, 2026 14:50
@phambinhfin phambinhfin force-pushed the phambinh/skip-fused-attention-navi-lds-limit branch from 57aaf2b to 94878f2 Compare March 3, 2026 14:54
Comment thread tests/pallas/gpu_ops_test.py
TF32 (TensorFloat-32) is an NVIDIA-only format. AMD Radeon (Navi/RDNA)
GPUs do not support it, causing UNIMPLEMENTED errors at runtime.
Copy link
Copy Markdown

@i-chaochen i-chaochen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you check and confirm hipblaslt's TF32 support on radeon gpu, please? cc @AleksaArsic wondering do you know the answer?

Comment thread tests/lax_test.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants