From 75daf1e30ee4073a39b7a09fbd2adef987e240b9 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Wed, 3 Sep 2025 20:45:39 +0000 Subject: [PATCH 1/4] [ROCm] Bump AOTriton to 0.11b (#161754) Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.11b: * Invoke AITER Assembly kernels on gfx942/gfx950 when inputs meet requirements - AITER ASM kernels deliver over 500TFLOPS training performance. See [AOTriton 0.11b Release Page](https://github.com/ROCm/aotriton/releases/tag/0.11b) for more details. * Now returns natural based `logsumexp` tensor, matching CUDA's behavior - PR #156903 is reverted in this PR as well since it is not needed anymore. * Enables `CausalVariant.LOWER_RIGHT` The build system changes drastically along with new packaging scheme of AOTriton 0.11 * AOTriton 0.11 packs GPU images separately from AOTriton runtime * `aotriton.cmake` now selectively downloads image packs according to `PYTORCH_ROCM_ARCH` * `aotriton.cmake` now only use pre-compiled runtime library that exactly matches the ROCM in the build environment. For PyTorch builds with ROCm versions not listed in the file, the build process will build AOTriton runtime without GPU images from source - This avoids any further ABI breaks like ROCM 6.4 -> 7.0 - recursive git clone is disabled since building AOTriton runtime does not require submodules. Bug fixes: * Fix a kernel bug introduced when implementing SWA Known Problems: * gfx1100 target (Radeon RX 7000 Series) is moved back to experimental status due to accuracy issues. Triton compiler fixes are needed to restore the support status. * Enabling TF32 tests affects accuracy for later non-TF32 tests on ROCM 7.0. This issue is under investigation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161754 Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily --- .../native/transformers/cuda/attention.cu | 57 +++++- .../transformers/cuda/attention_backward.cu | 70 +++++++- .../native/transformers/cuda/sdp_utils.cpp | 23 ++- .../transformers/hip/aotriton_adapter.h | 59 ++++++ .../transformers/hip/aotriton_versions.h | 20 +++ .../hip/flash_attn/aot/mha_all_aot.hip | 104 ++++++----- .../transformers/hip/gemm_kernel_utils.h | 32 ++++ cmake/External/aotriton.cmake | 170 ++++++++++++------ test/test_transformers.py | 75 +++++--- torch/testing/_internal/common_cuda.py | 10 +- torch/utils/_triton.py | 11 ++ 11 files changed, 487 insertions(+), 144 deletions(-) create mode 100644 aten/src/ATen/native/transformers/hip/aotriton_versions.h create mode 100644 aten/src/ATen/native/transformers/hip/gemm_kernel_utils.h diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 80049aa9a832..dae3332430f1 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -1406,12 +1406,15 @@ std::tuple _efficient_ at::Tensor v_t = value.transpose(1, 2); at::Tensor output_t = res.transpose(1, 2); bool is_causal; - if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { - is_causal = true; - } else if (static_cast(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) { + if (static_cast(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) { is_causal = false; } else { - TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now"); + is_causal = true; +#if AOTRITON_V3_API == 0 + if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) != custom_mask_type) { + TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now"); + } +#endif } at::Tensor atomic_counter; @@ -1436,7 +1439,51 @@ std::tuple _efficient_ auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr() : nullptr); auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); hipError_t err; // TODO: Error handling - if (seqstart_q.has_value()) { + if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef +#if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions + using aotriton::v3::flash::CausalType; + using aotriton::v3::flash::VarlenType; + using aotriton::v3::flash::WindowValue; + aotriton::v3::flash::attn_fwd_params params; + params.Q = mk_aotensor(q_t, "q"); + params.K = mk_aotensor(k_t, "k"); + params.V = mk_aotensor(v_t, "v"); + params.Sm_scale = softmax_scale; + params.L = compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2; + params.Out = mk_aotensor(output_t, "Out"); + params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty + params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty + params.dropout_p = dropout_p; + params.philox_seed_ptr = seed; + params.philox_offset1 = offset1; + params.philox_offset2 = offset2; + params.philox_seed_output = seed_output; + params.philox_offset_output = offset_output; + params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax"); + params.persistent_atomic_counter = persistent_counter; + params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; + if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { + params.window_left = WindowValue::TopLeftAligned; + params.window_right = WindowValue::TopLeftAligned; + } else if (static_cast(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) { + params.window_left = WindowValue::BottomRightAligned; + params.window_right = WindowValue::BottomRightAligned; + } + if (bias.has_value()) { + params.B = mk_aotensor(bias.value(), "bias"); + } + if (seqstart_q.has_value()) { + params.varlen_type = VarlenType::CompactVarlen; + params.cu_seqlens_q = mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"); + params.cu_seqlens_k = mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"); + } else { + params.varlen_type = VarlenType::None; + } + err = aotriton::v3::flash::attn_fwd(params, + aotriton::v3::flash::attn_fwd_params::kVersion, + stream); +#endif // AOTRITON_V3_API + } else if (seqstart_q.has_value()) { // varlen aka nested tensor err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"), mk_aotensor(k_t, "k"), diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 8940bea9a27f..0339f6eec055 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -24,6 +24,7 @@ #include #include #else +#include #include #include #include @@ -45,6 +46,7 @@ #include #include #else +#include // MemoryEfficient Attention Specific Imports for ROCM #ifndef DISABLE_AOTRITON #include @@ -482,12 +484,15 @@ _efficient_attention_backward( } const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); bool is_causal; - if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { - is_causal = true; - } else if (static_cast(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) { + if (static_cast(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) { is_causal = false; } else { - TORCH_CHECK(false, "[_efficient_attention_backward] Unsupported mask type in AOTriton, for now"); + is_causal = true; +#if AOTRITON_V3_API == 0 + if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) != custom_mask_type) { + TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now"); + } +#endif } at::Tensor q_t = query.permute({0,2,1,3}); at::Tensor k_t = key.permute({0,2,1,3}); @@ -506,7 +511,62 @@ _efficient_attention_backward( using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::cast_dtype; aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype())); - if (cu_seqlens_q.has_value()) { + if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef +#if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions + using aotriton::v3::flash::CausalType; + using aotriton::v3::flash::VarlenType; + using aotriton::v3::flash::WindowValue; + aotriton::v3::flash::attn_bwd_params params; + params.Q = mk_aotensor(q_t, "q"); + params.K = mk_aotensor(k_t, "k"); + params.V = mk_aotensor(v_t, "v"); + params.B = bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4; + params.Sm_scale = softmax_scale; + params.Out = mk_aotensor(out_t, "out"); + params.DO = mk_aotensor(dout_t, "dout"); + params.DK = mk_aotensor(dk_t, "dk"); + params.DV = mk_aotensor(dv_t, "dv"); + params.DQ = mk_aotensor(dq_t, "dq"); + params.DB = bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4; + params.L = mk_aotensor<2>(softmax_lse, "L"); + params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty + params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty + params.dropout_p = float(dropout_p); + params.philox_seed_ptr = mk_aoscalartensor(philox_seed); + params.philox_offset1 = mk_aoscalartensor(philox_offset); + params.philox_offset2 = 0; + params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; + if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { + params.window_left = WindowValue::TopLeftAligned; + params.window_right = WindowValue::TopLeftAligned; + } else if (static_cast(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) { + params.window_left = WindowValue::BottomRightAligned; + params.window_right = WindowValue::BottomRightAligned; + } +#if AOTRITON_ALWAYS_V3_API + using sdp::aotriton_adapter::mklazy_empty_like; + using sdp::aotriton_adapter::mklazy_fp32zeros; + using sdp::aotriton_adapter::LazyTensorContext; + LazyTensorContext lazy_delta { .like_tensor = softmax_lse, .tensor_name = "delta" }; + LazyTensorContext lazy_dq_acc { .like_tensor = dq_t, .tensor_name = "dq_acc" }; + params.D = mklazy_empty_like<2>(&lazy_delta); + params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc); +#else + at::Tensor delta = at::empty_like(softmax_lse).contiguous(); + params.D = mk_aotensor<2>(delta, "delta"); +#endif + if (cu_seqlens_q.has_value()) { + params.varlen_type = VarlenType::CompactVarlen; + params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q"); + params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k"); + } else { + params.varlen_type = VarlenType::None; + } + err = aotriton::v3::flash::attn_bwd(params, + aotriton::v3::flash::attn_bwd_params::kVersion, + stream); +#endif // AOTRITON_V3_API + } else if (cu_seqlens_q.has_value()) { at::Tensor delta = at::empty_like(softmax_lse).contiguous(); // varlen aka Nested tensor err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"), diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 45b4cf118c1b..4570ac8f03fc 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #if AT_CUDNN_ENABLED() #include @@ -25,9 +26,12 @@ #if USE_ROCM #if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION) +#include #include #define USE_ROCM_ATTENTION 1 #endif +#else +#define USE_ROCM_ATTENTION 0 #endif // Avoid potential compiler -Wall -Werror complains undefined macro @@ -112,9 +116,24 @@ int64_t minimum_gemm_alignment(sdp_params const& params) { // caller_is_meff is added to make the TORCH_WARN message showing the correct result template bool check_head_dim_size_flash(sdp_params const& params, bool debug) { -#if USE_ROCM_ATTENTION && AOTRITON_VERSION_MINOR >= 9 +#if USE_ROCM_ATTENTION // AOTriton 0.9+ supports head_dim up to 512 - const auto max_size = c10::SymInt(512); + const static auto max_hdim = []() { +#if AOTRITON_VERSION_CURRENT == AOTRITON_VERSION_INT(0, 11) + // gfx11xx only support hdim <= 256 on AOTriton 0.11 + auto dprops = at::cuda::getCurrentDeviceProperties(); + const c10::basic_string_view arch(dprops->gcnArchName); + if (arch.starts_with("gfx11")) { + return 256; + } +#endif // AOTriton 0.11 +#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 9) + return 512; +#else + return 256; +#endif + }(); + const auto max_size = c10::SymInt(max_hdim); #else // All head_dim sizes must be equal and less than 256 const auto max_size = c10::SymInt(256); diff --git a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h index b38122248db8..a80d4053b27b 100644 --- a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h +++ b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h @@ -2,8 +2,12 @@ #ifdef USE_ROCM +// Expect to be included after headers of at::zeros_like and at::empty_like + #include #include +#include +#include //////////////////////////////////////////////////////////////////////////////// // Common macros copied from cuda/mem_eff_attention/gemm_kernel_utils.h @@ -111,6 +115,61 @@ inline aotriton::TensorView<0> mk_atomictensor(const int32_t* ptr) aotriton::DType::kInt32); } +#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 11) + +struct LazyTensorContext { + at::Tensor like_tensor; + std::string_view tensor_name; + at::Tensor tensor; +}; + +template +struct LazyTensorFunctions : public LazyTensorContext { + static aotriton::TensorView acquire(void* cookie) { + auto ctx = (LazyTensorContext*)cookie; + if (!ctx->tensor.defined()) { + auto q = ctx->like_tensor; + if constexpr (kRequireZeros) { + ctx->tensor = at::zeros(q.sizes(), + q.options().dtype(at::kFloat)); + } else { + ctx->tensor = at::empty_like(q); + } + } + return mk_aotensor(ctx->tensor, ctx->tensor_name); + } + + static void dispose(void* cookie) { + } +}; + +template +aotriton::LazyTensor mklazy_common(LazyTensorContext* cookie) +{ + using LTF = LazyTensorFunctions; + return aotriton::LazyTensor { + .cookie = cookie, + .acquire = <F::acquire, + .dispose = <F::dispose + }; +} + +template +auto mklazy_empty_like(LazyTensorContext* cookie) +{ + return mklazy_common(cookie); +} + + +// Note: this will not keep the original strides +template +auto mklazy_fp32zeros(LazyTensorContext* cookie) +{ + return mklazy_common(cookie); +} + +#endif // >= 0.11 + } // namespace aotriton_adapter } // namespace sdp diff --git a/aten/src/ATen/native/transformers/hip/aotriton_versions.h b/aten/src/ATen/native/transformers/hip/aotriton_versions.h new file mode 100644 index 000000000000..2f5d3f0e1222 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/aotriton_versions.h @@ -0,0 +1,20 @@ +#pragma once + +#ifdef USE_ROCM + +#define AOTRITON_VERSION_INT(x, y) (x * 100 + y) +#define AOTRITON_VERSION_CURRENT (AOTRITON_VERSION_MAJOR * 100 + AOTRITON_VERSION_MINOR) + +#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 11) +#define AOTRITON_ALWAYS_V3_API 1 +#else +#define AOTRITON_ALWAYS_V3_API 0 +#endif + +#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 10) +#define AOTRITON_V3_API 1 +#else +#define AOTRITON_V3_API 0 +#endif + +#endif diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index 1908096e2f6f..8540fa992d75 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -60,20 +60,13 @@ #include // AOTriton headers -#include #include #include -#if AOTRITON_VERSION_MINOR < 9 +#if AOTRITON_VERSION_CURRENT < AOTRITON_VERSION_INT(0, 9) #error "This adaptor code is only tested with AOTriton >= 0.9" #endif -#if (AOTRITON_VERSION_MAJOR * 100 + AOTRITON_VERSION_MINOR) >= 10 -#define V3_API 1 -#else -#define V3_API 0 -#endif - namespace pytorch_flash { namespace { @@ -93,15 +86,15 @@ calculate_swa(std::optional window_size_left, int max_seqlen_q, int max_seqlen_k, bool is_causal) { -#if V3_API // SWA is exposed through V3 API +#if AOTRITON_V3_API // SWA is exposed through V3 API bool needs_swa = false; using aotriton::v3::flash::WindowValue; // Default values when std::optional window_size_left/right have no value int window_left = max_seqlen_q; int window_right = max_seqlen_k; if (is_causal) { - window_left = WindowValue::TopLeftAligned; - window_right = WindowValue::TopLeftAligned; + window_left = WindowValue::BottomRightAligned; + window_right = WindowValue::BottomRightAligned; } if (window_size_left.has_value() || window_size_right.has_value()) { needs_swa = true; @@ -254,10 +247,10 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x seqlen_q, seqlen_k, is_causal); -#if V3_API +#if AOTRITON_V3_API const bool uses_swa = needs_swa; #else - // When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be + // When AOTRITON_V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be // optimized out (hopefully). constexpr bool uses_swa = false; #endif @@ -276,8 +269,8 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr() : nullptr); auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr() : nullptr); auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); - if (uses_swa) { -#if V3_API + if (uses_swa || AOTRITON_ALWAYS_V3_API) { +#if AOTRITON_V3_API using aotriton::v3::flash::CausalType; using aotriton::v3::flash::VarlenType; aotriton::v3::flash::attn_fwd_params params; @@ -297,7 +290,7 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x params.philox_offset_output = offset_output; params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax"); params.persistent_atomic_counter = persistent_counter; - params.causal_type = CausalType::WindowedAttention; + params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; params.varlen_type = VarlenType::None; params.window_left = window_left; params.window_right = window_right; @@ -447,10 +440,10 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot max_seqlen_q, max_seqlen_k, is_causal); -#if V3_API +#if AOTRITON_V3_API const bool uses_swa = needs_swa; #else - // When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be + // When AOTRITON_V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be // optimized out (hopefully). constexpr bool uses_swa = false; #endif @@ -477,8 +470,8 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : nullscalar; auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : nullscalar; auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr()) : nullscalar; - if (uses_swa) { -#if V3_API + if (uses_swa || AOTRITON_ALWAYS_V3_API) { +#if AOTRITON_V3_API using aotriton::v3::flash::CausalType; using aotriton::v3::flash::VarlenType; aotriton::v3::flash::attn_fwd_params params; @@ -500,7 +493,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot params.philox_offset_output = offset_output; params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax"); params.persistent_atomic_counter = persistent_counter; - params.causal_type = CausalType::WindowedAttention; + params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; params.varlen_type = VarlenType::CompactVarlen; params.window_left = window_left; params.window_right = window_right; @@ -594,10 +587,6 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); - if (is_causal){ - TORCH_CHECK((seqlen_q == seqlen_k), "For backwards kernel seqlen_q must equal seqlen_k for causal kernels"); - } - TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); @@ -649,10 +638,10 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea seqlen_q, seqlen_k, is_causal); -#if V3_API +#if AOTRITON_V3_API const bool uses_swa = needs_swa; #else - // When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be + // When AOTRITON_V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be // optimized out (hopefully). constexpr bool uses_swa = false; #endif @@ -676,10 +665,9 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea hipError_t err; // TODO: Error handling using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; - if (uses_swa) { -#if V3_API + if (uses_swa || AOTRITON_ALWAYS_V3_API) { +#if AOTRITON_V3_API // Fused BWD does not support SWA - at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); using aotriton::v3::flash::CausalType; using aotriton::v3::flash::VarlenType; aotriton::v3::flash::attn_bwd_params params; @@ -689,21 +677,32 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea params.Sm_scale = softmax_scale; params.Out = mk_aotensor(out_t, "out"); params.DO = mk_aotensor(dout_t, "dout"); - params.DK = mk_aotensor(dq_t, "dq"); - params.DV = mk_aotensor(dk_t, "dk"); - params.DQ = mk_aotensor(dv_t, "dv"); + params.DQ = mk_aotensor(dq_t, "dq"); + params.DK = mk_aotensor(dk_t, "dk"); + params.DV = mk_aotensor(dv_t, "dv"); params.L = mk_aotensor<2>(softmax_lse_cont, "L"); - params.D = mk_aotensor<2>(delta, "delta"); params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty params.dropout_p = p_dropout; params.philox_seed_ptr = mk_aoscalartensor(philox_seed); params.philox_offset1 = mk_aoscalartensor(philox_offset); params.philox_offset2 = 0; - params.causal_type = CausalType::WindowedAttention; - params.varlen_type = VarlenType::None; + params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; params.window_left = window_left; params.window_right = window_right; + params.varlen_type = VarlenType::None; +#if AOTRITON_ALWAYS_V3_API + using sdp::aotriton_adapter::mklazy_empty_like; + using sdp::aotriton_adapter::mklazy_fp32zeros; + using sdp::aotriton_adapter::LazyTensorContext; + LazyTensorContext lazy_delta { .like_tensor = softmax_lse_cont, .tensor_name = "delta" }; + LazyTensorContext lazy_dq_acc { .like_tensor = dq_t, .tensor_name = "dq_acc" }; + params.D = mklazy_empty_like<2>(&lazy_delta); + params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc); +#else + at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); + params.D = mk_aotensor<2>(delta, "delta"); +#endif err = aotriton::v3::flash::attn_bwd(params, aotriton::v3::flash::attn_bwd_params::kVersion, stream); @@ -838,7 +837,6 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size CHECK_SHAPE(cu_seqlens_k, batch_size + 1); at::Tensor softmax_lse_cont = softmax_lse.view({batch_size * num_heads, max_seqlen_q}).contiguous(); - at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); at::Tensor q_padded, k_padded, v_padded; q_padded = q.unsqueeze(0).transpose(1, 2); @@ -896,10 +894,10 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size max_seqlen_q, max_seqlen_k, is_causal); -#if V3_API +#if AOTRITON_V3_API const bool uses_swa = needs_swa; #else - // When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be + // When AOTRITON_V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be // optimized out (hopefully). constexpr bool uses_swa = false; #endif @@ -919,8 +917,8 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size hipError_t err; // TODO: Error handling using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; - if (uses_swa) { -#if V3_API + if (uses_swa || AOTRITON_ALWAYS_V3_API) { +#if AOTRITON_V3_API using aotriton::v3::flash::CausalType; using aotriton::v3::flash::VarlenType; aotriton::v3::flash::attn_bwd_params params; @@ -930,11 +928,10 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size params.Sm_scale = softmax_scale; params.Out = mk_aotensor(out_t, "out"); params.DO = mk_aotensor(dout_t, "dout"); - params.DK = mk_aotensor(dq_padded, "dq"); - params.DV = mk_aotensor(dk_padded, "dk"); - params.DQ = mk_aotensor(dv_padded, "dv"); + params.DK = mk_aotensor(dk_padded, "dk"); + params.DV = mk_aotensor(dv_padded, "dv"); + params.DQ = mk_aotensor(dq_padded, "dq"); params.L = mk_aotensor<2>(softmax_lse_cont, "L"); - params.D = mk_aotensor<2>(delta, "delta"); params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"); params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"); params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty @@ -943,17 +940,30 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size params.philox_seed_ptr = mk_aoscalartensor(philox_seed); params.philox_offset1 = mk_aoscalartensor(philox_offset); params.philox_offset2 = 0; - params.causal_type = CausalType::WindowedAttention; + params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; params.varlen_type = VarlenType::CompactVarlen; params.window_left = window_left; params.window_right = window_right; +#if AOTRITON_ALWAYS_V3_API + using sdp::aotriton_adapter::mklazy_empty_like; + using sdp::aotriton_adapter::mklazy_fp32zeros; + using sdp::aotriton_adapter::LazyTensorContext; + LazyTensorContext lazy_delta { .like_tensor = softmax_lse_cont, .tensor_name = "delta" }; + LazyTensorContext lazy_dq_acc { .like_tensor = dq_padded, .tensor_name = "dq_acc" }; + params.D = mklazy_empty_like<2>(&lazy_delta); + params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc); +#else + at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); + params.D = mk_aotensor<2>(delta, "delta"); +#endif err = aotriton::v3::flash::attn_bwd(params, aotriton::v3::flash::attn_bwd_params::kVersion, stream); -#endif +#endif // AOTRITON_ALWAYS_V3_API } else { using aotriton::v2::flash::attn_bwd_compact_varlen; using sdp::aotriton_adapter::cast_dtype; + at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"), mk_aotensor(k_padded, "k"), diff --git a/aten/src/ATen/native/transformers/hip/gemm_kernel_utils.h b/aten/src/ATen/native/transformers/hip/gemm_kernel_utils.h new file mode 100644 index 000000000000..c18744afc1ff --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/gemm_kernel_utils.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// This file is a trimmed version of cuda/mem_eff_attention/gemm_kernel_utils.h +#pragma once + +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK(TENSOR.is_contiguous()); + +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + TORCH_CHECK( \ + TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); + +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + TORCH_CHECK( \ + uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + TORCH_CHECK( \ + B < std::numeric_limits::max(), #B " overflows"); \ + } diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index 8b380d24f6c8..5d9158774654 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -9,94 +9,160 @@ if(NOT __AOTRITON_INCLUDED) # Replaces .ci/docker/aotriton_version.txt # Note packages information may have versions skipped (due to no ABI breaks) # But they must be listed from lower version to higher version - set(__AOTRITON_VER "0.10b") + set(__AOTRITON_VER "0.11b") set(__AOTRITON_MANYLINUX_LIST + "manylinux_2_28" # rocm6.2 "manylinux_2_28" # rocm6.3 "manylinux_2_28" # rocm6.4 "manylinux_2_28" # rocm7.0 ) set(__AOTRITON_ROCM_LIST + "rocm6.2" "rocm6.3" "rocm6.4" "rocm7.0" ) - set(__AOTRITON_CI_COMMIT "6fca155f4deeb8d9529326f7b69f350aeeb93477") + set(__AOTRITON_CI_COMMIT "972223c501ffc22068bb035ac5d64cf54318d895") set(__AOTRITON_SHA256_LIST - "861cd9f7479eec943933c27cb86920247e5b5dd139bc7c1376c81808abb7d7fe" # rocm6.3 - "acea7d811a2d3bbe718b6e07fc2a9f739e49eecd60b4b6a36fcb3fe8edf85d78" # rocm6.4 - "1e9b3dddf0c7fc07131c6f0f5266129e83ce2331f459fa2be8c63f4ae91b0f5b" # rocm7.0 + "6cae3d5de75ee205d22e088f7dfaab1227056d02ea67f29ccdbc09f2be4e8c8f" # rocm6.2 + "72a153549ea20707331e8a1f1e3d1b8de2913f9d5af2b900c56235d578b57efe" # rocm6.3 + "c7f319dd7448cbbbab81889dd8a37d47dbc25ebcbd89760f09e6a0904e556393" # rocm6.4 + "a2a974e0ad929a5e5827c0f896c59bda4872459cbaf8dd8e0a00407f404491cf" # rocm7.0 ) + set(__AOTRITON_IMAGE_LIST + "amd-gfx90a" + "amd-gfx942" + "amd-gfx950" + "amd-gfx11xx" + "amd-gfx120x" + ) + set(__AOTRITON_IMAGE_SHA256_LIST + "c19a41c9480510ab32e6fb05e6ed0a3832d6b07634f050b836b760200befa735" # amd-gfx90a + "3a06a99971dddb7703a30378f1c5d6b41468d926ea51821156d1b6857b985bc4" # amd-gfx942 + "27fc21f6761d57987a700436de8cf29cbdd9eeee91318dfed596eeb147d219ad" # amd-gfx950 + "ec134032087344176695505db659387374d1916adfee16f0db47dee38d9c8603" # amd-gfx11xx + "fec05205747ff51649b1e151545267d5aa2037ba9d0338cad286882915b941b0" # amd-gfx120x + ) + set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore set(__AOTRITON_Z "gz") - - # Note it is INSTALL"ED" - if(DEFINED ENV{AOTRITON_INSTALLED_PREFIX}) - install(DIRECTORY - $ENV{AOTRITON_INSTALLED_PREFIX}/lib - $ENV{AOTRITON_INSTALLED_PREFIX}/include - DESTINATION ${__AOTRITON_INSTALL_DIR}) - set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}") - message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}") - elseif(DEFINED ENV{AOTRITON_INSTALL_FROM_SOURCE}) - ExternalProject_Add(aotriton_external + function(aotriton_build_from_source noimage project) + if(noimage) + SET(RECURSIVE "OFF") + else() + SET(RECURSIVE "ON") + endif() + message(STATUS "PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}") + ExternalProject_Add(${project} GIT_REPOSITORY https://github.com/ROCm/aotriton.git + GIT_SUBMODULES_RECURSE ${RECURSIVE} GIT_TAG ${__AOTRITON_CI_COMMIT} PREFIX ${__AOTRITON_EXTERN_PREFIX} - INSTALL_DIR ${__AOTRITON_INSTALL_DIR} - CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR} + CMAKE_CACHE_ARGS -DAOTRITON_TARGET_ARCH:STRING=${PYTORCH_ROCM_ARCH} + -DCMAKE_INSTALL_PREFIX:FILEPATH=${__AOTRITON_INSTALL_DIR} + CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + -DAOTRITON_GPU_BUILD_TIMEOUT=0 -DAOTRITON_NO_PYTHON=ON - -DAOTRITON_NO_SHARED=OFF - # CONFIGURE_COMMAND "" - BUILD_COMMAND "" # No build, install command will repeat the build process due to problems in the build system. + -DAOTRITON_NOIMAGE_MODE=${noimage} BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so" USES_TERMINAL_DOWNLOAD TRUE USES_TERMINAL_CONFIGURE TRUE USES_TERMINAL_BUILD TRUE USES_TERMINAL_INSTALL TRUE - # INSTALL_COMMAND ${MAKE_COMMAND} install - ) - add_dependencies(__caffe2_aotriton aotriton_external) - message(STATUS "Using AOTriton compiled from source directory ${__AOTRITON_EXTERN_PREFIX}") - else() - set(__AOTRITON_SYSTEM_ROCM "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}") - list(GET __AOTRITON_ROCM_LIST 0 __AOTRITON_ROCM_DEFAULT_STR) - # Initialize __AOTRITON_ROCM to lowest version, in case all builds > system's ROCM - string(SUBSTRING ${__AOTRITON_ROCM_DEFAULT_STR} 4 -1 __AOTRITON_ROCM) - foreach(AOTRITON_ROCM_BUILD_STR IN LISTS __AOTRITON_ROCM_LIST) - # len("rocm") == 4 - string(SUBSTRING ${AOTRITON_ROCM_BUILD_STR} 4 -1 AOTRITON_ROCM_BUILD) - # Find the last build that <= system's ROCM - # Assume the list is from lower to higher - if(AOTRITON_ROCM_BUILD VERSION_GREATER __AOTRITON_SYSTEM_ROCM) - break() - endif() - set(__AOTRITON_ROCM ${AOTRITON_ROCM_BUILD}) - endforeach() - list(FIND __AOTRITON_ROCM_LIST "rocm${__AOTRITON_ROCM}" __AOTRITON_ROCM_INDEX) - list(GET __AOTRITON_SHA256_LIST ${__AOTRITON_ROCM_INDEX} __AOTRITON_SHA256) - list(GET __AOTRITON_MANYLINUX_LIST ${__AOTRITON_ROCM_INDEX} __AOTRITON_MANYLINUX) - set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) + ) + endfunction() + + set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) + function(aotriton_download_runtime index project) + list(GET __AOTRITON_ROCM_LIST ${index} __AOTRITON_ROCM) + list(GET __AOTRITON_MANYLINUX_LIST ${index} __AOTRITON_MANYLINUX) + list(GET __AOTRITON_SHA256_LIST ${index} __AOTRITON_SHA256) + string(CONCAT __AOTRITON_FILE "aotriton-" "${__AOTRITON_VER}-${__AOTRITON_MANYLINUX}" - "_${__AOTRITON_ARCH}-rocm${__AOTRITON_ROCM}" + "_${__AOTRITON_ARCH}-${__AOTRITON_ROCM}" "-shared.tar.${__AOTRITON_Z}") - string(CONCAT __AOTRITON_URL "https://github.com/ROCm/aotriton/releases/download/" # @lint-ignore - "${__AOTRITON_VER}/${__AOTRITON_FILE}") - ExternalProject_Add(aotriton_external + string(CONCAT __AOTRITON_URL + "${__AOTRITON_BASE_URL}" + "${__AOTRITON_VER}/${__AOTRITON_FILE}") + ExternalProject_Add(${project} URL "${__AOTRITON_URL}" URL_HASH SHA256=${__AOTRITON_SHA256} - SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball + SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime CONFIGURE_COMMAND "" BUILD_COMMAND "" INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory - "${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball" + "${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime" "${__AOTRITON_INSTALL_DIR}" BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so" ) - add_dependencies(__caffe2_aotriton aotriton_external) - message(STATUS "Using AOTriton from pre-compiled binary ${__AOTRITON_URL}.\ + message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\ Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.") + endfunction() + + function(aotriton_download_image image project) + list(FIND __AOTRITON_IMAGE_LIST ${image} index) + list(GET __AOTRITON_IMAGE_SHA256_LIST ${index} __AOTRITON_SHA256) + + string(CONCAT __AOTRITON_FILE + "aotriton-${__AOTRITON_VER}-images-" + "${image}.tar.${__AOTRITON_Z}") + string(CONCAT __AOTRITON_URL + "${__AOTRITON_BASE_URL}" + "${__AOTRITON_VER}/${__AOTRITON_FILE}") + ExternalProject_Add(${project} + URL "${__AOTRITON_URL}" + URL_HASH SHA256=${__AOTRITON_SHA256} + SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory + "${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}" + "${__AOTRITON_INSTALL_DIR}" + BUILD_BYPRODUCTS + "${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__" + ) + message(STATUS "Download AOTriton pre-compiled GPU images from ${__AOTRITON_URL}.") + endfunction() + + # Note it is INSTALL"ED" + if(DEFINED ENV{AOTRITON_INSTALLED_PREFIX}) + install(DIRECTORY + $ENV{AOTRITON_INSTALLED_PREFIX}/lib + $ENV{AOTRITON_INSTALLED_PREFIX}/include + DESTINATION ${__AOTRITON_INSTALL_DIR}) + set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}") + message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}") + elseif(DEFINED ENV{AOTRITON_INSTALL_FROM_SOURCE}) + aotriton_build_from_source(OFF aotriton_external) + add_dependencies(__caffe2_aotriton aotriton_external) + message(STATUS "Using AOTriton compiled from source directory ${__AOTRITON_EXTERN_PREFIX}") + else() + set(__AOTRITON_SYSTEM_ROCM "${HIP_VERSION_MAJOR}.${HIP_VERSION_MINOR}") + list(FIND __AOTRITON_ROCM_LIST "rocm${__AOTRITON_SYSTEM_ROCM}" __AOTRITON_RUNTIME_INDEX) + if(${__AOTRITON_RUNTIME_INDEX} LESS 0) + message(STATUS "Cannot find AOTriton runtime for ROCM ${__AOTRITON_SYSTEM_ROCM}. \ + Build runtime from source") + aotriton_build_from_source(ON aotriton_runtime) + else() + aotriton_download_runtime(${__AOTRITON_RUNTIME_INDEX} aotriton_runtime) + endif() + add_dependencies(__caffe2_aotriton aotriton_runtime) + set(__AOTRITON_CHAINED_IMAGE "aotriton_runtime") + foreach(image ${__AOTRITON_IMAGE_LIST}) + string(SUBSTRING ${image} 7 -1 gfx_pattern) + string(REPLACE "x" "." gfx_regex ${gfx_pattern}) + foreach(target ${PYTORCH_ROCM_ARCH}) + if(target MATCHES ${gfx_regex}) + set(__AOTRITON_DOWNLOAD_TARGET aotriton_image_${gfx_pattern}) + aotriton_download_image(${image} ${__AOTRITON_DOWNLOAD_TARGET}) + add_dependencies(${__AOTRITON_CHAINED_IMAGE} ${__AOTRITON_DOWNLOAD_TARGET}) + set(__AOTRITON_CHAINED_IMAGE ${__AOTRITON_DOWNLOAD_TARGET}) + break() + endif() + endforeach() + endforeach() endif() target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so) target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) diff --git a/test/test_transformers.py b/test/test_transformers.py index 8bdad854cd22..df6448a7d98c 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -52,6 +52,7 @@ SM90OrLater, tf32_on_and_off, tf32_enabled, + ROCM_VERSION, ) if TEST_FAIRSEQ: @@ -340,7 +341,7 @@ def test_train_with_pad_and_catch_error(self, device): l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item() self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL") - @tf32_on_and_off(0.001) + @tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) @parametrize("attn_mask_dim", [2, 3, None]) @parametrize("key_padding_mask_dim", [2, None]) @parametrize("mask_dtype", [torch.bool, torch.float32]) @@ -524,7 +525,7 @@ def test_transformerencoder_fastpath(self, device, use_torchscript, enable_neste slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0) self.assertEqual(fastpath_output_expanded, slowpath_output) - @tf32_on_and_off(0.001) + @tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) @parametrize("with_no_grad", [True, False]) @parametrize("training", [True, False]) @parametrize("enable_nested_tensor", [False]) @@ -1110,7 +1111,7 @@ def forward( return_all_hiddens=False, )[0] - @tf32_on_and_off(0.003) + @tf32_on_and_off(0.003, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) @parametrize("input_dim,attn_mask_dim,is_causal", [(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True), (4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)], @@ -3351,6 +3352,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, 'grad_value': 8.5, } if TEST_WITH_ROCM: + fudge_factors['out'] = 5.0 fudge_factors['grad_key'] = 45.0 fudge_factors['grad_query'] = 360.0 if seq_len_k >= 1024: @@ -3360,6 +3362,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, fudge_factors['grad_query'] = 670.0 if dtype == torch.float32: fudge_factors['grad_key'] = 90.0 + if "gfx95" in torch.cuda.get_device_properties(0).gcnArchName: + fudge_factors['grad_value'] = 16.0 check_out_and_grad( (out_ref, out_lp_ref, out), @@ -3472,6 +3476,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, "grad_attn_mask": 45.0, } if TEST_WITH_ROCM: + fudge_factors['out'] = 6.0 fudge_factors['grad_key'] = 45.0 fudge_factors['grad_query'] = 360.0 if seq_len_k >= 1024: @@ -3481,6 +3486,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, fudge_factors['grad_query'] = 670.0 # gfx90a if dtype == torch.float32: fudge_factors['grad_key'] = 90.0 + if "gfx95" in torch.cuda.get_device_properties(0).gcnArchName: + fudge_factors['grad_value'] = 16.0 check_out_and_grad( (out_ref, out_lp_ref, out), @@ -3601,17 +3608,33 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le 'grad_value': 4, } if TEST_WITH_ROCM: - fudge_factors['grad_key'] = 45.0 - fudge_factors['grad_query'] = 360.0 - if seq_len_k >= 1024: - fudge_factors['grad_key'] = 70.0 - if seq_len_k >= 2048: - fudge_factors['grad_key'] = 190.0 - fudge_factors['grad_query'] = 650.0 - if seq_len_q >= 2048: - fudge_factors['grad_query'] = 1100.0 - if dtype == torch.float32: - fudge_factors['grad_key'] = 90.0 + fudge_factors['grad_value'] = 6.0 + if TEST_WITH_CK: + fudge_factors['out'] = 5.0 + fudge_factors['grad_key'] = 145.0 + fudge_factors['grad_query'] = 855.0 # ck min = 855.0 + if seq_len_k >= 1024: + fudge_factors['grad_key'] = 70.0 + if seq_len_k >= 2048: + fudge_factors['grad_key'] = 190.0 + fudge_factors['grad_query'] = 1550.0 # NEW CK MIN + if seq_len_q >= 2048: + fudge_factors['grad_query'] = 1100.0 + if dtype == torch.float32: + fudge_factors['grad_key'] = 90.0 + else: + fudge_factors['out'] = 6.0 + fudge_factors['grad_key'] = 45.0 + fudge_factors['grad_query'] = 360.0 + if seq_len_k >= 1024: + fudge_factors['grad_key'] = 70.0 + if seq_len_k >= 2048: + fudge_factors['grad_key'] = 190.0 + fudge_factors['grad_query'] = 650.0 + if seq_len_q >= 2048: + fudge_factors['grad_query'] = 1100.0 + if dtype == torch.float32: + fudge_factors['grad_key'] = 90.0 check_out_and_grad( (out_ref, out_lp_ref, out), @@ -3764,15 +3787,19 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d grads_ref_lp = torch.autograd.grad(out_lp_ref, (query, key, value), upstream_grad) grads_ref = torch.autograd.grad(out_ref, (query_ref, key_ref, value_ref), upstream_grad) + fudge_factors = { + 'out': 3.0, + 'grad_query': 110.0, + 'grad_key': 8.0, + 'grad_value': 3.0, + } + if TEST_WITH_ROCM: + fudge_factors['out'] = 6.0 + fudge_factors['grad_value'] = 6.0 check_out_and_grad( (out_ref, out_lp_ref, out), *zip(grads_ref, grads_ref_lp, grads), - fudge_factors={ - 'out': 3.0, - 'grad_query': 110.0, - 'grad_key': 8.0, - 'grad_value': 3.0, - } + fudge_factors=fudge_factors ) @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @@ -4384,10 +4411,6 @@ def test_causal_variants(self, device, causal_variant: CausalVariant, shape: lis make_tensor = partial( torch.rand, device=device, dtype=torch.float16, requires_grad=True ) - if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT: - self.skipTest("No support for LOWER_RIGHT variant for now") - return - bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim)) make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim)) @@ -4418,10 +4441,6 @@ def test_causal_variants(self, device, causal_variant: CausalVariant, shape: lis @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows") @skipIfTorchDynamo("This function already calls torch.compile.") def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]): - if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT: - self.skipTest("No support for LOWER_RIGHT variant for now") - return - cnts = CompileCounterWithBackend("aot_eager") make_tensor = partial( torch.rand, device=device, dtype=torch.float16, requires_grad=True diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 2620c64a95ef..9c4dfd1d7d44 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -24,6 +24,7 @@ TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE))) TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0) +ROCM_VERSION = LazyVal(lambda : tuple(int(v) for v in torch.version.hip.split('.')[:2]) if torch.version.hip else (0, 0)) SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3)) SM60OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0)) @@ -93,7 +94,6 @@ def evaluate_platform_supports_cudnn_attention(): def evaluate_platform_supports_fp8(): if torch.cuda.is_available(): if torch.version.hip: - ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2]) archs = ['gfx94'] if ROCM_VERSION >= (6, 3): archs.extend(['gfx120']) @@ -111,7 +111,8 @@ def evaluate_platform_supports_fp8(): def _platform_supports_mx_gemm(): if torch.cuda.is_available(): if torch.version.hip: - return 'gfx95' in torch.cuda.get_device_properties(0).gcnArchName + if ROCM_VERSION >= (7, 0): + return 'gfx950' in torch.cuda.get_device_properties(0).gcnArchName else: return SM100OrLater return False @@ -222,7 +223,7 @@ def tf32_enabled(): # if device is specified, it will check if device is cuda # if dtype is specified, it will check if dtype is float32 or complex64 # tf32 and fp32 are different only when all the three checks pass -def tf32_on_and_off(tf32_precision=1e-5): +def tf32_on_and_off(tf32_precision=1e-5, only_if=True): def with_tf32_disabled(self, function_call): with tf32_off(): function_call() @@ -238,7 +239,7 @@ def wrapper(f): @functools.wraps(f) def wrapped(*args, **kwargs): kwargs.update(zip(arg_names, args)) - cond = torch.cuda.is_tf32_supported() + cond = torch.cuda.is_tf32_supported() and only_if if 'device' in kwargs: cond = cond and (torch.device(kwargs['device']).type == 'cuda') if 'dtype' in kwargs: @@ -252,7 +253,6 @@ def wrapped(*args, **kwargs): return wrapped return wrapper - # This is a wrapper that wraps a test to run it with TF32 turned off. # This wrapper is designed to be used when a test uses matmul or convolutions # but the purpose of that test is not testing matmul or convolutions. diff --git a/torch/utils/_triton.py b/torch/utils/_triton.py index fa2355463bb3..f260f5781f96 100644 --- a/torch/utils/_triton.py +++ b/torch/utils/_triton.py @@ -15,6 +15,17 @@ def has_triton_package() -> bool: return False +@functools.cache +def get_triton_version(fallback: tuple[int, int] = (0, 0)) -> tuple[int, int]: + try: + import triton # noqa: F401 + + major, minor = tuple(int(v) for v in triton.__version__.split(".")[:2]) + return (major, minor) + except ImportError: + return fallback + + @functools.cache def _device_supports_tma() -> bool: import torch From 3345d297a360b96c67e6a70118efcb6c8ccd119e Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Mon, 15 Sep 2025 16:13:03 +0000 Subject: [PATCH 2/4] [ROCm/Windows] Support aotriton for scaled_dot_product_attention on Windows. (#162330) Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton. Already tested to be working on Windows with TheRock. Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162330 Approved by: https://github.com/jeffdaily Co-authored-by: Scott Todd --- CMakeLists.txt | 4 +- .../native/transformers/cuda/attention.cu | 66 ++++++++++ .../transformers/hip/flash_attn/flash_api.h | 39 +----- cmake/External/aotriton.cmake | 113 +++++++++++++++++- tools/linter/dictionary.txt | 1 + 5 files changed, 179 insertions(+), 44 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a5d25e6afa0f..63025c26a05e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -867,7 +867,7 @@ cmake_dependent_option( "Whether to build the flash_attention kernel for scaled dot product attention.\ Will be disabled if not supported by the platform" ON - "USE_CUDA OR USE_ROCM;NOT MSVC" + "(USE_CUDA AND NOT MSVC) OR USE_ROCM" OFF) # CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem @@ -883,7 +883,7 @@ cmake_dependent_option( # USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake # if(USE_ROCM) - if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)) + if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION) include(cmake/External/aotriton.cmake) endif() endif() diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index dae3332430f1..76a62b3f7f8a 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -95,6 +95,72 @@ #endif #endif +#if defined(USE_ROCM) && (defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)) +namespace pytorch_flash +{ +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_) { +#if defined(USE_ROCM_CK_SDPA) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); + std::optional dummy_attn_bias = std::nullopt; + return mha_fwd_ck( + q, + k, + v, + out_, + p_dropout, + softmax_scale, + is_causal, + non_null_window_left, + non_null_window_right, + return_softmax, + gen_, + dummy_attn_bias); // Not used in flash attention + } +#endif + return mha_fwd_aot( + q, + k, + v, + out_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +} +} +#endif + namespace at { namespace cuda::philox { diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h index 17298aae9485..e578847e3273 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -270,7 +270,7 @@ std::tuple mha_varle #endif TORCH_API -inline std::tuple< +std::tuple< at::Tensor, at::Tensor, at::Tensor, @@ -294,42 +294,7 @@ mha_fwd( std::optional window_size_right, const float softcap, const bool return_softmax, - std::optional gen_) { -#if defined(USE_CK_FLASH_ATTENTION) - if (at::globalContext().getROCmFAPreferredBackend() == - at::ROCmFABackend::Ck) { - const int non_null_window_left = window_size_left.value_or(-1); - const int non_null_window_right = window_size_right.value_or(-1); - std::optional dummy_attn_bias = std::nullopt; - return mha_fwd_ck( - q, - k, - v, - out_, - p_dropout, - softmax_scale, - is_causal, - non_null_window_left, - non_null_window_right, - return_softmax, - gen_, - dummy_attn_bias); // Not used in flash attention - } -#endif - return mha_fwd_aot( - q, - k, - v, - out_, - alibi_slopes_, - p_dropout, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - return_softmax, - gen_); -} + std::optional gen_); inline std::tuple< at::Tensor, diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index 5d9158774654..4f7a79a78bfc 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -45,13 +45,88 @@ if(NOT __AOTRITON_INCLUDED) ) set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore set(__AOTRITON_Z "gz") + # Set the default __AOTRITON_LIB path + set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so") + if(WIN32) + set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/aotriton_v2.lib") + endif() + + function(aotriton_build_windows_dependencies dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR) + # Windows-specific dependencies - build these first + if(NOT noimage) + message(FATAL_ERROR "noimage must be ON for Windows builds") + endif() + # Build dlfcn-win32 + set(__DLFCN_WIN32_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32") + set(__DLFCN_WIN32_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32-install") + + ExternalProject_Add(${dlfcn-win32_external} + GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git + GIT_TAG v1.4.2 + PREFIX ${__DLFCN_WIN32_PREFIX} + INSTALL_DIR ${__DLFCN_WIN32_INSTALL_DIR} + CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX=${__DLFCN_WIN32_INSTALL_DIR} + -DCMAKE_BUILD_TYPE=Release + -DCMAKE_C_COMPILER=cl + -DCMAKE_CXX_COMPILER=cl + -DBUILD_SHARED_LIBS=ON + -DBUILD_TESTS=OFF + BUILD_BYPRODUCTS + "${__DLFCN_WIN32_INSTALL_DIR}/lib/dl.lib" + "${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll" + ) + ExternalProject_Add_Step(${dlfcn-win32_external} copy_to_aotriton + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll" + "${__AOTRITON_INSTALL_DIR}/lib/" + DEPENDEES install + ) + set(${dlfcn-win32_DIR} "${__DLFCN_WIN32_INSTALL_DIR}/share/dlfcn-win32" CACHE PATH "Path to dlfcn-win32 CMake config" FORCE) + + # Build xz/liblzma + set(__XZ_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/xz") + set(__XZ_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/xz-install") + + ExternalProject_Add(${xz_external} + GIT_REPOSITORY https://github.com/tukaani-project/xz.git + GIT_TAG v5.8.1 + PREFIX ${__XZ_PREFIX} + INSTALL_DIR ${__XZ_INSTALL_DIR} + CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX=${__XZ_INSTALL_DIR} + -DCMAKE_BUILD_TYPE=Release + -DBUILD_SHARED_LIBS=ON + -DENABLE_NLS=OFF + -DXZ_TOOL_LZMAINFO=OFF + -DXZ_TOOL_XZ=OFF + -DXZ_TOOL_XZDEC=OFF + -DXZ_TOOL_LZMADEC=OFF + BUILD_BYPRODUCTS + "${__XZ_INSTALL_DIR}/lib/lzma.lib" + "${__XZ_INSTALL_DIR}/bin/liblzma.dll" + ) + ExternalProject_Add_Step(${xz_external} copy_to_aotriton + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${__XZ_INSTALL_DIR}/bin/liblzma.dll" + "${__AOTRITON_INSTALL_DIR}/lib/" + DEPENDEES install + ) + set(${liblzma_DIR} "${__XZ_INSTALL_DIR}/lib/cmake/liblzma" CACHE PATH "Path to xz/liblzma CMake config" FORCE) + endfunction() + function(aotriton_build_from_source noimage project) if(noimage) SET(RECURSIVE "OFF") else() SET(RECURSIVE "ON") endif() + if(WIN32) + message(STATUS "Building AOTriton Windows dependencies") + aotriton_build_windows_dependencies(dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR) + endif() message(STATUS "PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}") + ExternalProject_Add(${project} GIT_REPOSITORY https://github.com/ROCm/aotriton.git GIT_SUBMODULES_RECURSE ${RECURSIVE} @@ -65,12 +140,19 @@ if(NOT __AOTRITON_INCLUDED) -DAOTRITON_GPU_BUILD_TIMEOUT=0 -DAOTRITON_NO_PYTHON=ON -DAOTRITON_NOIMAGE_MODE=${noimage} - BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so" + -DHIP_PLATFORM=amd + $<$:-Ddlfcn-win32_DIR=${dlfcn-win32_DIR}> + $<$:-Dliblzma_DIR=${liblzma_DIR}> + BUILD_BYPRODUCTS + "${__AOTRITON_LIB}" USES_TERMINAL_DOWNLOAD TRUE USES_TERMINAL_CONFIGURE TRUE USES_TERMINAL_BUILD TRUE USES_TERMINAL_INSTALL TRUE ) + if(WIN32) + add_dependencies(${project} dlfcn-win32_external xz_external) + endif() endfunction() set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) @@ -95,7 +177,7 @@ if(NOT __AOTRITON_INCLUDED) INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory "${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime" "${__AOTRITON_INSTALL_DIR}" - BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so" + BUILD_BYPRODUCTS "${__AOTRITON_LIB}" ) message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\ Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.") @@ -111,14 +193,35 @@ if(NOT __AOTRITON_INCLUDED) string(CONCAT __AOTRITON_URL "${__AOTRITON_BASE_URL}" "${__AOTRITON_VER}/${__AOTRITON_FILE}") + + # Set up directories + set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_download-${image}) + set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}) + set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}) + set(__DOWNLOAD_NO_EXTRACT "") + set(__BUILD_COMMANDS "") + + # On Windows, we need custom tar extraction with UTF-8 support + if(WIN32) + set(__DOWNLOAD_NO_EXTRACT "DOWNLOAD_NO_EXTRACT;TRUE") + set(__BUILD_COMMANDS + COMMAND ${CMAKE_COMMAND} -E make_directory "${__AOTRITON_EXTRACT_DIR}" + COMMAND tar --options hdrcharset=UTF-8 -xf "${__AOTRITON_DOWNLOAD_DIR}/${__AOTRITON_FILE}" -C "${__AOTRITON_EXTRACT_DIR}" + ) + set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}/aotriton) + endif() + ExternalProject_Add(${project} URL "${__AOTRITON_URL}" URL_HASH SHA256=${__AOTRITON_SHA256} - SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image} + DOWNLOAD_DIR ${__AOTRITON_DOWNLOAD_DIR} + ${__DOWNLOAD_NO_EXTRACT} + SOURCE_DIR ${__AOTRITON_EXTRACT_DIR} CONFIGURE_COMMAND "" BUILD_COMMAND "" + ${__BUILD_COMMANDS} INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory - "${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}" + "${__AOTRITON_INSTALL_SOURCE_DIR}" "${__AOTRITON_INSTALL_DIR}" BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__" @@ -164,7 +267,7 @@ if(NOT __AOTRITON_INCLUDED) endforeach() endforeach() endif() - target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so) + target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_LIB}) target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) set(AOTRITON_FOUND TRUE) endif() # __AOTRITON_INCLUDED diff --git a/tools/linter/dictionary.txt b/tools/linter/dictionary.txt index a3da2299cf23..2a7c3b9d1acd 100644 --- a/tools/linter/dictionary.txt +++ b/tools/linter/dictionary.txt @@ -3,6 +3,7 @@ BU contiguities contiguity coo +DEPENDEES Din Dout dOut From a86374d594a831acad26751b47d8df4dea488d3e Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 22 Sep 2025 15:01:18 +0000 Subject: [PATCH 3/4] [ROCm] Fix environment variable AOTRITON_INSTALLED_PREFIX (#163373) Early assignment of `__AOTRITON_LIB` breaks the usage of environment variable `$AOTRITON_INSTALLED_PREFIX` Pull Request resolved: https://github.com/pytorch/pytorch/pull/163373 Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily --- cmake/External/aotriton.cmake | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index 4f7a79a78bfc..f09f77bedb80 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -46,9 +46,10 @@ if(NOT __AOTRITON_INCLUDED) set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore set(__AOTRITON_Z "gz") # Set the default __AOTRITON_LIB path - set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so") - if(WIN32) - set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/aotriton_v2.lib") + if(NOT WIN32) + set(__AOTRITON_LIB "lib/libaotriton_v2.so") + else() + set(__AOTRITON_LIB "lib/aotriton_v2.lib") endif() function(aotriton_build_windows_dependencies dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR) @@ -143,8 +144,7 @@ if(NOT __AOTRITON_INCLUDED) -DHIP_PLATFORM=amd $<$:-Ddlfcn-win32_DIR=${dlfcn-win32_DIR}> $<$:-Dliblzma_DIR=${liblzma_DIR}> - BUILD_BYPRODUCTS - "${__AOTRITON_LIB}" + BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/${__AOTRITON_LIB}" USES_TERMINAL_DOWNLOAD TRUE USES_TERMINAL_CONFIGURE TRUE USES_TERMINAL_BUILD TRUE @@ -177,7 +177,7 @@ if(NOT __AOTRITON_INCLUDED) INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory "${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime" "${__AOTRITON_INSTALL_DIR}" - BUILD_BYPRODUCTS "${__AOTRITON_LIB}" + BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/${__AOTRITON_LIB}" ) message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\ Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.") @@ -267,7 +267,7 @@ if(NOT __AOTRITON_INCLUDED) endforeach() endforeach() endif() - target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_LIB}) + target_link_libraries(__caffe2_aotriton INTERFACE "${__AOTRITON_INSTALL_DIR}/${__AOTRITON_LIB}") target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) set(AOTRITON_FOUND TRUE) endif() # __AOTRITON_INCLUDED From c497508b3a83a1a6293f7c5c842614eec613f2f7 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Thu, 25 Sep 2025 17:14:16 +0000 Subject: [PATCH 4/4] [ROCm] Transformer/SDPA unit test parity (#163745) * Efficient Attention on ROCM requires last dimensions of input tensors align with 16 bytes. - Unlike FA, ME does not pad input tensors in `scaled_dot_product_attention` and hence this is required. * Fix `atomic_counter` handling in varlen FA API * Unskips a few unit tests. Fixes #157120 Fixes #157121 Fixes #157122 Fixes #157167 Fixes #155217 Fixes #157043 Fixes #157060 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163745 Approved by: https://github.com/jeffdaily --- .../native/transformers/cuda/sdp_utils.cpp | 22 +++++++++++++++++++ .../hip/flash_attn/aot/mha_all_aot.hip | 5 +++-- test/nn/test_multihead_attention.py | 2 -- test/test_flop_counter.py | 3 --- test/test_nn.py | 10 +++------ test/test_transformers.py | 13 +---------- 6 files changed, 29 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 4570ac8f03fc..0df958c4c010 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -158,6 +158,28 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) { } return false; } + if constexpr(caller_is_meff) { + bool is_half = (params.query.dtype() == at::kHalf) || + (params.query.dtype() == at::kBFloat16); + const int64_t alignment = is_half ? 8 : 4; + if (!(query_size_last % alignment == 0 && query_size_last > 0 && + value_size_last % alignment == 0 && value_size_last > 0)) { + if (debug) { + TORCH_WARN( + "Mem efficient attention requires last dimension of inputs to be divisible by ", + alignment, + ". ", + "Got Query.size(-1): ", + query_size_last, + ", Key.size(-1): ", + params.key.sym_size(-1), + ", Value.size(-1): ", + params.value.sym_size(-1), + " instead."); + } + return false; + } + } return true; } diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index 8540fa992d75..acadb67ae171 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -457,10 +457,11 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::mk_philoxtensor; + using sdp::aotriton_adapter::mk_atomictensor; using sdp::aotriton_adapter::cast_dtype; at::Tensor atomic_counter; if (is_causal) { - atomic_counter = at::zeros({1}, q.options()); + atomic_counter = at::zeros({1}, q.options().dtype(at::kInt)); } aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); @@ -469,7 +470,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot auto nullscalar = mk_philoxtensor(nullptr); auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : nullscalar; auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : nullscalar; - auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr()) : nullscalar; + auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); if (uses_swa || AOTRITON_ALWAYS_V3_API) { #if AOTRITON_V3_API using aotriton::v3::flash::CausalType; diff --git a/test/nn/test_multihead_attention.py b/test/nn/test_multihead_attention.py index c0419664d009..40dca90b1648 100644 --- a/test/nn/test_multihead_attention.py +++ b/test/nn/test_multihead_attention.py @@ -17,7 +17,6 @@ instantiate_parametrized_tests, parametrize as parametrize_test, run_tests, - skipIfRocm, TEST_NUMPY, TEST_WITH_CROSSREF, ) @@ -746,7 +745,6 @@ def test_multihead_attn_nested_tensor_outside_fast_path(self): class TestMultiheadAttentionNNDeviceType(NNTestCase): - @skipIfRocm(msg="To investigate: yields NaN") def test_multihead_self_attn_two_masks_fast_path(self, device): """ Multihead self-attention should give the same result on the fast path (BetterTransformer) as on the slow path diff --git a/test/test_flop_counter.py b/test/test_flop_counter.py index c44d5e5d4145..17e699e04e58 100644 --- a/test/test_flop_counter.py +++ b/test/test_flop_counter.py @@ -15,7 +15,6 @@ ) from torch.testing._internal.common_utils import ( run_tests, - skipIfRocm, TEST_WITH_TORCHDYNAMO, TestCase, ) @@ -463,7 +462,6 @@ def get_flops( self.assertExpectedInline(str(flops_fw_bw_math), """805306368""") self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""") - @skipIfRocm # Nested tensor @unittest.skipIf(not HAS_CUDA, "CUDA not available") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION @@ -683,7 +681,6 @@ def split_tensor(x): ), ) - @skipIfRocm # Nested tensor @unittest.skipIf(not HAS_CUDA, "CUDA not available") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, diff --git a/test/test_nn.py b/test/test_nn.py index f3b1764af69d..e65f5d53147a 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -36,8 +36,9 @@ download_file, get_function_arglist, load_tests, skipIfMPS, \ IS_PPC, \ parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \ - skipIfTorchDynamo, skipIfRocmVersionLessThan, gcIfJetson, set_default_dtype -from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION, _get_torch_rocm_version + skipIfTorchDynamo, gcIfJetson, set_default_dtype +from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \ + _get_torch_rocm_version from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \ module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \ ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input @@ -3148,7 +3149,6 @@ def perm_fn(x): [2.42240309, 0.0354595, -0.60659063, -0.05378816]]])) torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) - @skipIfRocm(msg='Large numerical errors') def test_transformerdecoder(self): def get_a_test_layer(use_cuda, activation, batch_first=False): d_model = 4 @@ -12937,8 +12937,6 @@ def test_skip_init(self, device): @dtypes(torch.float) @dtypesIfCUDA(torch.double, torch.float, torch.half) def test_transformerencoderlayer(self, device, dtype): - if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half: - self.skipTest("Skip on ROCM due to Flash Attention tolerances") # this is a deterministic test for TransformerEncoderLayer d_model = 4 nhead = 2 @@ -13160,8 +13158,6 @@ def test_transformerencoderlayer_fast_path(self, device, dtype): @dtypes(torch.float) @dtypesIfCUDA(torch.half, torch.float) def test_transformerencoderlayer_gelu(self, device, dtype): - if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half: - self.skipTest("Skip on ROCM due to Flash Attention tolerances") # this is a deterministic test for TransformerEncoderLayer with gelu activation d_model = 4 nhead = 2 diff --git a/test/test_transformers.py b/test/test_transformers.py index df6448a7d98c..a68ed8e10576 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -431,7 +431,6 @@ def hook(module, inputs, output): # remove hook handle.remove() - @skipIfRocm @tf32_on_and_off(0.001) @parametrize("use_torchscript", [False]) @parametrize("enable_nested_tensor", [True, False]) @@ -1422,7 +1421,6 @@ def ones_tensor(*shape): _ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True) torch.cuda.synchronize() - @skipIfRocm # Missing EFFICIENT_ATTENTION @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware" ) @@ -1728,7 +1726,7 @@ def test_unaligned_tensors(self, device): make_tensor = partial(torch.rand, size, device=device, dtype=dtype) q, k, v = make_tensor(), make_tensor(), make_tensor() with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): - ctxmgr = self.assertRaises(RuntimeError) if not TEST_WITH_ROCM else contextlib.nullcontext() + ctxmgr = self.assertRaises(RuntimeError) with ctxmgr: torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False) @@ -2542,7 +2540,6 @@ def convert_flash_attn_S_to_softmax( S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) return S_converted[:, :, :seqlen_q, :seqlen_k] - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_cudnn_attention_different_dk_dv(self, device): dtype = torch.bfloat16 @@ -2566,7 +2563,6 @@ def test_cudnn_attention_different_dk_dv(self, device): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_cudnn_attention_gqa(self, device): batch = 4 @@ -2590,7 +2586,6 @@ def test_cudnn_attention_gqa(self, device): self.assertEqual(output_math, output_cudnn) - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_cudnn_attention_d256_heuristic(self, device): dtype = torch.bfloat16 @@ -2615,7 +2610,6 @@ def test_cudnn_attention_d256_heuristic(self, device): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) - @skipIfRocm(msg="No cuDNN on ROCm") @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_fused_attention_different_dk_dv(self, device): dtype = torch.bfloat16 @@ -2639,7 +2633,6 @@ def test_fused_attention_different_dk_dv(self, device): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_cudnn_attention_fail_d128(self, device): # Test that cuDNN attention dispatching correctly bails out on d > 128 @@ -2662,7 +2655,6 @@ def test_cudnn_attention_fail_d128(self, device): with self.assertRaisesRegex(RuntimeError, "No available kernel."): torch.nn.functional.scaled_dot_product_attention(q, k, v) - @skipIfRocm(msg="No cuDNN on ROCm") @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_trivial_output_transpose(self, device): # see also: https://github.com/pytorch/pytorch/issues/134001 @@ -2678,7 +2670,6 @@ def test_cudnn_attention_trivial_output_transpose(self, device): o.backward(o) torch.testing.assert_close(x.grad, x_cpu.grad.cuda(), atol=7e-3, rtol=7e-3) - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_nonmodulo64seqlen(self, device): # see also: https://github.com/pytorch/pytorch/issues/137347 @@ -2718,7 +2709,6 @@ def test_cudnn_attention_nonmodulo64seqlen(self, device): torch.testing.assert_close(k.grad, k_cpu.grad.cuda(), atol=3e-3, rtol=2e-3) torch.testing.assert_close(v.grad, v_cpu.grad.cuda(), atol=3e-3, rtol=2e-3) - @skipIfRocm @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_preserves_query_layout(self, device): @@ -3131,7 +3121,6 @@ def test_sdp_choice_with_determinism(self, device, warn_only): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]): assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value - @skipIfRocm @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA")