Skip to content

aiter mha kernels (ASM+CK) integration#747

Open
zahiqbal wants to merge 1 commit intoaiter-mhafrom
aiter-mha-integration-draft-1
Open

aiter mha kernels (ASM+CK) integration#747
zahiqbal wants to merge 1 commit intoaiter-mhafrom
aiter-mha-integration-draft-1

Conversation

@zahiqbal
Copy link
Copy Markdown

@zahiqbal zahiqbal commented Apr 7, 2026

Integrated AMD's AITER library to provide high-performance multi-head
attention (MHA) forward and backward kernels on ROCm GPUs. AITER
dispatches internally between CK (Composable Kernel) and hand-tuned
ASM v3 assembly kernels for optimal performance on supported
architectures (e.g.gfx942, gfx950 + ).
Public API:

  • jax._src.aiter.flash_attn_func: batch flash attention with custom_vjp
  • jax._src.aiter.flash_attn_varlen: variable-length flash attention with custom_vjp
    Both APIs support causal masking, sliding window attention, dropout,
    ALiBi, GQA/MQA head layouts, and non-standard head dimensions
    (automatically padded to multiples of 8).
    Implementation:
  • C++ FFI handlers (hip_aiter_mha_fwd.cc, hip_aiter_mha_bwd.cc)
    wrap aiter::mha_fwd / aiter::mha_bwd with unified batch/varlen
    dispatch based on tensor rank (4D=batch, 3D=varlen).
  • Common utilities (hip_aiter_mha_common_utils.{h,cc}) provide
    stride calculation, mask/bias construction, MQA/GQA reduction
    kernels, and RNG state management.
  • Nanobind module (hip_aiter.cc) registers FFI handler symbols
    hip_mha_fwd_ffi and hip_mha_bwd_ffi.
  • Python layer (aiter_mha.py) implements custom_vjp wrappers,
    head-dim padding, ASM v3 eligibility checks, and gfx950-specific
    guards.
  • Plugin loader (gpu_aiter.py) discovers _aiter.so from the
    jax-rocm plugin wheel via import_from_plugin.
    ASM v3 kernel eligibility on gfx950:
  • Forward: disabled for head_dim >= 96 (kernel symbols missing)
  • Backward: disabled for head_dim >= 96 (kernel symbols missing)
  • Also disabled for dropout, GQA/MQA, bias, SWA, and
    causal with sq > sk configurations.
  • Fallback to CK kernels is automatic and transparent.
    New files:
    jax/_src/aiter/init.py
    jax/_src/aiter/aiter_mha.py
    jaxlib/gpu_aiter.py
    jaxlib/gpu/hip_aiter.h
    jaxlib/gpu/hip_aiter.cc
    jaxlib/gpu/hip_aiter_mha_fwd.cc
    jaxlib/gpu/hip_aiter_mha_bwd.cc
    jaxlib/gpu/hip_aiter_mha_common_utils.h
    jaxlib/gpu/hip_aiter_mha_common_utils.cc
    third_party/aiter/{BUILD,include/}
    tests/test_aiter_mha.py
    Modified files:
    jax/_src/BUILD (add aiter target)
    jax/_src/lib/init.py (import gpu_aiter)
    jaxlib/BUILD (add gpu_aiter.py)
    jaxlib/gpu/BUILD (export aiter sources)
    jaxlib/rocm/BUILD (aiter library + nanobind targets)
    jaxlib/tools/build_wheel.py (include gpu_aiter.py in jaxlib wheel)
    jaxlib/tools/build_gpu_kernels_wheel.py (include _aiter.so in plugin wheel)
    jaxlib/tools/BUILD.bazel (aiter runtime .so in plugin wheel)

@phambinhfin
Copy link
Copy Markdown

phambinhfin commented Apr 7, 2026

Hi, do you have test cases for your implentation?
PS : I saw the folder test there, let me check

@zahiqbal
Copy link
Copy Markdown
Author

zahiqbal commented Apr 7, 2026

Hi, do you have test cases for your implentation?

tests/test_aiter_mha.py is the file

I attached the shell script that runs all use cases.
run_mha_tests.sh

Results are like
root@smci355-ccs-aus-m12-05:/workspaces/aiter-work/jax_review/jax# bash run_mha_tests.sh

AITer MHA Test Suite
Tue Apr 7 17:18:29 UTC 2026
GPU: gfx950
unknown

Running test_batch_fwd_shape ... PASSED (50 passed)
Running test_batch_fwd_accuracy ... PASSED (22 passed)
Running test_batch_bwd_shape ... PASSED (50 passed)
Running test_batch_bwd_accuracy ... PASSED (11 passed)
Running test_dropout_fwd ... PASSED (4 passed)
Running test_dropout_bwd ... PASSED (2 passed)
Running test_swa_fwd ... PASSED (2 passed)
Running test_swa_bwd ... PASSED (2 passed)
Running test_bias_fwd ... PASSED (2 passed)
Running test_bias_bwd ... PASSED (2 passed)
Running test_alibi_fwd ... PASSED (4 passed)
Running test_alibi_causal ... PASSED (2 passed)
Running test_return_lse ... PASSED (1 passed)
Running test_return_attn_probs_with_dropout ... PASSED (1 passed)
Running test_padded_head_dim_fwd ... PASSED (9 passed)
Running test_padded_head_dim_bwd ... PASSED (2 passed)
Running test_deterministic_consistency ... PASSED (2 passed)
Running test_deterministic_bwd ... PASSED (1 passed)
Running test_varlen_fwd ... PASSED (18 passed)
Running test_varlen_bwd ... PASSED (9 passed)
Running test_decode_sq1_fwd_bwd ... PASSED (1 passed)
Running test_sq_gt_sk_nomask ... PASSED (1 passed)
Running test_sq_gt_sk_causal ... PASSED (1 passed)
Running test_large_batch ... PASSED (1 passed)
Running test_single_head ... PASSED (1 passed)
Running test_many_heads ... PASSED (1 passed)
Running test_v3_bwd_sq_gt_sk_causal ... PASSED (1 passed)
Running test_1024_1023_causal ... PASSED (6 passed)
Running test_mqa_gqa_bwd_routing ... PASSED (1 passed)
Running test_varlen_large_sk_causal ... PASSED (2 passed)
Running test_gfx950_1block_override ... PASSED (1 passed)
Running test_swa_not_v3_bwd ... PASSED (1 passed)
Running test_all_head_dims_bwd ... PASSED (4 passed)

RESULTS

Passed: 218
Failed: 0
Crashed: 0

@zahiqbal zahiqbal force-pushed the aiter-mha-integration-draft-1 branch from 0df02f7 to a6041a3 Compare April 7, 2026 19:31
Copy link
Copy Markdown

@i-chaochen i-chaochen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Code Review: AITER MHA Kernels (ASM+CK) Integration

Commit: a6041a3b2c8ad9d6f86f071855e8523a3a49894a
Author: Zahid Iqbal
Scope: 3,644 lines added across 24 files -- integrates AITER multi-head attention forward/backward kernels for ROCm (CK + ASM v3) into JAX via XLA FFI.


1. Architecture Overview

The commit introduces a well-structured layered integration:

  • Python API layer (jax/_src/aiter/aiter_mha.py): Public flash_attn_func and flash_attn_varlen with custom_vjp, dispatching to unified forward/backward wrappers.
  • FFI bridge (jaxlib/gpu/hip_aiter_mha_fwd.cc, hip_aiter_mha_bwd.cc): C++ handlers bound via XLA FFI, translating JAX buffers into aiter::mha_fwd_args / aiter::mha_bwd_args.
  • Third-party headers (third_party/aiter/): AITER's own mha_fwd.h, mha_bwd.h, ASM kernel loader infrastructure.
  • Nanobind glue (hip_aiter.cc): Registration of FFI targets.

The design of a single unified handler detecting batch (4D) vs varlen (3D) from tensor rank is clean and avoids handler proliferation.


2. Thread Safety and Data Race Issues (Critical Section)

2.1 CRITICAL: hipMemcpyAsync from stack-local buffer (use-after-scope risk)

In hip_aiter_mha_common_utils.cc, prepare_rng_state_for_fwd:

uint64_t host_rng[2] = {seed_value, offset_value};
hipError_t err =
    hipMemcpyAsync(rng_state->untyped_data(), host_rng,
                   2 * sizeof(int64_t), hipMemcpyHostToDevice, stream);

host_rng is a stack-allocated array. hipMemcpyAsync with hipMemcpyHostToDevice from pageable host memory is synchronous with respect to the host in the current HIP/CUDA runtime (the DMA engine must stage through a pinned bounce buffer), so this is de facto safe. However, this is an implementation-defined behavior -- future ROCm runtime changes could make this truly async, at which point host_rng would be read after it goes out of scope. The same pattern appears in the forward handler:

std::vector<float> neg_inf(n, -std::numeric_limits<float>::infinity());
HIP_CHECK(hipMemcpyAsync(lse->untyped_data(), neg_inf.data(),
                         n * sizeof(float), hipMemcpyHostToDevice, stream));

Recommendation: Either use hipMemcpy (synchronous) explicitly, or use pinned memory, or add a hipStreamSynchronize after the copy.

2.2 CRITICAL: Reading device pointers on the host

In prepare_rng_state_for_fwd:

const auto *gen_data = static_cast<const int64_t *>(gen->untyped_data());
seed_value = static_cast<uint64_t>(gen_data[0]);    // Host reads device memory!
offset_value = static_cast<uint64_t>(gen_data[1]);  // Host reads device memory!
VLOG(1) << "Using provided generator with seed: " << seed_value << ...;

If gen is a device buffer (which is the typical case in JAX -- all FFI buffers are device-side), dereferencing gen_data[0] from host code is undefined behavior. It will either segfault or return garbage. The subsequent hipMemcpyAsync(DeviceToDevice) is correct, but the VLOG will log incorrect values. The real concern is the UB itself.

Recommendation: Remove the host-side reads of gen_data[0]/gen_data[1], or use hipMemcpy to bring the values to host first if logging is needed.

2.3 HIGH: BwdDeviceBuffers destructor races with GPU work

In hip_aiter_mha_bwd.cc, BwdDeviceBuffers::~BwdDeviceBuffers() calls hipFree on buffers that were just passed to async kernels:

void free_all() {
    if (dbias_expanded) { hipFree(dbias_expanded); ... }
    if (dk_expanded) { hipFree(dk_expanded); ... }
    ...
}

hipFree performs an implicit device synchronization (waits for ALL streams), so the data won't actually be freed while in use. However:

  • This causes a full device sync on every backward pass, destroying any pipeline overlap.
  • The implicit sync behavior is an implementation detail that could change.
  • It makes the handler incompatible with CUDA Graphs / HIP Graphs and the kCmdBufferCompatible trait the handler claims.

Recommendation: Use a workspace pattern -- have JAX allocate the scratch space as additional output buffers, or use XLA's buffer allocation to pre-allocate scratch.

2.4 MEDIUM: Non-thread-safe get_gpu_arch() in header

In aiter_hip_common.h:

static const std::string get_gpu_arch() {
    hipDevice_t dev;
    HIP_CALL(hipGetDevice(&dev));
    HIP_CALL(hipGetDeviceProperties(&dev_prop, dev));
    ...
}

This is marked static in a header, so every translation unit gets its own copy. hipGetDevice returns the thread-local current device. If this is called from different threads with different active devices, each will get a different answer, which is the intended behavior. However, the combined pattern with get_num_cu_func():

static uint32_t get_num_cu_func() {
    static const uint32_t num_cu = get_num_cu_local();  // initialized once
    return num_cu;
}

This caches the CU count from whichever device happens to be current during the first call. If the system has heterogeneous GPUs, subsequent calls from threads using different devices will get the wrong CU count. C++11 guarantees thread-safe initialization of static locals, but the value may be wrong for non-default devices.

2.5 MEDIUM: RNG seed collision under concurrency

When no generator is provided, the seed is derived from a timestamp:

seed_value = static_cast<uint64_t>(timestamp) ^ static_cast<uint64_t>(dev_idx);

Two concurrent forward passes on the same device within the same microsecond will produce identical seeds, leading to identical dropout masks. This is a correctness issue for training with dropout.

Recommendation: Use a monotonic atomic counter or incorporate the stream pointer / buffer address as additional entropy.

2.6 LOW: _make_fwd_call / _make_bwd_call are not cached

Each call to mha_fwd_unified / mha_bwd_unified creates a fresh jax.ffi.ffi_call + jax.jit wrapper:

fn = _make_fwd_call(out_shape, lse_shape, p_shape, rng_shape, q.dtype)

While JAX's tracing and compilation have their own caches, repeatedly creating new ffi_call objects and JIT wrappers adds unnecessary overhead. These should be memoized by (out_shape, lse_shape, p_shape, rng_shape, dtype).


3. Memory Management Issues

3.1 hipMalloc / hipFree on every backward call

The backward handler allocates up to 5 device buffers per invocation:

HIP_CHECK(hipMalloc(&bufs.dq_acc, dq_acc_bytes));       // always
HIP_CHECK(hipMalloc(&bufs.dk_expanded, dk_sz));          // if MQA/GQA
HIP_CHECK(hipMalloc(&bufs.dv_expanded, dv_sz));          // if MQA/GQA
HIP_CHECK(hipMalloc(&bufs.dbias_expanded, dbias_sz));    // if has_dbias
HIP_CHECK(hipMalloc(&bufs.dummy_rng, 2*sizeof(uint64_t)));// if no rng

hipMalloc is synchronous and expensive (~microseconds to milliseconds). For training workloads where backward is called millions of times, this is a significant performance bottleneck.

Recommendation: Allocate these as JAX-managed output buffers in the Python wrapper (passing workspace shapes through the FFI), or use a persistent memory pool.

3.2 Large stack-allocated vector for LSE initialization

std::vector<float> neg_inf(n, -std::numeric_limits<float>::infinity());

For large batch/sequence lengths, n can be very large (e.g., batch=64, heads=32, seq=2048 => n=4M floats = 16MB heap allocation). This should use a fill kernel instead.


4. Correctness Issues

4.1 Wrong sequence length passed to v3 eligibility check

In _flash_attn_backward:

_, sq, hq, dq = q.shape
_, sk, hk, _ = k.shape
...
use_v3 = _compute_v3_eligibility_bwd(
    dropout_p, hq, hk, dq, causal, wl, wr, bias, sq, gfx  # <-- sq, not sk
)

The parameter is named sq_or_max_sk, and the check is:

if causal and gfx == "gfx950" and sq_or_max_sk > 256:
    use_v3 = False

For the batch path, sq (query sequence length) is being passed where sk (key sequence length) is expected. When sq != sk (cross-attention), this can incorrectly enable or disable ASM v3. Compare with the varlen backward which correctly passes max_seqlen_k:

use_v3 = _compute_v3_eligibility_bwd(
    res_dp, hq, hk, dq, causal, window_size[0], window_size[1],
    None, max_seqlen_k, gfx  # <-- correct: max_seqlen_k
)

This is a bug for cross-attention cases on gfx950 with sq <= 256 < sk (causal).

4.2 kCmdBufferCompatible trait is incorrect

Both handlers are declared with:

{xla::ffi::Traits::kCmdBufferCompatible}

But the backward handler calls hipMalloc, hipFree, hipPointerGetAttributes, and other operations that are not compatible with command buffer recording. The forward handler similarly calls hipPointerGetAttributes via device_from_ptr. This trait should be removed or the handlers refactored.

4.3 Variable shadowing in mha_fwd_unified

def mha_fwd_unified(q, k, v, ...):
    ...
    dq = q.shape[-1]  # shadows name 'dq' which in MHA context means gradient of q
    use_v3_fwd = not (get_gfx() == "gfx950" and dq >= 96)

While not a bug, dq meaning "head dimension of q" vs. "gradient of q" is confusing in MHA code. Consider renaming to hdim_q.

4.4 _flash_attn_func_bwd gradient count vs. custom_vjp nondiff argnums

@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 9, 10, 11, 12, 13))
def flash_attn_func(q, k, v, dropout_p, softmax_scale, causal, window_size,
                    bias, alibi_slopes, deterministic, return_lse,
                    return_attn_probs, cu_seqlens_q, cu_seqlens_kv):

Differentiable args: indices 0,1,2,7,8 = q, k, v, bias, alibi_slopes. The backward returns:

return (dq, dk, dv, dbias, None)  # 5 values for 5 diff args

This is correct but note that None is returned for alibi_slopes gradient. If someone passes alibi_slopes that requires grad in the future, this will silently drop gradients. Consider adding a check or comment.

4.5 Missing softmax_scale default in varlen path differs from batch

In flash_attn_varlen, the default scale uses q.shape[-1] which is the head dimension for 3D input. In flash_attn_func, it uses q.shape[-1] which is also head dimension for 4D input. Both are correct, but the defensive check should be unified.


5. Error Handling Issues

5.1 HIP_CHECK calls std::abort() -- no graceful degradation

inline void hipCheck(hipError_t err, const char *file, int line) {
    if (err != hipSuccess) {
        LOG(ERROR) << ...;
        std::abort();
    }
}

In the backward handler, if hipMalloc fails (e.g., OOM), the process is killed. For an FFI handler, it would be better to return ffi::Error with an appropriate error code and let JAX handle the error gracefully.

5.2 Bare catch (...) swallows errors

In hip_aiter_mha_bwd.cc:

try {
    auto [s, o] = mha_utils::get_rng_seed_offset_ptrs(rng_state_, dropout_p);
    seed_ptr = s; offset_ptr = o;
} catch (...) { /* fallthrough to dummy */ }

This silently swallows any exception, including errors that indicate real problems (e.g., buffer too small). At minimum, log the exception.

5.3 dtype_to_string throws std::runtime_error

inline std::string dtype_to_string(xla::ffi::DataType dtype) {
    ...
    default:
        throw std::runtime_error("Unsupported dtype for MHA");
}

Throwing across FFI boundaries is undefined behavior in many contexts. Should return ffi::Error instead.


6. Python-Side Issues

6.1 get_gfx() shells out to rocminfo at import time

@functools.lru_cache(maxsize=1)
def get_gfx() -> str:
    ...
    result = subprocess.run([os.path.realpath(rocminfo)], ...)

This is called during tracing, which happens at JIT compilation time. It's expensive, fragile (depends on rocminfo being on PATH), and won't work in sandboxed environments. The lru_cache helps but the first call can still be slow.

Recommendation: Query the GPU architecture through the HIP runtime API in C++ and expose it via an FFI call or Python binding, rather than shelling out.

6.2 FFI target name mismatch

In hip_aiter.cc:

dict[JAX_GPU_PREFIX "_mha_fwd_ffi"] = EncapsulateFfiHandler(aiter_mha_fwd);

With JAX_GPU_PREFIX = "hip", this registers as "hip_mha_fwd_ffi". But the Python side calls:

jax.ffi.ffi_call("hip_mha_fwd_ffi", ...)

The name must match exactly. If JAX_GPU_PREFIX is ever changed (e.g., to "rocm"), this silently breaks. Consider using a constant or deriving the name.

6.3 Redundant import

import functools
...
from functools import partial

Both functools and functools.partial are imported. Minor cleanup opportunity.

6.4 Missing newline at end of file

Multiple files (aiter_mha.py, __init__.py, several C++ files) are missing the trailing newline. This causes \ No newline at end of file in diffs and can confuse some tools.


7. Test Coverage Assessment

The test file (test_aiter_mha.py, 551 lines) is quite comprehensive:

Strengths:

  • TE-style tolerance computation (eps^(2/3)) is appropriate for mixed-precision
  • Good parametric coverage: head dims 32-256, both dtypes, MHA/GQA/MQA
  • Regression guards for specific historical bugs (good practice)
  • Both batch and varlen paths tested
  • Edge cases: sq=1 (decode), sq>sk, sq<sk, large batch, single head

Gaps:

  • No concurrent execution tests -- no test for thread-safety or multi-stream correctness
  • No dropout numerical correctness test -- only shape/crash tests (acknowledged in docstring, but this means dropout correctness is untested)
  • No test for the logits_soft_cap parameter -- it's accepted but never tested
  • No test for zero_tensors=True -- the flag is accepted but the default False path is always taken
  • No test for min_seqlen_q parameter
  • No negative tests -- no tests for invalid inputs (wrong dtypes, mismatched shapes, etc.)
  • No multi-GPU tests
  • Reference implementation in tests doesn't support GQA/MQA, so GQA/MQA accuracy is only tested for shape/finiteness, not numerical correctness

8. Build System Issues

8.1 Missing .so files in third_party/aiter/

The BUILD file references:

cc_import(
    name = "mha_fwd_so",
    shared_library = "libmha_fwd.so",
)

But libmha_fwd.so and libmha_bwd.so are not included in the commit. There's no documentation for how these shared libraries should be built or obtained.

8.2 linkopts = ["-Wl,-rpath,$$ORIGIN"]

This sets RPATH to look for .so files in the same directory as the binary. This is correct for deployment but needs the .so files to be co-located at runtime.


9. Summary of Findings by Severity

Severity Count Key Items
Critical 2 Host reads of device pointers (UB); hipMemcpyAsync from stack buffers
High 2 hipFree in destructor syncs device (perf); hipMalloc per backward call
Medium 3 Wrong sq vs sk in v3 eligibility; incorrect kCmdBufferCompatible; RNG seed collision
Low 6+ Missing tests for dropout/logits_soft_cap/zero_tensors; std::abort on OOM; swallowed exceptions; style issues

The most actionable items are:

  1. Fix the host-side device pointer reads in prepare_rng_state_for_fwd
  2. Fix sq vs sk in _flash_attn_backward's v3 eligibility check
  3. Remove kCmdBufferCompatible from both handlers
  4. Replace per-call hipMalloc/hipFree with a workspace pattern for the backward pass

Comment on lines +86 to +104
def _pad_to_multiple_of_8(q, k, v):
"""Pad head dimensions of Q/K/V to the next multiple of 8 if needed.

Returns (q_padded, k_padded, v_padded, hd_q_original, hd_v_original).
"""
hd_q = q.shape[-1]
hd_v = v.shape[-1]
q_p, k_p, v_p = q, k, v
ndim = q.ndim
if hd_q % 8 != 0:
pad = 8 - hd_q % 8
pw = tuple((0, 0) for _ in range(ndim - 1)) + ((0, pad),)
q_p = jnp.pad(q, pw)
k_p = jnp.pad(k, pw)
if hd_v % 8 != 0:
pad = 8 - hd_v % 8
pw = tuple((0, 0) for _ in range(ndim - 1)) + ((0, pad),)
v_p = jnp.pad(v, pw)
return q_p, k_p, v_p, hd_q, hd_v
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this padding for enabling the ck/aiter flow or for pure performance? We didn't see configs with hdim not 8 multiples from external customers before so we didn't do it in TE. But I was curious about the overall performance comparison between padding-->ck/aiter-->unpadding vs passing original config to ck/aiter

Copy link
Copy Markdown
Author

@zahiqbal zahiqbal Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The padding is a hard functional requirement, not a performance optimization. The CK and ASM v3 kernels will produce incorrect results or crash if head dimensions are not multiples of 8

The upstream AITer mha.py (the PyTorch version) has the identical logic in every public entry point, at line 1760:

And then after the kernel runs, the output is sliced back: line 1845
dq = dq[..., :head_size_q_og] # We could have padded the head dimension
dk = dk[..., :head_size_q_og]
dv = dv[..., :head_size_v_og]

This is done in flash_attn_func, flash_attn_varlen_func, and flash_attn_fp8_pertensor_func — every single public API.
ASM v3 eligibility explicitly requires hdim % 8 == 0
aiter: mha.py:line1543, ret &= hdim_q >= 64 and hdim_q <= 192 and hdim_q % 8 == 0

Comment on lines +107 to +125
def _compute_v3_eligibility_bwd(
dropout_p, hq, hk, dq, causal, wl, wr, bias, sq_or_max_sk, gfx
):
"""Shared ASM v3 eligibility check for backward pass (batch & varlen)."""
swa = (wl > 0) or (wr >= 0 and wr != -1)
use_v3 = True
if dropout_p > 0:
use_v3 = False
if hq != hk:
use_v3 = False
if bias is not None and bias.size > 0:
use_v3 = False
if swa:
use_v3 = False
if causal and gfx == "gfx950" and sq_or_max_sk > 256:
use_v3 = False
if gfx == "gfx950" and dq >= 96:
use_v3 = False
return use_v3
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AITER had its internal v3 api checking and will fallback to v2 even if you requested v3 asm

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

33 test crashed without this check... v3_api_check is user controlled in aiter, and controlled in benchmark tests, i think switch between v3 and v2

gen = _empty(jnp.int64)

rng_shape = (2,)
bf16_cvt = 0 if get_gfx() == "gfx950" else 1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bf16_cvt can also be set to different values in gfx942, refer to TE readme for more details: https://github.com/ROCm/TransformerEngine?tab=readme-ov-file#aiter-fa-v3-kernels

Comment on lines +40 to +44
if (dbias_expanded) { hipFree(dbias_expanded); dbias_expanded = nullptr; }
if (dummy_rng) { hipFree(dummy_rng); dummy_rng = nullptr; }
if (dq_acc) { hipFree(dq_acc); dq_acc = nullptr; }
if (dk_expanded) { hipFree(dk_expanded); dk_expanded = nullptr; }
if (dv_expanded) { hipFree(dv_expanded); dv_expanded = nullptr; }
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to request a buffer from jax? In other words, the jax will manage those extra buffers like dq_acc, softmax_lse_buffer and so on? Calling hipMalloc and hipFree could be heavy for e2e training if you need to run this every iteration

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants