diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index 9de8423640104..e69d64cc5a6b7 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -106,6 +106,12 @@ COPY triton_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt +# This is needed by sccache +COPY ./common/install_openssl.sh install_openssl.sh +ENV OPENSSL_ROOT_DIR /opt/openssl +RUN bash ./install_openssl.sh +ENV OPENSSL_DIR /opt/openssl + # Install ccache/sccache (do this last, so we get priority in PATH) COPY ./common/install_cache.sh install_cache.sh ENV PATH /opt/cache/bin:$PATH diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 0087dd95d96ee..8c8dc6662e00e 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -174,6 +174,7 @@ file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp") # flash_attention sources file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip") file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip") +file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip") #Mem_eff attention sources file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu") @@ -188,6 +189,7 @@ if(USE_FLASH_ATTENTION) list(APPEND ATen_ATTENTION_KERNEL_SRCS ${flash_attention_cuda_kernels_cu}) list(APPEND native_transformers_hip_hip ${flash_attention_hip_hip}) + list(APPEND native_transformers_hip_hip ${flash_attention_hip_aot_hip}) list(APPEND native_transformers_src_hip_hip ${flash_attention_src_hip_hip}) endif() diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 0f9356a7f3063..84b07e010cd67 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -1077,8 +1077,8 @@ std::tuple _efficient_ auto ret = aotriton::v2::flash::check_gpu(stream); if (hipSuccess != ret) { TORCH_CHECK(false, - "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs" - " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)") + "[AOTriton] Accelerated SDPA only supports MI200/MI300X/7900XTX/9070XT GPUs" + " (gfx90a/gfx942/gfx1100/gfx1201)") } // AOTriton may accept aligned on logsumexp tensor in the future for better @@ -1086,10 +1086,13 @@ std::tuple _efficient_ // compute_logsumexp is false constexpr int kAlignLSE = 1; res = at::empty({B, M, num_heads, Kv}, query.options()); + at::Tensor softmax_lse; logsumexp = at::empty( - { B, num_heads, max_seqlen_q }, + { B, num_heads, compute_logsumexp ? max_seqlen_q : 0}, query.options().dtype(at::ScalarType::Float)); - at::Tensor softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q}); + if (compute_logsumexp) { + softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q}); + } at::Tensor q_t = query.transpose(1, 2); at::Tensor k_t = key.transpose(1, 2); at::Tensor v_t = value.transpose(1, 2); @@ -1105,40 +1108,69 @@ std::tuple _efficient_ const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); + at::Tensor atomic_counter; + if (is_causal) { + atomic_counter = at::zeros({1}, query.options().dtype(at::kInt)); + } + using aotriton::v2::flash::attn_fwd; + using aotriton::v2::flash::attn_fwd_compact_varlen; using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::mk_philoxtensor; + using sdp::aotriton_adapter::mk_atomictensor; aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16); + aotriton::TensorView<2> empty_t2(0, {0, 0}, {0, 0}, aotriton::DType::kFloat32); at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options()); const bool use_philox_state = in_capture_stream; auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t); auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0; - auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : mk_philoxtensor(nullptr); - auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : mk_philoxtensor(nullptr); + auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr() : nullptr); + auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr() : nullptr); + auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); hipError_t err; // TODO: Error handling - err = attn_fwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4, - softmax_scale, - mk_aotensor<2>(softmax_lse, "M"), - mk_aotensor(output_t, "Out"), - dropout_p, - seed, - offset1, - offset2, - seed_output, - offset_output, - mk_aotensor(softmax_fa_t, "encoded_softmax"), - is_causal, - stream); - if (!compute_logsumexp) { - // Set the tensor to empty when compute_logsumexp is false - logsumexp = at::empty( - { B * num_heads, max_seqlen_q, 0 }, - query.options().dtype(at::ScalarType::Float)); + if (seqstart_q.has_value()) { + // varlen aka nested tensor + err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4, + mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"), + mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + softmax_scale, + compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2, + mk_aotensor(output_t, "Out"), + dropout_p, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + persistent_counter, + stream); + } else { + err = attn_fwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4, + softmax_scale, + compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2, + mk_aotensor(output_t, "Out"), + dropout_p, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + persistent_counter, + stream); } #else // CUDA Implementation @@ -1401,15 +1433,24 @@ at::Tensor& _fill_mem_eff_dropout_mask_( #if defined(USE_MEM_EFF_ATTENTION) #ifdef USE_ROCM - using aotriton::v2::flash::debug_fill_dropout_rng; + using aotriton::v2::flash::debug_simulate_encoded_softmax; using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + at::cuda::CUDAGuard device_guard(self.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + at::Tensor seed_t, offset_t; + const auto options = at::dtype(at::kLong).device(at::kCUDA); + seed_t = at::scalar_tensor(at::Scalar(seed), options); + offset_t = at::scalar_tensor(at::Scalar(offset), options); hipError_t err; // TODO: Error handling - err = debug_fill_dropout_rng(mk_aotensor(self, "r"), - static_cast(seed), - static_cast(offset), - stream); + err = debug_simulate_encoded_softmax(mk_aotensor(self, "r"), + dropout_p, + mk_aoscalartensor(seed_t), + mk_aoscalartensor(offset_t), + 0, + stream); #else at::PhiloxCudaState rng_engine_inputs; rng_engine_inputs = at::PhiloxCudaState(seed, offset); diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index e809f97265774..017661a3f3637 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -383,8 +383,8 @@ _efficient_attention_backward( auto ret = aotriton::v2::flash::check_gpu(stream); if (hipSuccess != ret) { TORCH_CHECK(false, - "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs" - " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)") + "[AOTriton] Accelerated SDPA only supports MI200/MI300X/7900XTX/9070XT GPUs" + " (gfx90a/gfx942/gfx1100/gfx1201)") } const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); bool is_causal; @@ -404,33 +404,86 @@ _efficient_attention_backward( at::Tensor dv_t = grad_v.permute({0,2,1,3}); at::Tensor dout_t = grad_out.permute({0,2,1,3}); at::Tensor softmax_lse = logsumexp.view({B * nH, max_seqlen_q}); - at::Tensor delta = at::empty_like(softmax_lse).contiguous(); hipError_t err; using aotriton::v2::flash::attn_bwd; + using aotriton::v2::flash::attn_bwd_fused; + using aotriton::v2::flash::attn_bwd_compact_varlen; using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::cast_dtype; aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype())); - err = attn_bwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, - softmax_scale, - mk_aotensor(out_t, "out"), - mk_aotensor(dout_t, "dout"), - mk_aotensor(dq_t, "dq"), - mk_aotensor(dk_t, "dk"), - mk_aotensor(dv_t, "dv"), - bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, - mk_aotensor<2>(softmax_lse, "L"), - mk_aotensor<2>(delta, "delta"), - float(dropout_p), - mk_aoscalartensor(philox_seed), - mk_aoscalartensor(philox_offset), - 0, - is_causal, - stream); + if (cu_seqlens_q.has_value()) { + at::Tensor delta = at::empty_like(softmax_lse).contiguous(); + // varlen aka Nested tensor + err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q"), + mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_t, "dq"), + mk_aotensor(dk_t, "dk"), + mk_aotensor(dv_t, "dv"), + bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, + mk_aotensor<2>(softmax_lse, "L"), + mk_aotensor<2>(delta, "delta"), + float(dropout_p), + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } else { + auto d_head = Kv; + bool use_fused_bwd = d_head <= 192 && d_head * max_seqlen_q < 64 * 512; + if (use_fused_bwd) { + err = attn_bwd_fused(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_t, "dq"), + mk_aotensor(dk_t, "dk"), + mk_aotensor(dv_t, "dv"), + bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, + mk_aotensor<2>(softmax_lse, "L"), + float(dropout_p), + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } else { + at::Tensor delta = at::empty_like(softmax_lse).contiguous(); + err = attn_bwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_t, "dq"), + mk_aotensor(dk_t, "dk"), + mk_aotensor(dv_t, "dv"), + bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, + mk_aotensor<2>(softmax_lse, "L"), + mk_aotensor<2>(delta, "delta"), + float(dropout_p), + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } + } #else at::Tensor workspace; cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index()); diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index a61d95312fbe3..34f1daa27b9ce 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -26,6 +26,11 @@ #endif #endif +// Avoid potential compiler -Wall -Werror complains undefined macro +#ifndef AOTRITON_VERSION_MINOR +#define AOTRITON_VERSION_MINOR 0 +#endif + /** * Note [SDPA Runtime Dispatch] * SDPA relies on a runtime dispatch mechanism to select the appropriate @@ -83,8 +88,13 @@ int64_t minimum_gemm_alignment(sdp_params const& params) { } bool check_head_dim_size_flash(sdp_params const& params, bool debug) { +#if USE_AOTRITON && AOTRITON_VERSION_MINOR >= 9 + // AOTriton 0.9+ supports head_dim up to 512 + const auto max_size = c10::SymInt(512); +#else // All head_dim sizes must be equal and less than 256 const auto max_size = c10::SymInt(256); +#endif const auto query_size_last = params.query.sym_size(-1); const auto key_size_last = params.key.sym_size(-1); const auto value_size_last = params.value.sym_size(-1); @@ -207,6 +217,16 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug " Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1."); return false; } +#if AOTRITON_VERSION_MINOR >= 9 + if (aotriton::isArchExperimentallySupported(stream)) { + static const bool enable_experimental = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true; + if (!enable_experimental) { + TORCH_WARN_ONCE("Flash Efficient attention on Current AMD GPU is still experimental." + " Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1."); + return false; + } + } +#endif } #else return false; @@ -243,15 +263,16 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) } return false; } - c10::string_view arch(dprops->gcnArchName); - if (arch == "gfx1100") { - static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true; - if (!enable_navi3x) { - TORCH_WARN_ONCE("Memory Efficient attention on Navi31 GPU is still experimental." - " Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1."); +#if AOTRITON_VERSION_MINOR >= 9 + if (aotriton::isArchExperimentallySupported(stream)) { + static const bool enable_experimental = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true; + if (!enable_experimental) { + TORCH_WARN_ONCE("Mem Efficient attention on Current AMD GPU is still experimental." + " Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1."); return false; } } +#endif #else return false; #endif diff --git a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h index 57d5c34444390..1623852b249fe 100644 --- a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h +++ b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h @@ -127,6 +127,12 @@ inline aotriton::TensorView<0> mk_philoxtensor(const int64_t* ptr) aotriton::DType::kUInt64); // AOTriton excepts unsigned int64 } +inline aotriton::TensorView<0> mk_atomictensor(const int32_t* ptr) +{ + return aotriton::TensorView<0>(reinterpret_cast(ptr), + aotriton::DType::kInt32); +} + } // namespace aotriton_adapter } // namespace sdp diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip new file mode 100644 index 0000000000000..22203f22079e7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -0,0 +1,787 @@ +/****************************************************************************** + * Copyright (c) 2023, Advanced Micro Devices, Inc. + * Copyright (c) 2022, Tri Dao. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ +#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS + +#include +#include + +#include + +#ifdef USE_FLASH_ATTENTION +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +#include +#include + +#include + +// AOTriton headers +#include +#include +#include + +#if AOTRITON_VERSION_MINOR < 9 +#error "This adaptor code is only tested with AOTriton 0.9+" +#endif + +namespace pytorch_flash { + +namespace { + +void check_gpu_arch(hipStream_t stream) { + auto ret = aotriton::v2::flash::check_gpu(stream); + if (hipSuccess != ret) { + TORCH_CHECK(false, + "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs" + " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)") + } +} + +// We want to checkpoint and save the RNG state for backward if dropout +// We get the default generator and return the seed and offset which will +// be used in the backward function +std::tuple +prepare_philox_arguments(float p_dropout, int64_t counter_offset) { + at::Tensor seed_t, offset_t; + at::PhiloxCudaState philox_state; + bool use_philox_state = false; + if (p_dropout <= 0.0) { + seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + return { seed_t, offset_t, philox_state, use_philox_state }; + } + auto gen = at::get_generator_or_default(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + std::lock_guard lock(gen->mutex_); + philox_state = gen->philox_cuda_state(counter_offset); + if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { + auto [seed, offset] = at::cuda::philox::unpack(philox_state); + seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), at::dtype(at::kLong).device(at::kCUDA)); + } else { + // See Note [CUDA Graph-safe RNG states] about the design + use_philox_state = true; + seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + } + + return { seed_t, offset_t, philox_state, use_philox_state }; +} + + +} + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +std::tuple +mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &out_, // batch_size x seqlen_q x num_heads x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + const std::optional& gen_) { + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + check_gpu_arch(stream); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + + // FIXME: ROCM probably does not need this + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size_og = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case + if (is_causal) { window_size_right = 0; } + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); + + at::Tensor q_padded, k_padded, v_padded; + q_padded = q; + k_padded = k; + v_padded = v; + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); + if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); } + } else { + out = at::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = round_multiple(head_size, 32); + + auto [seed_t, offset_t, philox_state, use_philox_state] = + prepare_philox_arguments(p_dropout, batch_size * num_heads * 32); + + // Transpose tensors to meet AOTriton's Flash API + at::Tensor q_t = q_padded.permute({0,2,1,3}); + at::Tensor k_t = k_padded.permute({0,2,1,3}); + at::Tensor v_t = v_padded.permute({0,2,1,3}); + at::Tensor output_t = out.permute({0,2,1,3}); + + auto opts = q.options(); + at::Tensor M = at::empty({batch_size * num_heads, seqlen_q}, opts.dtype(at::kFloat)); // aka softmax_lse + + at::Tensor softmax_fa_t; + if (return_softmax) { + softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); + } else { + softmax_fa_t = at::empty({ 0, 0, 0, 0 }, opts); + } + + at::Tensor atomic_counter; + if (is_causal) { + atomic_counter = at::zeros({1}, opts.dtype(at::kInt)); + } + + hipError_t err; // TODO: Error handling + using aotriton::v2::flash::attn_fwd; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::mk_philoxtensor; + using sdp::aotriton_adapter::mk_atomictensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); + auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t); + auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0; + auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr() : nullptr); + auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr() : nullptr); + auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); + err = attn_fwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + empty_bias, + softmax_scale, + mk_aotensor<2>(M, "M"), + mk_aotensor(output_t, "Out"), + p_dropout, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + persistent_counter, + stream); + + return {out, q_padded, k_padded, v_padded, M.view({batch_size, num_heads, seqlen_q}), seed_t, offset_t, softmax_fa_t}; +} + +std::tuple +mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + const std::optional& gen_) { + TORCH_CHECK(!seqused_k.has_value(), "[ROCm] mha_varlen_fwd: seqused_k must be nullopt"); + TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt"); + + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + check_gpu_arch(stream); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int batch_size = cu_seqlens_q.numel() - 1; + int num_heads = sizes[1]; + const int head_size_og = sizes[2]; + const int num_heads_k = k.size(1); + + if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { + is_causal = false; + } // causal=true is the same as causal=false in this case + + at::Tensor temp_q = q; + const int total_q = temp_q.sizes()[0]; + + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (window_size_left >= max_seqlen_k) { + window_size_left = -1; + } + if (window_size_right >= max_seqlen_k) { + window_size_right = -1; + } + + CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og); + const int total_k = k.size(0); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + // AOTriton's varlen API needs input shapes be + // (1, num_heads, total sequence lenght, head dimension) + at::Tensor q_padded, k_padded, v_padded; + at::Tensor out, out_padded; + q_padded = q.unsqueeze(0).transpose(1, 2); + k_padded = k.unsqueeze(0).transpose(1, 2); + v_padded = v.unsqueeze(0).transpose(1, 2); + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, total_q, num_heads, head_size_og); + } else { + out = at::empty_like(q); + } + out_padded = out.unsqueeze(0).transpose(1, 2); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = head_size_og; + + auto opts = q.options(); + + auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor M = softmax_lse.view({batch_size * num_heads, max_seqlen_q}); + at::Tensor softmax_fa_t; + // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { + TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + softmax_fa_t = at::empty({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts); + } else { + softmax_fa_t = at::empty({ 0, 0, 0, 0 }, opts); + } + + if (zero_tensors) { + out.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_softmax) { + softmax_fa_t.zero_(); + } + } + + auto [seed_t, offset_t, philox_state, use_philox_state] = + prepare_philox_arguments(p_dropout, batch_size * num_heads * 32); + + if (max_seqlen_k > 0) { + hipError_t err; // TODO: Error handling + using aotriton::v2::flash::attn_fwd_compact_varlen; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::mk_philoxtensor; + using sdp::aotriton_adapter::cast_dtype; + at::Tensor atomic_counter; + if (is_causal) { + atomic_counter = at::zeros({1}, q.options()); + } + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); + auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t); + auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0; + auto nullscalar = mk_philoxtensor(nullptr); + auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : nullscalar; + auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : nullscalar; + auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr()) : nullscalar; + err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"), + mk_aotensor(k_padded, "k"), + mk_aotensor(v_padded, "v"), + empty_bias, + mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), + mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + softmax_scale, + mk_aotensor<2>(M, "M"), + mk_aotensor(out_padded, "Out"), + p_dropout, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + persistent_counter, + stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + return {out, q, k, v, softmax_lse, seed_t, offset_t, softmax_fa_t}; +} + +std::tuple +mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset) { + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + check_gpu_arch(stream); + + bool is_dropout = p_dropout > 0.0; + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + const int seqlen_q = sizes[1]; + const int num_heads = sizes[2]; + const int head_size_og = dout.size(3); + const int head_size = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + + if (is_causal){ + TORCH_CHECK((seqlen_q == seqlen_k), "For backwards kernel seqlen_q must equal seqlen_k for causal kernels"); + } + + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); + TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + + TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og); + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); + } else { + dq = at::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); + } else { + dk = at::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); + } else { + dv = at::empty_like(k); + } + + auto opts = q.options(); + auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + + at::Tensor q_t = q.permute({0,2,1,3}); + at::Tensor k_t = k.permute({0,2,1,3}); + at::Tensor v_t = v.permute({0,2,1,3}); + at::Tensor out_t = out.permute({0,2,1,3}); + at::Tensor dq_t = dq.permute({0,2,1,3}); + at::Tensor dk_t = dk.permute({0,2,1,3}); + at::Tensor dv_t = dv.permute({0,2,1,3}); + at::Tensor dout_t = dout.permute({0,2,1,3}); + + at::Tensor softmax_lse_cont = softmax_lse.view({batch_size * num_heads, seqlen_q}).contiguous(); + + int d_head = head_size_og; + bool use_fused_bwd = d_head <= 192 && d_head * seqlen_q < 64 * 512; + hipError_t err; // TODO: Error handling + if (use_fused_bwd) { + using aotriton::v2::flash::attn_bwd_fused; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + err = attn_bwd_fused(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + empty_bias, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_t, "dq"), + mk_aotensor(dk_t, "dk"), + mk_aotensor(dv_t, "dv"), + empty_bias, // dbb + mk_aotensor<2>(softmax_lse_cont, "L"), + p_dropout, + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } else { + at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); + using aotriton::v2::flash::attn_bwd; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + err = attn_bwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + empty_bias, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_t, "dq"), + mk_aotensor(dk_t, "dk"), + mk_aotensor(dv_t, "dv"), + empty_bias, // db + mk_aotensor<2>(softmax_lse_cont, "L"), + mk_aotensor<2>(delta, "delta"), + p_dropout, + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } + + return { dq, dk, dv, softmax_d }; +} + +std::tuple +mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset) +{ + TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt"); + + if (is_causal) { + window_size_right = 0; + } + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + check_gpu_arch(stream); + + bool is_dropout = p_dropout > 0.0; + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int total_q = sizes[0]; + const int batch_size = cu_seqlens_q.numel() - 1; + const int num_heads = sizes[1]; + const int head_size_og = dout.size(2); + const int head_size = sizes[2]; + const int total_k = k.size(0); + const int num_heads_k = k.size(1); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size); + CHECK_SHAPE(dout, total_q, num_heads, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + at::Tensor softmax_lse_cont = softmax_lse.view({batch_size * num_heads, max_seqlen_q}).contiguous(); + at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); + + at::Tensor q_padded, k_padded, v_padded; + q_padded = q.unsqueeze(0).transpose(1, 2); + k_padded = k.unsqueeze(0).transpose(1, 2); + v_padded = v.unsqueeze(0).transpose(1, 2); + at::Tensor out_t, dout_t; + out_t = out.unsqueeze(0).transpose(1, 2); + dout_t = dout.unsqueeze(0).transpose(1, 2); + + at::Tensor dq, dk, dv; + at::Tensor dq_padded, dk_padded, dv_padded; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, total_q, num_heads, head_size); + } else { + dq = at::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, total_k, num_heads_k, head_size); + } else { + dk = at::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, total_k, num_heads_k, head_size); + } else { + dv = at::empty_like(v); + } + dq_padded = dq.unsqueeze(0).transpose(1, 2); + dk_padded = dk.unsqueeze(0).transpose(1, 2); + dv_padded = dv.unsqueeze(0).transpose(1, 2); + + auto opts = q.options(); + auto softmax_d = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + + if( zero_tensors ) { + dq.zero_(); + dk.zero_(); + dv.zero_(); + softmax_d.zero_(); + } + + at::PhiloxCudaState philox_args; + if (is_dropout) { + if (at::cuda::currentStreamCaptureStatus() == + at::cuda::CaptureStatus::None) + { + philox_args = at::PhiloxCudaState(*philox_seed.data_ptr(), *philox_offset.data_ptr()); + } else { // dropout + capture + philox_args = at::PhiloxCudaState( + philox_seed.data_ptr(), philox_offset.data_ptr(), 0); + } + } + if (max_seqlen_q > 0) { + hipError_t err; // TODO: Error handling + using aotriton::v2::flash::attn_bwd_compact_varlen; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"), + mk_aotensor(k_padded, "k"), + mk_aotensor(v_padded, "v"), + mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), + mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + empty_bias, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_padded, "dq"), + mk_aotensor(dk_padded, "dk"), + mk_aotensor(dv_padded, "dv"), + empty_bias, + mk_aotensor<2>(softmax_lse_cont, "L"), + mk_aotensor<2>(delta, "delta"), + p_dropout, + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dq.zero_(); + dk.zero_(); + dv.zero_(); + softmax_d.zero_(); + } + + return { dq, dk, dv, softmax_d }; +} +} // namespace pytorch_flash + +#endif diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h new file mode 100644 index 0000000000000..cddd6dfb7a885 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -0,0 +1,322 @@ +#pragma once +#include + +#include +#include +#include + +namespace pytorch_flash { + +// AOTriton Implementation +TORCH_API +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd_aot( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + const std::optional& gen_); + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_varlen_fwd_aot( + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& + seqused_k, // b. If given, only this many elements of each batch + // element's keys are used. + std::optional& alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + const std::optional& gen_); + +std::tuple mha_bwd_aot( + const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x seqlen_q + std::optional& + dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset); + +std::tuple mha_varlen_bwd_aot( + const at::Tensor& dout, // total_q x num_heads, x head_size + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& out, // total_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x s softmax logsumexp + std::optional& + dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional& + dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset); + +TORCH_API +inline std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_) { + return mha_fwd_aot( + q, + k, + v, + out_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +} + +inline std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_varlen_fwd( + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& + seqused_k, // b. If given, only this many elements of each batch + // element's keys are used. + std::optional& alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const bool return_softmax, + std::optional gen_) { + return mha_varlen_fwd_aot( + q, + k, + v, + out_, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +} + +inline std::tuple mha_bwd( + const at::Tensor& dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x seqlen_q + std::optional& + dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) { + return mha_bwd_aot( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); +} + +inline std::tuple mha_varlen_bwd( + const at::Tensor& dout, // total_q x num_heads, x head_size + const at::Tensor& + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& + k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& + v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& out, // total_q x num_heads x head_size + const at::Tensor& softmax_lse, // b x h x s softmax logsumexp + std::optional& + dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional& + dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional& + dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + std::optional& alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const bool deterministic, + const at::Tensor philox_seed, + const at::Tensor philox_offset) { + return mha_varlen_bwd_aot( + dout, + q, + k, + v, + out, + softmax_lse, + dq_, + dk_, + dv_, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + max_seqlen_q, + max_seqlen_k, + p_dropout, + softmax_scale, + zero_tensors, + is_causal, + window_size_left, + window_size_right, + deterministic, + philox_seed, + philox_offset); +} + +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip deleted file mode 100644 index 9b0820a501bf4..0000000000000 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip +++ /dev/null @@ -1,504 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Advanced Micro Devices, Inc. - * Copyright (c) 2022, Tri Dao. - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ -#include -#define TORCH_ASSERT_ONLY_METHOD_OPERATORS - -#include -#include - -#include - -#ifdef USE_FLASH_ATTENTION -#include -#include -#include -#include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#include -#else -#include -#include -#include -#include -#include -#include -#include -#include -#endif - -#include -#include - -#include -#include - -// AOTriton headers -#include -#include - -namespace pytorch_flash { - -namespace { - -void check_gpu_arch(hipStream_t stream) { - auto ret = aotriton::v2::flash::check_gpu(stream); - if (hipSuccess != ret) { - TORCH_CHECK(false, - "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs" - " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)") - } -} - -} - -#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -std::tuple -mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size - c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads - const float p_dropout, - const float softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - const bool return_softmax, - c10::optional gen_) { - auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); - check_gpu_arch(stream); - - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, - "FlashAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); - - // FIXME: ROCM probably does not need this - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - - const auto sizes = q.sizes(); - - const int batch_size = sizes[0]; - int seqlen_q = sizes[1]; - int num_heads = sizes[2]; - const int head_size_og = sizes[3]; - const int seqlen_k = k.size(1); - const int num_heads_k = k.size(2); - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!"); - TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case - if (is_causal) { window_size_right = 0; } - - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); - CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); - - at::Tensor q_padded, k_padded, v_padded; - q_padded = q; - k_padded = k; - v_padded = v; - - at::Tensor out; - if (out_.has_value()) { - out = out_.value(); - TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); - CHECK_DEVICE(out); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); - if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); } - } else { - out = at::empty_like(q_padded); - } - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size = round_multiple(head_size_og, 8); - const int head_size_rounded = round_multiple(head_size, 32); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; - - // We want to checkpoint and save the RNG state for backward if dropout - // We get the default generator and return the seed and offset which will - // be used in the backward function - auto gen = at::get_generator_or_default(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); - at::Tensor seed_t, offset_t; - - at::PhiloxCudaState philox_state; - bool use_philox_state = false; - if (p_dropout > 0.0) { - // number of times random will be generated per thread, to offset philox counter in thc random - // state - // We use a custom RNG that increases the offset by batch_size * nheads * 32. - int64_t counter_offset = batch_size * num_heads * 32; - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - philox_state = gen->philox_cuda_state(counter_offset); - if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { - auto [seed, offset] = at::cuda::philox::unpack(philox_state); - seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), at::dtype(at::kLong).device(at::kCUDA)); - } else { - // See Note [CUDA Graph-safe RNG states] about the design - use_philox_state = true; - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - } - } else { - if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) { - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - } else { - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - } - } - - at::PhiloxCudaState philox_args; - if (p_dropout > 0.0) { - if (at::cuda::currentStreamCaptureStatus() == - at::cuda::CaptureStatus::None) - { - philox_args = at::PhiloxCudaState(*seed_t.data_ptr(), *offset_t.data_ptr()); - } else { // dropout + capture - philox_args = at::PhiloxCudaState(seed_t.data_ptr(), offset_t.data_ptr(), 0); - } - } - - // Transpose tensors to meet AOTriton's Flash API - at::Tensor q_t = q_padded.permute({0,2,1,3}); - at::Tensor k_t = k_padded.permute({0,2,1,3}); - at::Tensor v_t = v_padded.permute({0,2,1,3}); - at::Tensor output_t = out.permute({0,2,1,3}); - - at::Tensor M = at::empty({batch_size * num_heads, seqlen_q}, at::dtype(at::kFloat).device(q.device())); // aka softmax_lse - - at::Tensor softmax_fa_t; - if (return_softmax) { - softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, - at::dtype(q.dtype()).device(q.device())); - } else { - softmax_fa_t = at::empty({ 0, 0, 0, 0 }, at::dtype(q.dtype()).device(q.device())); - } - - hipError_t err; // TODO: Error handling - using aotriton::v2::flash::attn_fwd; - using aotriton::TensorView; - using sdp::aotriton_adapter::mk_aotensor; - using sdp::aotriton_adapter::mk_aoscalartensor; - using sdp::aotriton_adapter::mk_philoxtensor; - using sdp::aotriton_adapter::cast_dtype; - aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); - auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); - auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t); - auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0; - auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : mk_philoxtensor(nullptr); - auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : mk_philoxtensor(nullptr); - err = attn_fwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - empty_bias, - softmax_scale, - mk_aotensor<2>(M, "M"), - mk_aotensor(output_t, "Out"), - p_dropout, - seed, - offset1, - offset2, - seed_output, - offset_output, - mk_aotensor(softmax_fa_t, "encoded_softmax"), - is_causal, - stream); - - return {out, q_padded, k_padded, v_padded, M, seed_t, offset_t, softmax_fa_t}; -} - -std::tuple -mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. - c10::optional &alibi_slopes_, // num_heads or b x num_heads - int max_seqlen_q, - const int max_seqlen_k, - const float p_dropout, - const float softmax_scale, - const bool zero_tensors, - bool is_causal, - int window_size_left, - int window_size_right, - const bool return_softmax, - c10::optional gen_) { - - TORCH_CHECK(false, "mha_varlen_fwd not supported on ROCm"); - - at::Tensor softmax_lse = at::empty({}, at::dtype(at::kFloat)); - at::Tensor p = at::empty({}, at::dtype(at::kFloat)); - at::Tensor offset_t = at::empty({}, at::dtype(at::kLong)); - at::Tensor seed_t = at::empty({}, at::dtype(at::kLong)); - at::Tensor out = at::empty({}, at::dtype(at::kFloat)); - - return {out, q, k, v, softmax_lse, seed_t, offset_t, p}; -} - -std::tuple -mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og - const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &softmax_lse, // b x h x seqlen_q - c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size - c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size - c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads - const float p_dropout, // probability to drop - const float softmax_scale, - const bool is_causal, - int window_size_left, - int window_size_right, - const bool deterministic, - const at::Tensor philox_seed, - const at::Tensor philox_offset) { - auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); - check_gpu_arch(stream); - - bool is_dropout = p_dropout > 0.0; - - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, - "FlashAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); - TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); - - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); - CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); - - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); - TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); - - const auto sizes = q.sizes(); - - const int batch_size = sizes[0]; - const int seqlen_q = sizes[1]; - const int num_heads = sizes[2]; - const int head_size_og = dout.size(3); - const int head_size = sizes[3]; - const int seqlen_k = k.size(1); - const int num_heads_k = k.size(2); - - if (is_causal){ - TORCH_CHECK((seqlen_q == seqlen_k), "For backwards kernel seqlen_q must equal seqlen_k for causal kernels"); - } - - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); - TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); - TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(head_size, 32); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - - TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); - - CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og); - - at::Tensor dq, dk, dv; - if (dq_.has_value()) { - dq = dq_.value(); - TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); - CHECK_DEVICE(dq); - TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); - CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); - } else { - dq = at::empty_like(q); - } - if (dk_.has_value()) { - dk = dk_.value(); - TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); - CHECK_DEVICE(dk); - TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); - CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); - } else { - dk = at::empty_like(k); - } - if (dv_.has_value()) { - dv = dv_.value(); - TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); - CHECK_DEVICE(dv); - TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); - CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); - } else { - dv = at::empty_like(k); - } - - // const at::Tensor& dout_padded = dout; - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; - - auto opts = q.options(); - auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); - - at::Tensor dk_expanded, dv_expanded; - if (num_heads_k != num_heads) { // MQA / GQA - dk_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - dv_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - } else { - dk_expanded = dk; - dv_expanded = dv; - } - - at::PhiloxCudaState philox_args; - if (p_dropout > 0.0) { - if (at::cuda::currentStreamCaptureStatus() == - at::cuda::CaptureStatus::None) - { - philox_args = at::PhiloxCudaState(*philox_seed.data_ptr(), *philox_offset.data_ptr()); - } else { // dropout + capture - philox_args = at::PhiloxCudaState(philox_seed.data_ptr(), philox_offset.data_ptr(), 0); - } - } - - at::Tensor q_t = q.permute({0,2,1,3}); - at::Tensor k_t = k.permute({0,2,1,3}); - at::Tensor v_t = v.permute({0,2,1,3}); - at::Tensor out_t = out.permute({0,2,1,3}); - at::Tensor dq_t = dq.permute({0,2,1,3}); - at::Tensor dk_t = dk.permute({0,2,1,3}); - at::Tensor dv_t = dv.permute({0,2,1,3}); - at::Tensor dout_t = dout.permute({0,2,1,3}); - - at::Tensor softmax_lse_cont = softmax_lse.contiguous(); - at::Tensor delta = at::empty_like(softmax_lse).contiguous(); - - int d_head = head_size_og; - hipError_t err; // TODO: Error handling - { - using aotriton::v2::flash::attn_bwd; - using sdp::aotriton_adapter::mk_aotensor; - using sdp::aotriton_adapter::mk_aoscalartensor; - using sdp::aotriton_adapter::cast_dtype; - aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); - err = attn_bwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - empty_bias, - softmax_scale, - mk_aotensor(out_t, "out"), - mk_aotensor(dout_t, "dout"), - mk_aotensor(dq_t, "dq"), - mk_aotensor(dk_t, "dk"), - mk_aotensor(dv_t, "dv"), - empty_bias, - mk_aotensor<2>(softmax_lse_cont, "L"), - mk_aotensor<2>(delta, "delta"), - p_dropout, - mk_aoscalartensor(philox_seed), - mk_aoscalartensor(philox_offset), - 0, - is_causal, - stream); - } - - // For MQA/GQA we need to sum dK and dV across the groups - if (num_heads_k != num_heads) { - at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); - at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); - } - return { dq, dk, dv, softmax_d }; -#undef CALL_BWD_DROPOUT -#undef CALL_BWD -} - -std::tuple -mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size - const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &out, // total_q x num_heads x head_size - const at::Tensor &softmax_lse, // b x h x s softmax logsumexp - c10::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - c10::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - c10::optional &alibi_slopes_, // num_heads or b x num_heads - const int max_seqlen_q, - const int max_seqlen_k, // max sequence length to choose the kernel - const float p_dropout, // probability to drop - const float softmax_scale, - const bool zero_tensors, - const bool is_causal, - int window_size_left, - int window_size_right, - const bool deterministic, - const at::Tensor philox_seed, - const at::Tensor philox_offset) { - TORCH_CHECK(false, "mha_varlen_bwd not supported on ROCm"); - - at::Tensor softmax_d = at::empty({}, at::dtype(at::kFloat)); - - return { q, k, v, softmax_d }; -} -} // namespace pytorch_fmha - -#endif diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 3cb4b81f81504..395b2695ff879 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1915,9 +1915,4 @@ if(BUILD_PYTHON) add_custom_target(python_copy_files ALL DEPENDS ${build_files}) - - # Install commands - # Pick up static python files - install(DIRECTORY ${CMAKE_BINARY_DIR}/caffe2 DESTINATION ${PYTHON_LIB_REL_PATH} - FILES_MATCHING PATTERN "*.py") endif() diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index bc8535a88ef80..49415bf5ebc2a 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -1,3 +1,16 @@ +macro(get_target_gpus_from_pytorch target_gpus) + set(gfx90a_key MI200) + set(gfx942_key MI300X) + set(gfx1100_key Navi31) + + foreach(X IN LISTS PYTORCH_ROCM_ARCH) + set(key ${X}) + string(APPEND key "_key") + string(APPEND target_gpus ${${key}}) + string(APPEND target_gpus "|") + endforeach() +endmacro() + if(NOT __AOTRITON_INCLUDED) set(__AOTRITON_INCLUDED TRUE) @@ -9,25 +22,31 @@ if(NOT __AOTRITON_INCLUDED) # Replaces .ci/docker/aotriton_version.txt # Note packages information may have versions skipped (due to no ABI breaks) # But they must be listed from lower version to higher version - set(__AOTRITON_VER "0.7.1b") + set(__AOTRITON_RELEASE_PAGE "0.10b") + set(__AOTRITON_VER_LIST + "0.10b" # rocm6.3 + "0.10b" # rocm6.4 + "0.10b" # rocm6.5 + "0.10b" # rocm7.0 + ) set(__AOTRITON_MANYLINUX_LIST - "manylinux_2_17" # rocm6.1 - "manylinux_2_17" # rocm6.2 - "manylinux_2_28" # rocm6.2 "manylinux_2_28" # rocm6.3 + "manylinux_2_28" # rocm6.4 + "manylinux_2_28" # rocm6.5 + "manylinux_2_28" # rocm7.0 ) set(__AOTRITON_ROCM_LIST - "rocm6.1" - "rocm6.2" - "rocm6.2" "rocm6.3" + "rocm6.4" + "rocm6.5" + "rocm7.0" ) - set(__AOTRITON_CI_COMMIT "f6b28a9b7265b69e3df54ea6ba0237e8a8d6f736") + set(__AOTRITON_CI_COMMIT "6fca155f4deeb8d9529326f7b69f350aeeb93477") # source of rocm6.5 with gfx950 set(__AOTRITON_SHA256_LIST - "4f73c9271f95d18c1ef0d824bb6ca0ac63fe7795cfe786ffe4964287be5ecff2" # rocm6.1 - "df00412ae36fe5732d0a4601802bd3622b5dec12df7ec86027c5147adeb54c25" # rocm6.2 - "852d0e6e280cee3256fc5c7c3abed657594d7f56081d768ff8616c08bf9098b2" # rocm6.2 - "e4e3b06d2431e68e0096fcc8d3668cd5034ca0fd6fe236fb3b96774427d934b8" # rocm6.3 + "861cd9f7479eec943933c27cb86920247e5b5dd139bc7c1376c81808abb7d7fe" # rocm6.3 + "acea7d811a2d3bbe718b6e07fc2a9f739e49eecd60b4b6a36fcb3fe8edf85d78" # rocm6.4 + "7e29c325d5bd33ba896ddb106f5d4fc7d715274dca7fe937f724fffa82017838" # rocm6.5 + "1e9b3dddf0c7fc07131c6f0f5266129e83ce2331f459fa2be8c63f4ae91b0f5b" # rocm7.0 ) set(__AOTRITON_Z "gz") @@ -57,13 +76,14 @@ if(NOT __AOTRITON_INCLUDED) list(FIND __AOTRITON_ROCM_LIST "rocm${__AOTRITON_ROCM}" __AOTRITON_ROCM_INDEX) list(GET __AOTRITON_SHA256_LIST ${__AOTRITON_ROCM_INDEX} __AOTRITON_SHA256) list(GET __AOTRITON_MANYLINUX_LIST ${__AOTRITON_ROCM_INDEX} __AOTRITON_MANYLINUX) + list(GET __AOTRITON_VER_LIST ${__AOTRITON_ROCM_INDEX} __AOTRITON_VER) set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) string(CONCAT __AOTRITON_FILE "aotriton-" "${__AOTRITON_VER}-${__AOTRITON_MANYLINUX}" "_${__AOTRITON_ARCH}-rocm${__AOTRITON_ROCM}" "-shared.tar.${__AOTRITON_Z}") string(CONCAT __AOTRITON_URL "https://github.com/ROCm/aotriton/releases/download/" - "${__AOTRITON_VER}/${__AOTRITON_FILE}") + "${__AOTRITON_RELEASE_PAGE}/${__AOTRITON_FILE}") ExternalProject_Add(aotriton_external URL "${__AOTRITON_URL}" URL_HASH SHA256=${__AOTRITON_SHA256} diff --git a/setup.py b/setup.py index e263fffd5e1c8..12fb8b063ad6f 100644 --- a/setup.py +++ b/setup.py @@ -1396,6 +1396,13 @@ def main(): "lib/*.lib", ] ) + aotriton_image_path = os.path.join(lib_path, "aotriton.images") + aks2_files = [] + for root, dirs, files in os.walk(aotriton_image_path): + subpath = os.path.relpath(root, start=aotriton_image_path) + for fn in files: + aks2_files.append(os.path.join("lib/aotriton.images", subpath, fn)) + torch_package_data += aks2_files if get_cmake_cache_vars()["BUILD_CAFFE2"]: torch_package_data.extend( [ diff --git a/test/test_transformers.py b/test/test_transformers.py index d3992a7776aa9..38397e2dab25f 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1462,7 +1462,7 @@ def test_invalid_fused_inputs_head_dim(self, device, kernel: SDPBackend): make_tensor = partial(torch.rand, device=device, dtype=dtype) size = SdpaShape(2, 2, 3, 9) if kernel == SDPBackend.EFFICIENT_ATTENTION else SdpaShape(2, 2, 3, 257) if TEST_WITH_ROCM: # On ROCM, FA and EA share the backend GPU kernels - size = SdpaShape(2, 2, 3, 257) + size = SdpaShape(2, 2, 3, 513) q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, False)) @@ -2692,6 +2692,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 2.0 query_fudge_factor = dropout_fudge_factor + if TEST_WITH_ROCM: + query_fudge_factor += 1.0 grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) # TODO: Investigate why grad_k needs larger tolerances @@ -2814,6 +2816,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, mask_fudge_factor = 1.0 if attn_mask is None else 1.5 query_fudge_factor = dropout_fudge_factor + if TEST_WITH_ROCM: + query_fudge_factor += 1.0 grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor) # TODO: Investigate why grad_k needs larger tolerances diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 96047e61f0304..3b035df1d29e8 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -97,7 +97,6 @@ "aten/src/ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h", "aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h", "aten/src/ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h", - "aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h", "aten/src/THC/*", "aten/src/ATen/test/*", # CMakeLists.txt isn't processed by default, but there are a few diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index e93042e21929d..9d6c41949dcec 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -33,32 +33,36 @@ IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)]) -def CDNA2OrLater(): - if TEST_WITH_ROCM: - gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName - return any(arch in gcn_arch_name for arch in {"gfx90a", "gfx940", "gfx941", "gfx942"}) - return False - -def evaluate_gfx_arch_exact(matching_arch): +def evaluate_gfx_arch_within(arch_list): if not torch.cuda.is_available(): return False gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName - arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name) - return arch == matching_arch + effective_arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name) + # gcnArchName can be complicated strings like gfx90a:sramecc+:xnack- + # Hence the matching should be done reversely + return any(arch in effective_arch for arch in arch_list) -GFX90A_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-')) -GFX942_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')) +def CDNA2OrLater(): + return evaluate_gfx_arch_within(["gfx90a", "gfx942"]) def evaluate_platform_supports_flash_attention(): if TEST_WITH_ROCM: - return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-') + arch_list = ["gfx90a", "gfx942", "gfx1100"] + version = _get_torch_rocm_version() + if version >= (6, 5): + arch_list += ["gfx950"] + return evaluate_gfx_arch_within(arch_list) if TEST_CUDA: return not IS_WINDOWS and SM80OrLater return False def evaluate_platform_supports_efficient_attention(): if TEST_WITH_ROCM: - return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-') + arch_list = ["gfx90a", "gfx942", "gfx1100"] + version = _get_torch_rocm_version() + if version >= (6, 5): + arch_list += ["gfx950"] + return evaluate_gfx_arch_within(arch_list) if TEST_CUDA: return True return False