Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
33b9345
[ROCm] Select gpu targets according to PYTORCH_ROCM_ARCH when buildin…
vickytsang Nov 25, 2024
8853978
Let aotriton.cmake detect the best binary package to use, and depreca…
xinyazhang Jan 9, 2025
438d659
[ROCm] Bump AOTriton to 0.8.2b (#145508)
xinyazhang Jan 28, 2025
326b5a2
Backport AOTriton 0.9.2b
xinyazhang Mar 7, 2025
cef2cbc
AOTriton: add 0.9.2b version built on ROCM 6.5, with gfx950 supported…
xinyazhang May 8, 2025
c584d44
[release/2.7] [AOTriton] Support ROCM 7.0 ABI (#2302)
xinyazhang Jul 1, 2025
f943edb
Try to fix linking error
xinyazhang Jul 7, 2025
e6b625a
Add missing using aotriton::v2::flash::attn_fwd_compact_varlen
xinyazhang Jul 7, 2025
2c56ccd
Revert "Try to fix linking error"
xinyazhang Jul 7, 2025
110fead
Set RUNPATH so installed tests can find the required shared libraries…
kundaMwiza Oct 25, 2024
34adeb7
[CMake] Remove pthread linking (#134436)
cyyever Oct 29, 2024
b47f007
build: add missing file
xinyazhang Jul 7, 2025
8e80a7b
do not hipify tools/amd_build/build_amd.py
xinyazhang Jul 7, 2025
102b3f3
Revert "[CMake] Remove pthread linking (#134436)"
xinyazhang Jul 7, 2025
a143ee5
Revert "Set RUNPATH so installed tests can find the required shared l…
xinyazhang Jul 7, 2025
e75771d
fix build error
xinyazhang Jul 7, 2025
d62e294
fix build error
xinyazhang Jul 7, 2025
611f5b2
fix "file INSTALL cannot make directory" when build with non-root users
xinyazhang Jul 7, 2025
faa5235
add missing aotriton.images
xinyazhang Jul 7, 2025
604c22e
remove files newer than release/2.4
xinyazhang Jul 7, 2025
83f9fca
enable UT for arch supported by AOTriton 0.9.x
xinyazhang Jul 7, 2025
ff5e4b1
USE_ROCM_ATTENTION -> USE_AOTRITON
xinyazhang Jul 8, 2025
1d3b6c3
flash_api: remove _ck functions
xinyazhang Jul 8, 2025
6aa4ae3
Use AOTriton 0.10b instead to pass all UTs
xinyazhang Jul 16, 2025
b2fdd04
fix test_invalid_fused_inputs_head_dim. AOTriton supports hdim <= 512
xinyazhang Jul 16, 2025
714e850
AOTriton 0.10b needs slightly larger fudge factor for dq
xinyazhang Jul 16, 2025
c3834b3
Fix the adaptor code
xinyazhang Jul 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .ci/docker/ubuntu-rocm/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ COPY triton_version.txt triton_version.txt
RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi
RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt

# This is needed by sccache
COPY ./common/install_openssl.sh install_openssl.sh
ENV OPENSSL_ROOT_DIR /opt/openssl
RUN bash ./install_openssl.sh
ENV OPENSSL_DIR /opt/openssl

# Install ccache/sccache (do this last, so we get priority in PATH)
COPY ./common/install_cache.sh install_cache.sh
ENV PATH /opt/cache/bin:$PATH
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
# flash_attention sources
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip")

#Mem_eff attention sources
file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu")
Expand All @@ -188,6 +189,7 @@ if(USE_FLASH_ATTENTION)
list(APPEND ATen_ATTENTION_KERNEL_SRCS ${flash_attention_cuda_kernels_cu})

list(APPEND native_transformers_hip_hip ${flash_attention_hip_hip})
list(APPEND native_transformers_hip_hip ${flash_attention_hip_aot_hip})
list(APPEND native_transformers_src_hip_hip ${flash_attention_src_hip_hip})
endif()

Expand Down
105 changes: 73 additions & 32 deletions aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1077,19 +1077,22 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs"
" (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)")
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/7900XTX/9070XT GPUs"
" (gfx90a/gfx942/gfx1100/gfx1201)")
}

// AOTriton may accept aligned on logsumexp tensor in the future for better
// performance, but for now it requires compact logsumexp tensor, even if
// compute_logsumexp is false
constexpr int kAlignLSE = 1;
res = at::empty({B, M, num_heads, Kv}, query.options());
at::Tensor softmax_lse;
logsumexp = at::empty(
{ B, num_heads, max_seqlen_q },
{ B, num_heads, compute_logsumexp ? max_seqlen_q : 0},
query.options().dtype(at::ScalarType::Float));
at::Tensor softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q});
if (compute_logsumexp) {
softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q});
}
at::Tensor q_t = query.transpose(1, 2);
at::Tensor k_t = key.transpose(1, 2);
at::Tensor v_t = value.transpose(1, 2);
Expand All @@ -1105,40 +1108,69 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_

const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();

at::Tensor atomic_counter;
if (is_causal) {
atomic_counter = at::zeros({1}, query.options().dtype(at::kInt));
}

using aotriton::v2::flash::attn_fwd;
using aotriton::v2::flash::attn_fwd_compact_varlen;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::mk_philoxtensor;
using sdp::aotriton_adapter::mk_atomictensor;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16);
aotriton::TensorView<2> empty_t2(0, {0, 0}, {0, 0}, aotriton::DType::kFloat32);
at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options());
const bool use_philox_state = in_capture_stream;
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t);
auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0;
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr<int64_t>() : nullptr);
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
hipError_t err; // TODO: Error handling
err = attn_fwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
softmax_scale,
mk_aotensor<2>(softmax_lse, "M"),
mk_aotensor(output_t, "Out"),
dropout_p,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
stream);
if (!compute_logsumexp) {
// Set the tensor to empty when compute_logsumexp is false
logsumexp = at::empty(
{ B * num_heads, max_seqlen_q, 0 },
query.options().dtype(at::ScalarType::Float));
if (seqstart_q.has_value()) {
// varlen aka nested tensor
err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"),
mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
softmax_scale,
compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2,
mk_aotensor(output_t, "Out"),
dropout_p,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
persistent_counter,
stream);
} else {
err = attn_fwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
softmax_scale,
compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2,
mk_aotensor(output_t, "Out"),
dropout_p,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
persistent_counter,
stream);
}
#else
// CUDA Implementation
Expand Down Expand Up @@ -1401,15 +1433,24 @@ at::Tensor& _fill_mem_eff_dropout_mask_(
#if defined(USE_MEM_EFF_ATTENTION)

#ifdef USE_ROCM
using aotriton::v2::flash::debug_fill_dropout_rng;
using aotriton::v2::flash::debug_simulate_encoded_softmax;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
at::cuda::CUDAGuard device_guard(self.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

at::Tensor seed_t, offset_t;
const auto options = at::dtype(at::kLong).device(at::kCUDA);
seed_t = at::scalar_tensor(at::Scalar(seed), options);
offset_t = at::scalar_tensor(at::Scalar(offset), options);
hipError_t err; // TODO: Error handling

err = debug_fill_dropout_rng(mk_aotensor(self, "r"),
static_cast<uint64_t>(seed),
static_cast<uint64_t>(offset),
stream);
err = debug_simulate_encoded_softmax(mk_aotensor(self, "r"),
dropout_p,
mk_aoscalartensor(seed_t),
mk_aoscalartensor(offset_t),
0,
stream);
#else
at::PhiloxCudaState rng_engine_inputs;
rng_engine_inputs = at::PhiloxCudaState(seed, offset);
Expand Down
97 changes: 75 additions & 22 deletions aten/src/ATen/native/transformers/cuda/attention_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,8 @@ _efficient_attention_backward(
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs"
" (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)")
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/7900XTX/9070XT GPUs"
" (gfx90a/gfx942/gfx1100/gfx1201)")
}
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
bool is_causal;
Expand All @@ -404,33 +404,86 @@ _efficient_attention_backward(
at::Tensor dv_t = grad_v.permute({0,2,1,3});
at::Tensor dout_t = grad_out.permute({0,2,1,3});
at::Tensor softmax_lse = logsumexp.view({B * nH, max_seqlen_q});
at::Tensor delta = at::empty_like(softmax_lse).contiguous();

hipError_t err;
using aotriton::v2::flash::attn_bwd;
using aotriton::v2::flash::attn_bwd_fused;
using aotriton::v2::flash::attn_bwd_compact_varlen;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
err = attn_bwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
softmax_scale,
mk_aotensor(out_t, "out"),
mk_aotensor(dout_t, "dout"),
mk_aotensor(dq_t, "dq"),
mk_aotensor(dk_t, "dk"),
mk_aotensor(dv_t, "dv"),
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
mk_aotensor<2>(softmax_lse, "L"),
mk_aotensor<2>(delta, "delta"),
float(dropout_p),
mk_aoscalartensor(philox_seed),
mk_aoscalartensor(philox_offset),
0,
is_causal,
stream);
if (cu_seqlens_q.has_value()) {
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
// varlen aka Nested tensor
err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q"),
mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
softmax_scale,
mk_aotensor(out_t, "out"),
mk_aotensor(dout_t, "dout"),
mk_aotensor(dq_t, "dq"),
mk_aotensor(dk_t, "dk"),
mk_aotensor(dv_t, "dv"),
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
mk_aotensor<2>(softmax_lse, "L"),
mk_aotensor<2>(delta, "delta"),
float(dropout_p),
mk_aoscalartensor(philox_seed),
mk_aoscalartensor(philox_offset),
0,
is_causal,
stream);
} else {
auto d_head = Kv;
bool use_fused_bwd = d_head <= 192 && d_head * max_seqlen_q < 64 * 512;
if (use_fused_bwd) {
err = attn_bwd_fused(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
softmax_scale,
mk_aotensor(out_t, "out"),
mk_aotensor(dout_t, "dout"),
mk_aotensor(dq_t, "dq"),
mk_aotensor(dk_t, "dk"),
mk_aotensor(dv_t, "dv"),
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
mk_aotensor<2>(softmax_lse, "L"),
float(dropout_p),
mk_aoscalartensor(philox_seed),
mk_aoscalartensor(philox_offset),
0,
is_causal,
stream);
} else {
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
err = attn_bwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
softmax_scale,
mk_aotensor(out_t, "out"),
mk_aotensor(dout_t, "dout"),
mk_aotensor(dq_t, "dq"),
mk_aotensor(dk_t, "dk"),
mk_aotensor(dv_t, "dv"),
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
mk_aotensor<2>(softmax_lse, "L"),
mk_aotensor<2>(delta, "delta"),
float(dropout_p),
mk_aoscalartensor(philox_seed),
mk_aoscalartensor(philox_offset),
0,
is_causal,
stream);
}
}
#else
at::Tensor workspace;
cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
Expand Down
33 changes: 27 additions & 6 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
#endif
#endif

// Avoid potential compiler -Wall -Werror complains undefined macro
#ifndef AOTRITON_VERSION_MINOR
#define AOTRITON_VERSION_MINOR 0
#endif

/**
* Note [SDPA Runtime Dispatch]
* SDPA relies on a runtime dispatch mechanism to select the appropriate
Expand Down Expand Up @@ -83,8 +88,13 @@ int64_t minimum_gemm_alignment(sdp_params const& params) {
}

bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
#if USE_AOTRITON && AOTRITON_VERSION_MINOR >= 9
// AOTriton 0.9+ supports head_dim up to 512
const auto max_size = c10::SymInt(512);
#else
// All head_dim sizes must be equal and less than 256
const auto max_size = c10::SymInt(256);
#endif
const auto query_size_last = params.query.sym_size(-1);
const auto key_size_last = params.key.sym_size(-1);
const auto value_size_last = params.value.sym_size(-1);
Expand Down Expand Up @@ -207,6 +217,16 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
" Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1.");
return false;
}
#if AOTRITON_VERSION_MINOR >= 9
if (aotriton::isArchExperimentallySupported(stream)) {
static const bool enable_experimental = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true;
if (!enable_experimental) {
TORCH_WARN_ONCE("Flash Efficient attention on Current AMD GPU is still experimental."
" Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1.");
return false;
}
}
#endif
}
#else
return false;
Expand Down Expand Up @@ -243,15 +263,16 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
}
return false;
}
c10::string_view arch(dprops->gcnArchName);
if (arch == "gfx1100") {
static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true;
if (!enable_navi3x) {
TORCH_WARN_ONCE("Memory Efficient attention on Navi31 GPU is still experimental."
" Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1.");
#if AOTRITON_VERSION_MINOR >= 9
if (aotriton::isArchExperimentallySupported(stream)) {
static const bool enable_experimental = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true;
if (!enable_experimental) {
TORCH_WARN_ONCE("Mem Efficient attention on Current AMD GPU is still experimental."
" Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1.");
return false;
}
}
#endif
#else
return false;
#endif
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/transformers/hip/aotriton_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ inline aotriton::TensorView<0> mk_philoxtensor(const int64_t* ptr)
aotriton::DType::kUInt64); // AOTriton excepts unsigned int64
}

inline aotriton::TensorView<0> mk_atomictensor(const int32_t* ptr)
{
return aotriton::TensorView<0>(reinterpret_cast<intptr_t>(ptr),
aotriton::DType::kInt32);
}

} // namespace aotriton_adapter

} // namespace sdp
Expand Down
Loading