Conversation
|
Hi, do you have test cases for your implentation? |
tests/test_aiter_mha.py is the file I attached the shell script that runs all use cases. Results are like AITer MHA Test Suite Running test_batch_fwd_shape ... PASSED (50 passed) RESULTS Passed: 218 |
0df02f7 to
a6041a3
Compare
There was a problem hiding this comment.
Code Review: AITER MHA Kernels (ASM+CK) Integration
Commit: a6041a3b2c8ad9d6f86f071855e8523a3a49894a
Author: Zahid Iqbal
Scope: 3,644 lines added across 24 files -- integrates AITER multi-head attention forward/backward kernels for ROCm (CK + ASM v3) into JAX via XLA FFI.
1. Architecture Overview
The commit introduces a well-structured layered integration:
- Python API layer (
jax/_src/aiter/aiter_mha.py): Publicflash_attn_funcandflash_attn_varlenwithcustom_vjp, dispatching to unified forward/backward wrappers. - FFI bridge (
jaxlib/gpu/hip_aiter_mha_fwd.cc,hip_aiter_mha_bwd.cc): C++ handlers bound via XLA FFI, translating JAX buffers intoaiter::mha_fwd_args/aiter::mha_bwd_args. - Third-party headers (
third_party/aiter/): AITER's ownmha_fwd.h,mha_bwd.h, ASM kernel loader infrastructure. - Nanobind glue (
hip_aiter.cc): Registration of FFI targets.
The design of a single unified handler detecting batch (4D) vs varlen (3D) from tensor rank is clean and avoids handler proliferation.
2. Thread Safety and Data Race Issues (Critical Section)
2.1 CRITICAL: hipMemcpyAsync from stack-local buffer (use-after-scope risk)
In hip_aiter_mha_common_utils.cc, prepare_rng_state_for_fwd:
uint64_t host_rng[2] = {seed_value, offset_value};
hipError_t err =
hipMemcpyAsync(rng_state->untyped_data(), host_rng,
2 * sizeof(int64_t), hipMemcpyHostToDevice, stream);host_rng is a stack-allocated array. hipMemcpyAsync with hipMemcpyHostToDevice from pageable host memory is synchronous with respect to the host in the current HIP/CUDA runtime (the DMA engine must stage through a pinned bounce buffer), so this is de facto safe. However, this is an implementation-defined behavior -- future ROCm runtime changes could make this truly async, at which point host_rng would be read after it goes out of scope. The same pattern appears in the forward handler:
std::vector<float> neg_inf(n, -std::numeric_limits<float>::infinity());
HIP_CHECK(hipMemcpyAsync(lse->untyped_data(), neg_inf.data(),
n * sizeof(float), hipMemcpyHostToDevice, stream));Recommendation: Either use hipMemcpy (synchronous) explicitly, or use pinned memory, or add a hipStreamSynchronize after the copy.
2.2 CRITICAL: Reading device pointers on the host
In prepare_rng_state_for_fwd:
const auto *gen_data = static_cast<const int64_t *>(gen->untyped_data());
seed_value = static_cast<uint64_t>(gen_data[0]); // Host reads device memory!
offset_value = static_cast<uint64_t>(gen_data[1]); // Host reads device memory!
VLOG(1) << "Using provided generator with seed: " << seed_value << ...;If gen is a device buffer (which is the typical case in JAX -- all FFI buffers are device-side), dereferencing gen_data[0] from host code is undefined behavior. It will either segfault or return garbage. The subsequent hipMemcpyAsync(DeviceToDevice) is correct, but the VLOG will log incorrect values. The real concern is the UB itself.
Recommendation: Remove the host-side reads of gen_data[0]/gen_data[1], or use hipMemcpy to bring the values to host first if logging is needed.
2.3 HIGH: BwdDeviceBuffers destructor races with GPU work
In hip_aiter_mha_bwd.cc, BwdDeviceBuffers::~BwdDeviceBuffers() calls hipFree on buffers that were just passed to async kernels:
void free_all() {
if (dbias_expanded) { hipFree(dbias_expanded); ... }
if (dk_expanded) { hipFree(dk_expanded); ... }
...
}hipFree performs an implicit device synchronization (waits for ALL streams), so the data won't actually be freed while in use. However:
- This causes a full device sync on every backward pass, destroying any pipeline overlap.
- The implicit sync behavior is an implementation detail that could change.
- It makes the handler incompatible with CUDA Graphs / HIP Graphs and the
kCmdBufferCompatibletrait the handler claims.
Recommendation: Use a workspace pattern -- have JAX allocate the scratch space as additional output buffers, or use XLA's buffer allocation to pre-allocate scratch.
2.4 MEDIUM: Non-thread-safe get_gpu_arch() in header
In aiter_hip_common.h:
static const std::string get_gpu_arch() {
hipDevice_t dev;
HIP_CALL(hipGetDevice(&dev));
HIP_CALL(hipGetDeviceProperties(&dev_prop, dev));
...
}This is marked static in a header, so every translation unit gets its own copy. hipGetDevice returns the thread-local current device. If this is called from different threads with different active devices, each will get a different answer, which is the intended behavior. However, the combined pattern with get_num_cu_func():
static uint32_t get_num_cu_func() {
static const uint32_t num_cu = get_num_cu_local(); // initialized once
return num_cu;
}This caches the CU count from whichever device happens to be current during the first call. If the system has heterogeneous GPUs, subsequent calls from threads using different devices will get the wrong CU count. C++11 guarantees thread-safe initialization of static locals, but the value may be wrong for non-default devices.
2.5 MEDIUM: RNG seed collision under concurrency
When no generator is provided, the seed is derived from a timestamp:
seed_value = static_cast<uint64_t>(timestamp) ^ static_cast<uint64_t>(dev_idx);Two concurrent forward passes on the same device within the same microsecond will produce identical seeds, leading to identical dropout masks. This is a correctness issue for training with dropout.
Recommendation: Use a monotonic atomic counter or incorporate the stream pointer / buffer address as additional entropy.
2.6 LOW: _make_fwd_call / _make_bwd_call are not cached
Each call to mha_fwd_unified / mha_bwd_unified creates a fresh jax.ffi.ffi_call + jax.jit wrapper:
fn = _make_fwd_call(out_shape, lse_shape, p_shape, rng_shape, q.dtype)While JAX's tracing and compilation have their own caches, repeatedly creating new ffi_call objects and JIT wrappers adds unnecessary overhead. These should be memoized by (out_shape, lse_shape, p_shape, rng_shape, dtype).
3. Memory Management Issues
3.1 hipMalloc / hipFree on every backward call
The backward handler allocates up to 5 device buffers per invocation:
HIP_CHECK(hipMalloc(&bufs.dq_acc, dq_acc_bytes)); // always
HIP_CHECK(hipMalloc(&bufs.dk_expanded, dk_sz)); // if MQA/GQA
HIP_CHECK(hipMalloc(&bufs.dv_expanded, dv_sz)); // if MQA/GQA
HIP_CHECK(hipMalloc(&bufs.dbias_expanded, dbias_sz)); // if has_dbias
HIP_CHECK(hipMalloc(&bufs.dummy_rng, 2*sizeof(uint64_t)));// if no rnghipMalloc is synchronous and expensive (~microseconds to milliseconds). For training workloads where backward is called millions of times, this is a significant performance bottleneck.
Recommendation: Allocate these as JAX-managed output buffers in the Python wrapper (passing workspace shapes through the FFI), or use a persistent memory pool.
3.2 Large stack-allocated vector for LSE initialization
std::vector<float> neg_inf(n, -std::numeric_limits<float>::infinity());For large batch/sequence lengths, n can be very large (e.g., batch=64, heads=32, seq=2048 => n=4M floats = 16MB heap allocation). This should use a fill kernel instead.
4. Correctness Issues
4.1 Wrong sequence length passed to v3 eligibility check
In _flash_attn_backward:
_, sq, hq, dq = q.shape
_, sk, hk, _ = k.shape
...
use_v3 = _compute_v3_eligibility_bwd(
dropout_p, hq, hk, dq, causal, wl, wr, bias, sq, gfx # <-- sq, not sk
)The parameter is named sq_or_max_sk, and the check is:
if causal and gfx == "gfx950" and sq_or_max_sk > 256:
use_v3 = FalseFor the batch path, sq (query sequence length) is being passed where sk (key sequence length) is expected. When sq != sk (cross-attention), this can incorrectly enable or disable ASM v3. Compare with the varlen backward which correctly passes max_seqlen_k:
use_v3 = _compute_v3_eligibility_bwd(
res_dp, hq, hk, dq, causal, window_size[0], window_size[1],
None, max_seqlen_k, gfx # <-- correct: max_seqlen_k
)This is a bug for cross-attention cases on gfx950 with sq <= 256 < sk (causal).
4.2 kCmdBufferCompatible trait is incorrect
Both handlers are declared with:
{xla::ffi::Traits::kCmdBufferCompatible}But the backward handler calls hipMalloc, hipFree, hipPointerGetAttributes, and other operations that are not compatible with command buffer recording. The forward handler similarly calls hipPointerGetAttributes via device_from_ptr. This trait should be removed or the handlers refactored.
4.3 Variable shadowing in mha_fwd_unified
def mha_fwd_unified(q, k, v, ...):
...
dq = q.shape[-1] # shadows name 'dq' which in MHA context means gradient of q
use_v3_fwd = not (get_gfx() == "gfx950" and dq >= 96)While not a bug, dq meaning "head dimension of q" vs. "gradient of q" is confusing in MHA code. Consider renaming to hdim_q.
4.4 _flash_attn_func_bwd gradient count vs. custom_vjp nondiff argnums
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 9, 10, 11, 12, 13))
def flash_attn_func(q, k, v, dropout_p, softmax_scale, causal, window_size,
bias, alibi_slopes, deterministic, return_lse,
return_attn_probs, cu_seqlens_q, cu_seqlens_kv):Differentiable args: indices 0,1,2,7,8 = q, k, v, bias, alibi_slopes. The backward returns:
return (dq, dk, dv, dbias, None) # 5 values for 5 diff argsThis is correct but note that None is returned for alibi_slopes gradient. If someone passes alibi_slopes that requires grad in the future, this will silently drop gradients. Consider adding a check or comment.
4.5 Missing softmax_scale default in varlen path differs from batch
In flash_attn_varlen, the default scale uses q.shape[-1] which is the head dimension for 3D input. In flash_attn_func, it uses q.shape[-1] which is also head dimension for 4D input. Both are correct, but the defensive check should be unified.
5. Error Handling Issues
5.1 HIP_CHECK calls std::abort() -- no graceful degradation
inline void hipCheck(hipError_t err, const char *file, int line) {
if (err != hipSuccess) {
LOG(ERROR) << ...;
std::abort();
}
}In the backward handler, if hipMalloc fails (e.g., OOM), the process is killed. For an FFI handler, it would be better to return ffi::Error with an appropriate error code and let JAX handle the error gracefully.
5.2 Bare catch (...) swallows errors
In hip_aiter_mha_bwd.cc:
try {
auto [s, o] = mha_utils::get_rng_seed_offset_ptrs(rng_state_, dropout_p);
seed_ptr = s; offset_ptr = o;
} catch (...) { /* fallthrough to dummy */ }This silently swallows any exception, including errors that indicate real problems (e.g., buffer too small). At minimum, log the exception.
5.3 dtype_to_string throws std::runtime_error
inline std::string dtype_to_string(xla::ffi::DataType dtype) {
...
default:
throw std::runtime_error("Unsupported dtype for MHA");
}Throwing across FFI boundaries is undefined behavior in many contexts. Should return ffi::Error instead.
6. Python-Side Issues
6.1 get_gfx() shells out to rocminfo at import time
@functools.lru_cache(maxsize=1)
def get_gfx() -> str:
...
result = subprocess.run([os.path.realpath(rocminfo)], ...)This is called during tracing, which happens at JIT compilation time. It's expensive, fragile (depends on rocminfo being on PATH), and won't work in sandboxed environments. The lru_cache helps but the first call can still be slow.
Recommendation: Query the GPU architecture through the HIP runtime API in C++ and expose it via an FFI call or Python binding, rather than shelling out.
6.2 FFI target name mismatch
In hip_aiter.cc:
dict[JAX_GPU_PREFIX "_mha_fwd_ffi"] = EncapsulateFfiHandler(aiter_mha_fwd);With JAX_GPU_PREFIX = "hip", this registers as "hip_mha_fwd_ffi". But the Python side calls:
jax.ffi.ffi_call("hip_mha_fwd_ffi", ...)The name must match exactly. If JAX_GPU_PREFIX is ever changed (e.g., to "rocm"), this silently breaks. Consider using a constant or deriving the name.
6.3 Redundant import
import functools
...
from functools import partialBoth functools and functools.partial are imported. Minor cleanup opportunity.
6.4 Missing newline at end of file
Multiple files (aiter_mha.py, __init__.py, several C++ files) are missing the trailing newline. This causes \ No newline at end of file in diffs and can confuse some tools.
7. Test Coverage Assessment
The test file (test_aiter_mha.py, 551 lines) is quite comprehensive:
Strengths:
- TE-style tolerance computation (
eps^(2/3)) is appropriate for mixed-precision - Good parametric coverage: head dims 32-256, both dtypes, MHA/GQA/MQA
- Regression guards for specific historical bugs (good practice)
- Both batch and varlen paths tested
- Edge cases: sq=1 (decode), sq>sk, sq<sk, large batch, single head
Gaps:
- No concurrent execution tests -- no test for thread-safety or multi-stream correctness
- No dropout numerical correctness test -- only shape/crash tests (acknowledged in docstring, but this means dropout correctness is untested)
- No test for the
logits_soft_capparameter -- it's accepted but never tested - No test for
zero_tensors=True-- the flag is accepted but the defaultFalsepath is always taken - No test for
min_seqlen_qparameter - No negative tests -- no tests for invalid inputs (wrong dtypes, mismatched shapes, etc.)
- No multi-GPU tests
- Reference implementation in tests doesn't support GQA/MQA, so GQA/MQA accuracy is only tested for shape/finiteness, not numerical correctness
8. Build System Issues
8.1 Missing .so files in third_party/aiter/
The BUILD file references:
cc_import(
name = "mha_fwd_so",
shared_library = "libmha_fwd.so",
)But libmha_fwd.so and libmha_bwd.so are not included in the commit. There's no documentation for how these shared libraries should be built or obtained.
8.2 linkopts = ["-Wl,-rpath,$$ORIGIN"]
This sets RPATH to look for .so files in the same directory as the binary. This is correct for deployment but needs the .so files to be co-located at runtime.
9. Summary of Findings by Severity
| Severity | Count | Key Items |
|---|---|---|
| Critical | 2 | Host reads of device pointers (UB); hipMemcpyAsync from stack buffers |
| High | 2 | hipFree in destructor syncs device (perf); hipMalloc per backward call |
| Medium | 3 | Wrong sq vs sk in v3 eligibility; incorrect kCmdBufferCompatible; RNG seed collision |
| Low | 6+ | Missing tests for dropout/logits_soft_cap/zero_tensors; std::abort on OOM; swallowed exceptions; style issues |
The most actionable items are:
- Fix the host-side device pointer reads in
prepare_rng_state_for_fwd - Fix
sqvsskin_flash_attn_backward's v3 eligibility check - Remove
kCmdBufferCompatiblefrom both handlers - Replace per-call
hipMalloc/hipFreewith a workspace pattern for the backward pass
| def _pad_to_multiple_of_8(q, k, v): | ||
| """Pad head dimensions of Q/K/V to the next multiple of 8 if needed. | ||
|
|
||
| Returns (q_padded, k_padded, v_padded, hd_q_original, hd_v_original). | ||
| """ | ||
| hd_q = q.shape[-1] | ||
| hd_v = v.shape[-1] | ||
| q_p, k_p, v_p = q, k, v | ||
| ndim = q.ndim | ||
| if hd_q % 8 != 0: | ||
| pad = 8 - hd_q % 8 | ||
| pw = tuple((0, 0) for _ in range(ndim - 1)) + ((0, pad),) | ||
| q_p = jnp.pad(q, pw) | ||
| k_p = jnp.pad(k, pw) | ||
| if hd_v % 8 != 0: | ||
| pad = 8 - hd_v % 8 | ||
| pw = tuple((0, 0) for _ in range(ndim - 1)) + ((0, pad),) | ||
| v_p = jnp.pad(v, pw) | ||
| return q_p, k_p, v_p, hd_q, hd_v |
There was a problem hiding this comment.
Is this padding for enabling the ck/aiter flow or for pure performance? We didn't see configs with hdim not 8 multiples from external customers before so we didn't do it in TE. But I was curious about the overall performance comparison between padding-->ck/aiter-->unpadding vs passing original config to ck/aiter
There was a problem hiding this comment.
The padding is a hard functional requirement, not a performance optimization. The CK and ASM v3 kernels will produce incorrect results or crash if head dimensions are not multiples of 8
The upstream AITer mha.py (the PyTorch version) has the identical logic in every public entry point, at line 1760:
And then after the kernel runs, the output is sliced back: line 1845
dq = dq[..., :head_size_q_og] # We could have padded the head dimension
dk = dk[..., :head_size_q_og]
dv = dv[..., :head_size_v_og]
This is done in flash_attn_func, flash_attn_varlen_func, and flash_attn_fp8_pertensor_func — every single public API.
ASM v3 eligibility explicitly requires hdim % 8 == 0
aiter: mha.py:line1543, ret &= hdim_q >= 64 and hdim_q <= 192 and hdim_q % 8 == 0
| def _compute_v3_eligibility_bwd( | ||
| dropout_p, hq, hk, dq, causal, wl, wr, bias, sq_or_max_sk, gfx | ||
| ): | ||
| """Shared ASM v3 eligibility check for backward pass (batch & varlen).""" | ||
| swa = (wl > 0) or (wr >= 0 and wr != -1) | ||
| use_v3 = True | ||
| if dropout_p > 0: | ||
| use_v3 = False | ||
| if hq != hk: | ||
| use_v3 = False | ||
| if bias is not None and bias.size > 0: | ||
| use_v3 = False | ||
| if swa: | ||
| use_v3 = False | ||
| if causal and gfx == "gfx950" and sq_or_max_sk > 256: | ||
| use_v3 = False | ||
| if gfx == "gfx950" and dq >= 96: | ||
| use_v3 = False | ||
| return use_v3 |
There was a problem hiding this comment.
AITER had its internal v3 api checking and will fallback to v2 even if you requested v3 asm
There was a problem hiding this comment.
33 test crashed without this check... v3_api_check is user controlled in aiter, and controlled in benchmark tests, i think switch between v3 and v2
| gen = _empty(jnp.int64) | ||
|
|
||
| rng_shape = (2,) | ||
| bf16_cvt = 0 if get_gfx() == "gfx950" else 1 |
There was a problem hiding this comment.
bf16_cvt can also be set to different values in gfx942, refer to TE readme for more details: https://github.com/ROCm/TransformerEngine?tab=readme-ov-file#aiter-fa-v3-kernels
| if (dbias_expanded) { hipFree(dbias_expanded); dbias_expanded = nullptr; } | ||
| if (dummy_rng) { hipFree(dummy_rng); dummy_rng = nullptr; } | ||
| if (dq_acc) { hipFree(dq_acc); dq_acc = nullptr; } | ||
| if (dk_expanded) { hipFree(dk_expanded); dk_expanded = nullptr; } | ||
| if (dv_expanded) { hipFree(dv_expanded); dv_expanded = nullptr; } |
There was a problem hiding this comment.
Is it possible to request a buffer from jax? In other words, the jax will manage those extra buffers like dq_acc, softmax_lse_buffer and so on? Calling hipMalloc and hipFree could be heavy for e2e training if you need to run this every iteration
Integrated AMD's AITER library to provide high-performance multi-head
attention (MHA) forward and backward kernels on ROCm GPUs. AITER
dispatches internally between CK (Composable Kernel) and hand-tuned
ASM v3 assembly kernels for optimal performance on supported
architectures (e.g.gfx942, gfx950 + ).
Public API:
Both APIs support causal masking, sliding window attention, dropout,
ALiBi, GQA/MQA head layouts, and non-standard head dimensions
(automatically padded to multiples of 8).
Implementation:
wrap aiter::mha_fwd / aiter::mha_bwd with unified batch/varlen
dispatch based on tensor rank (4D=batch, 3D=varlen).
stride calculation, mask/bias construction, MQA/GQA reduction
kernels, and RNG state management.
hip_mha_fwd_ffi and hip_mha_bwd_ffi.
head-dim padding, ASM v3 eligibility checks, and gfx950-specific
guards.
jax-rocm plugin wheel via import_from_plugin.
ASM v3 kernel eligibility on gfx950:
causal with sq > sk configurations.
New files:
jax/_src/aiter/init.py
jax/_src/aiter/aiter_mha.py
jaxlib/gpu_aiter.py
jaxlib/gpu/hip_aiter.h
jaxlib/gpu/hip_aiter.cc
jaxlib/gpu/hip_aiter_mha_fwd.cc
jaxlib/gpu/hip_aiter_mha_bwd.cc
jaxlib/gpu/hip_aiter_mha_common_utils.h
jaxlib/gpu/hip_aiter_mha_common_utils.cc
third_party/aiter/{BUILD,include/}
tests/test_aiter_mha.py
Modified files:
jax/_src/BUILD (add aiter target)
jax/_src/lib/init.py (import gpu_aiter)
jaxlib/BUILD (add gpu_aiter.py)
jaxlib/gpu/BUILD (export aiter sources)
jaxlib/rocm/BUILD (aiter library + nanobind targets)
jaxlib/tools/build_wheel.py (include gpu_aiter.py in jaxlib wheel)
jaxlib/tools/build_gpu_kernels_wheel.py (include _aiter.so in plugin wheel)
jaxlib/tools/BUILD.bazel (aiter runtime .so in plugin wheel)