Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ cmake_dependent_option(
"Whether to build the flash_attention kernel for scaled dot product attention.\
Will be disabled if not supported by the platform"
ON
"USE_CUDA OR USE_ROCM;NOT MSVC"
"(USE_CUDA AND NOT MSVC) OR USE_ROCM"
OFF)

cmake_dependent_option(
Expand Down Expand Up @@ -908,7 +908,7 @@ cmake_dependent_option(
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
#
if(USE_ROCM)
if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION))
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
include(cmake/External/aotriton.cmake)
endif()
endif()
Expand Down
21 changes: 2 additions & 19 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ void Context::setUserEnabledNNPACK(bool e) {
}

bool Context::allowTF32CuDNN(const std::string& op) const {
if (op.size() == 0){
if (op.empty()){
bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32";
bool allow_tf32_conv = float32Precision("cuda", "conv") == "tf32";
TORCH_CHECK(
Expand Down Expand Up @@ -281,9 +281,6 @@ bool Context::userEnabledOverrideableSDP() const {

static constexpr const auto cublas_config_var_name = "CUBLAS_WORKSPACE_CONFIG";
static constexpr const std::array<const char*, 2> cublas_deterministic_configs = {":4096:8", ":16:8"};
#ifdef USE_ROCM
static constexpr const auto hipblaslt_allow_tf32 = "HIPBLASLT_ALLOW_TF32";
#endif

bool Context::checkCuBLASConfigDeterministic() {
// If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config
Expand Down Expand Up @@ -343,12 +340,6 @@ void Context::setImmediateMiopen(bool b) {
}

bool Context::allowTF32CuBLAS() const {
#ifdef USE_ROCM
const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
if (allow_tf32 != true) {
return false;
}
#endif
bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
bool allow_tf32_new = float32Precision("cuda", "matmul") == "tf32";
TORCH_CHECK(
Expand All @@ -362,14 +353,6 @@ bool Context::allowTF32CuBLAS() const {
}

void Context::setAllowTF32CuBLAS(bool b) {
#ifdef USE_ROCM
const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
if (allow_tf32 != true) {
C10_LOG_FIRST_N(INFO, 10) << "torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. "
<< "Please set environment variable HIPBLASLT_ALLOW_TF32=1 to enable it.";
return;
}
#endif
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
setFloat32Precision("cuda", "matmul", b ? "tf32" : "ieee");
}
Expand Down Expand Up @@ -443,7 +426,7 @@ void Context::setFloat32Precision(const std::string& backend, const std::string&
std::string msg;
auto iterp = _fp32_precisions.find(backend);
TORCH_CHECK(iterp != _fp32_precisions.end());
for (auto p : iterp->second) {
for (const auto& p : iterp->second) {
msg += p;
msg += " ";
}
Expand Down
66 changes: 66 additions & 0 deletions aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,72 @@
#endif
#endif

#if defined(USE_ROCM) && (defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION))
namespace pytorch_flash
{
std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor>
mha_fwd(
const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor>&
out_, // batch_size x seqlen_q x num_heads x head_size
std::optional<at::Tensor>&
alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
#if defined(USE_ROCM_CK_SDPA)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
return mha_fwd_ck(
q,
k,
v,
out_,
p_dropout,
softmax_scale,
is_causal,
non_null_window_left,
non_null_window_right,
return_softmax,
gen_,
dummy_attn_bias); // Not used in flash attention
}
#endif
return mha_fwd_aot(
q,
k,
v,
out_,
alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
return_softmax,
gen_);
}
}
#endif

namespace at {

namespace cuda::philox {
Expand Down
22 changes: 22 additions & 0 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,28 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
}
return false;
}
if constexpr(caller_is_meff) {
bool is_half = (params.query.dtype() == at::kHalf) ||
(params.query.dtype() == at::kBFloat16);
const int64_t alignment = is_half ? 8 : 4;
if (!(query_size_last % alignment == 0 && query_size_last > 0 &&
value_size_last % alignment == 0 && value_size_last > 0)) {
if (debug) {
TORCH_WARN(
"Mem efficient attention requires last dimension of inputs to be divisible by ",
alignment,
". ",
"Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
params.key.sym_size(-1),
", Value.size(-1): ",
params.value.sym_size(-1),
" instead.");
}
return false;
}
}
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,10 +462,11 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::mk_philoxtensor;
using sdp::aotriton_adapter::mk_atomictensor;
using sdp::aotriton_adapter::cast_dtype;
at::Tensor atomic_counter;
if (is_causal) {
atomic_counter = at::zeros({1}, q.options());
atomic_counter = at::zeros({1}, q.options().dtype(at::kInt));
}
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
Expand All @@ -474,7 +475,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
auto nullscalar = mk_philoxtensor(nullptr);
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : nullscalar;
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : nullscalar;
auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr<int64_t>()) : nullscalar;
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
if (uses_swa || AOTRITON_ALWAYS_V3_API) {
#if AOTRITON_V3_API
using aotriton::v3::flash::CausalType;
Expand Down
39 changes: 2 additions & 37 deletions aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varle
#endif

TORCH_API
inline std::tuple<
std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
Expand All @@ -294,42 +294,7 @@ mha_fwd(
std::optional<int64_t> window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
#if defined(USE_ROCM_CK_SDPA)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
return mha_fwd_ck(
q,
k,
v,
out_,
p_dropout,
softmax_scale,
is_causal,
non_null_window_left,
non_null_window_right,
return_softmax,
gen_,
dummy_attn_bias); // Not used in flash attention
}
#endif
return mha_fwd_aot(
q,
k,
v,
out_,
alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
return_softmax,
gen_);
}
std::optional<at::Generator> gen_);

inline std::tuple<
at::Tensor,
Expand Down
Loading