Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ cmake_dependent_option(
"Whether to build the flash_attention kernel for scaled dot product attention.\
Will be disabled if not supported by the platform"
ON
"USE_CUDA OR USE_ROCM;NOT MSVC"
"(USE_CUDA AND NOT MSVC) OR USE_ROCM"
OFF)

# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
Expand All @@ -883,7 +883,7 @@ cmake_dependent_option(
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
#
if(USE_ROCM)
if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION))
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
include(cmake/External/aotriton.cmake)
endif()
endif()
Expand Down
123 changes: 118 additions & 5 deletions aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,72 @@
#endif
#endif

#if defined(USE_ROCM) && (defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION))
namespace pytorch_flash
{
std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor>
mha_fwd(
const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor>&
out_, // batch_size x seqlen_q x num_heads x head_size
std::optional<at::Tensor>&
alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
#if defined(USE_ROCM_CK_SDPA)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
return mha_fwd_ck(
q,
k,
v,
out_,
p_dropout,
softmax_scale,
is_causal,
non_null_window_left,
non_null_window_right,
return_softmax,
gen_,
dummy_attn_bias); // Not used in flash attention
}
#endif
return mha_fwd_aot(
q,
k,
v,
out_,
alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
return_softmax,
gen_);
}
}
#endif

namespace at {

namespace cuda::philox {
Expand Down Expand Up @@ -1406,12 +1472,15 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
at::Tensor v_t = value.transpose(1, 2);
at::Tensor output_t = res.transpose(1, 2);
bool is_causal;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
is_causal = true;
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
is_causal = false;
} else {
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
is_causal = true;
#if AOTRITON_V3_API == 0
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) != custom_mask_type) {
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
}
#endif
}

at::Tensor atomic_counter;
Expand All @@ -1436,7 +1505,51 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
hipError_t err; // TODO: Error handling
if (seqstart_q.has_value()) {
if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef
#if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions
using aotriton::v3::flash::CausalType;
using aotriton::v3::flash::VarlenType;
using aotriton::v3::flash::WindowValue;
aotriton::v3::flash::attn_fwd_params params;
params.Q = mk_aotensor(q_t, "q");
params.K = mk_aotensor(k_t, "k");
params.V = mk_aotensor(v_t, "v");
params.Sm_scale = softmax_scale;
params.L = compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2;
params.Out = mk_aotensor(output_t, "Out");
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty
params.dropout_p = dropout_p;
params.philox_seed_ptr = seed;
params.philox_offset1 = offset1;
params.philox_offset2 = offset2;
params.philox_seed_output = seed_output;
params.philox_offset_output = offset_output;
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax");
params.persistent_atomic_counter = persistent_counter;
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
params.window_left = WindowValue::TopLeftAligned;
params.window_right = WindowValue::TopLeftAligned;
} else if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) {
params.window_left = WindowValue::BottomRightAligned;
params.window_right = WindowValue::BottomRightAligned;
}
if (bias.has_value()) {
params.B = mk_aotensor(bias.value(), "bias");
}
if (seqstart_q.has_value()) {
params.varlen_type = VarlenType::CompactVarlen;
params.cu_seqlens_q = mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q");
params.cu_seqlens_k = mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k");
} else {
params.varlen_type = VarlenType::None;
}
err = aotriton::v3::flash::attn_fwd(params,
aotriton::v3::flash::attn_fwd_params::kVersion,
stream);
#endif // AOTRITON_V3_API
} else if (seqstart_q.has_value()) {
// varlen aka nested tensor
err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
Expand Down
70 changes: 65 additions & 5 deletions aten/src/ATen/native/transformers/cuda/attention_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/zeros.h>
#include <ATen/ops/zeros_like.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/_flash_attention_backward.h>
Expand All @@ -45,6 +46,7 @@
#include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h>
#include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
#else
#include <ATen/native/transformers/hip/gemm_kernel_utils.h>
// MemoryEfficient Attention Specific Imports for ROCM
#ifndef DISABLE_AOTRITON
#include <ATen/native/transformers/hip/aotriton_adapter.h>
Expand Down Expand Up @@ -482,12 +484,15 @@ _efficient_attention_backward(
}
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
bool is_causal;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
is_causal = true;
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
is_causal = false;
} else {
TORCH_CHECK(false, "[_efficient_attention_backward] Unsupported mask type in AOTriton, for now");
is_causal = true;
#if AOTRITON_V3_API == 0
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) != custom_mask_type) {
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
}
#endif
}
at::Tensor q_t = query.permute({0,2,1,3});
at::Tensor k_t = key.permute({0,2,1,3});
Expand All @@ -506,7 +511,62 @@ _efficient_attention_backward(
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
if (cu_seqlens_q.has_value()) {
if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef
#if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions
using aotriton::v3::flash::CausalType;
using aotriton::v3::flash::VarlenType;
using aotriton::v3::flash::WindowValue;
aotriton::v3::flash::attn_bwd_params params;
params.Q = mk_aotensor(q_t, "q");
params.K = mk_aotensor(k_t, "k");
params.V = mk_aotensor(v_t, "v");
params.B = bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4;
params.Sm_scale = softmax_scale;
params.Out = mk_aotensor(out_t, "out");
params.DO = mk_aotensor(dout_t, "dout");
params.DK = mk_aotensor(dk_t, "dk");
params.DV = mk_aotensor(dv_t, "dv");
params.DQ = mk_aotensor(dq_t, "dq");
params.DB = bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4;
params.L = mk_aotensor<2>(softmax_lse, "L");
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty
params.dropout_p = float(dropout_p);
params.philox_seed_ptr = mk_aoscalartensor(philox_seed);
params.philox_offset1 = mk_aoscalartensor(philox_offset);
params.philox_offset2 = 0;
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
params.window_left = WindowValue::TopLeftAligned;
params.window_right = WindowValue::TopLeftAligned;
} else if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) {
params.window_left = WindowValue::BottomRightAligned;
params.window_right = WindowValue::BottomRightAligned;
}
#if AOTRITON_ALWAYS_V3_API
using sdp::aotriton_adapter::mklazy_empty_like;
using sdp::aotriton_adapter::mklazy_fp32zeros;
using sdp::aotriton_adapter::LazyTensorContext;
LazyTensorContext lazy_delta { .like_tensor = softmax_lse, .tensor_name = "delta" };
LazyTensorContext lazy_dq_acc { .like_tensor = dq_t, .tensor_name = "dq_acc" };
params.D = mklazy_empty_like<2>(&lazy_delta);
params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc);
#else
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
params.D = mk_aotensor<2>(delta, "delta");
#endif
if (cu_seqlens_q.has_value()) {
params.varlen_type = VarlenType::CompactVarlen;
params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q");
params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k");
} else {
params.varlen_type = VarlenType::None;
}
err = aotriton::v3::flash::attn_bwd(params,
aotriton::v3::flash::attn_bwd_params::kVersion,
stream);
#endif // AOTRITON_V3_API
} else if (cu_seqlens_q.has_value()) {
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
// varlen aka Nested tensor
err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"),
Expand Down
45 changes: 43 additions & 2 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <c10/util/irange.h>
#include <c10/util/Array.h>
#include <c10/util/Exception.h>
#include <c10/util/string_view.h>

#if AT_CUDNN_ENABLED()
#include <ATen/cudnn/cudnn-wrapper.h>
Expand All @@ -25,9 +26,12 @@

#if USE_ROCM
#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)
#include <ATen/native/transformers/hip/aotriton_versions.h>
#include <aotriton/flash.h>
#define USE_ROCM_ATTENTION 1
#endif
#else
#define USE_ROCM_ATTENTION 0
#endif

// Avoid potential compiler -Wall -Werror complains undefined macro
Expand Down Expand Up @@ -112,9 +116,24 @@ int64_t minimum_gemm_alignment(sdp_params const& params) {
// caller_is_meff is added to make the TORCH_WARN message showing the correct result
template<bool caller_is_meff = false>
bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
#if USE_ROCM_ATTENTION && AOTRITON_VERSION_MINOR >= 9
#if USE_ROCM_ATTENTION
// AOTriton 0.9+ supports head_dim up to 512
const auto max_size = c10::SymInt(512);
const static auto max_hdim = []() {
#if AOTRITON_VERSION_CURRENT == AOTRITON_VERSION_INT(0, 11)
// gfx11xx only support hdim <= 256 on AOTriton 0.11
auto dprops = at::cuda::getCurrentDeviceProperties();
const c10::basic_string_view<char> arch(dprops->gcnArchName);
if (arch.starts_with("gfx11")) {
return 256;
}
#endif // AOTriton 0.11
#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 9)
return 512;
#else
return 256;
#endif
}();
const auto max_size = c10::SymInt(max_hdim);
#else
// All head_dim sizes must be equal and less than 256
const auto max_size = c10::SymInt(256);
Expand All @@ -139,6 +158,28 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
}
return false;
}
if constexpr(caller_is_meff) {
bool is_half = (params.query.dtype() == at::kHalf) ||
(params.query.dtype() == at::kBFloat16);
const int64_t alignment = is_half ? 8 : 4;
if (!(query_size_last % alignment == 0 && query_size_last > 0 &&
value_size_last % alignment == 0 && value_size_last > 0)) {
if (debug) {
TORCH_WARN(
"Mem efficient attention requires last dimension of inputs to be divisible by ",
alignment,
". ",
"Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
params.key.sym_size(-1),
", Value.size(-1): ",
params.value.sym_size(-1),
" instead.");
}
return false;
}
}
return true;
}

Expand Down
Loading