Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
29bf4d2
update CK
alugorey Jan 14, 2025
412aeb5
Remove generated files
alugorey Jan 14, 2025
0d27528
bring new fmha_bwd/fwd.hpp files
alugorey Jan 14, 2025
cf1e5be
add_subdirectory to ATen/CMakeLists.txt
alugorey Feb 10, 2025
92c437e
Add script for swapping make_kernel with make_kernel_pt
alugorey Feb 10, 2025
c884aa6
Create ck flash attention CMakeLists.txt file
alugorey Feb 10, 2025
e3ab46d
Update CK to pick up receipt 4
alugorey Feb 11, 2025
22940d3
Update to receipt 4
alugorey Feb 11, 2025
1b3ac11
Add warning if CK is requested when unsupported
alugorey Feb 11, 2025
da17d67
Add generated files to .gitignore
alugorey Feb 11, 2025
4cc20da
lint
alugorey Feb 12, 2025
1898200
Initial plumbing for mem_eff path
alugorey Jan 28, 2025
78f9f55
Began writing function signature
alugorey Jan 29, 2025
03d5fda
parameters aligned (pre-compile)
alugorey Jan 30, 2025
5fd4a30
parameters aligned (post-compile)
alugorey Feb 3, 2025
c9752e9
Called my new func (pre-compile)
alugorey Feb 4, 2025
8b3256f
call my func (post-compile)
alugorey Feb 4, 2025
af6d879
feed attn_bias to mha_fwd* (pre-compile)
alugorey Feb 5, 2025
2599a68
feed attn_bias to mha_fwd* (post-compile)
alugorey Feb 5, 2025
26d71a8
fighting linker runtime error
alugorey Feb 5, 2025
ab012d6
Move declaration to new header
alugorey Feb 6, 2025
b09806e
Compiled linker error and not hitting runtime error
alugorey Feb 6, 2025
6bee61e
descriptive changes to just test_transformers.py DELETE LATER
alugorey Feb 6, 2025
6700952
some logging
alugorey Feb 6, 2025
518c0cb
Finished initial implementation for fwd (post-compile, pre-run)
alugorey Feb 7, 2025
20f5401
Saving place, BLOCKED on codegen
alugorey Feb 7, 2025
e484624
debug traces DELETE LATER
alugorey Feb 12, 2025
8f9f165
remove backward from test. START BWD AFTER THIS COMMIT
alugorey Feb 12, 2025
84e8197
First draft of bwd function signature (pre-compile)
alugorey Feb 12, 2025
7f003d1
Calling my function as a no-op (pre-compile)
alugorey Feb 13, 2025
11099ab
Calling my function as a no-op (post-compile)
alugorey Feb 13, 2025
371a907
getting ready to implement wrapper. just some comments for that
alugorey Feb 14, 2025
c73aed1
calling mha stuff in wrapper (pre-compile)
alugorey Feb 14, 2025
cc5db9f
calling mha stuff in wrapper (post-compile)
alugorey Feb 17, 2025
27e2bbd
Start feeding grad_bias through (post-compile)
alugorey Feb 18, 2025
fc1c418
Remove unneeded re-naming scripts
alugorey Feb 18, 2025
84b3d7d
mha_varlen_fwd plumbing and replacing alibi_slopes
alugorey Feb 18, 2025
1967d05
Feeding grad_bias through to mha_bwd/varlen_ck
alugorey Feb 18, 2025
ab7740b
passed dbias through to CK (pre-compile)
alugorey Feb 19, 2025
a7152d0
pass dbias (post-compile)
alugorey Feb 19, 2025
58f4c62
Returning dbias up (pre-compile)
alugorey Feb 20, 2025
49aafd8
returning bias up (post-compile)
alugorey Feb 20, 2025
db02be9
Add branch on CK preferred backend1
alugorey Feb 20, 2025
1d727fe
Sanity is working E2E needs clean up and final verification
alugorey Feb 20, 2025
2877d4c
cleaned up lse bug traces
alugorey Feb 20, 2025
5c9fd0b
Chasing varlen bug. saving place
alugorey Feb 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
torch/version.py
minifier_launcher.py
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_convert*
aten/src/ATen/native/transformers/hip/flash_attn/ck/fwd_blob*
aten/src/ATen/native/transformers/hip/flash_attn/ck/bwd_blob*
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_api*
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_api*
# Root level file used in CI to specify certain env configs.
# E.g., see .circleci/config.yaml
env
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ if(USE_FLASH_ATTENTION)
endif()
endif()
message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled")
message(STATUS "Generating CK kernel instances...")
add_subdirectory(native/transformers/hip/flash_attn/ck)
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
endif()
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/transformers/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
#include <ATen/ops/_safe_softmax_native.h>
#include <ATen/ops/all.h>
#endif

#include <iostream>
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
namespace at::native {

Expand Down Expand Up @@ -741,6 +741,7 @@ Tensor scaled_dot_product_attention(
if (attn_mask.has_value()) {
attn_mask.value() = preprocess_mask(attn_mask.value(), query_, key, value);;
}
//std::cout << "OUTERMOST Q SHAPE: " << query_.sizes() << std::endl;
auto out_and_lse = at::_scaled_dot_product_efficient_attention(
query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, scale);
return std::get<0>(out_and_lse);
Expand Down
236 changes: 154 additions & 82 deletions aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
#include <ATen/native/transformers/hip/aotriton_adapter.h>
#include <aotriton/flash.h>
#include <aotriton/runtime.h>
#include <ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h>
#endif
#endif

Expand Down Expand Up @@ -848,13 +849,25 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
std::optional<double> scale) {
// Used for tracking usage statistics
C10_LOG_API_USAGE_ONCE("torch.sdpa.mem_efficient_attention");
//std::cout << std::endl;
//std::cout << "what we want vvvvvvvvvvvvvv" << std::endl;
//std::cout << "MAX_SEQLEN_Q: " << b4_max_seqlen_batch_q << std::endl;
//std::cout << "MAX_SEQLEN_K: " << b4_max_seqlen_batch_k << std::endl;
//std::cout << "MAX_SEQLEN_V: " << b4_max_seqlen_batch_v << std::endl;
//std::cout << "^^^^^^^^^^^^^^^^^^^^^^^^^^^^" << std::endl;
// Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head)
// Key -> Key(Batch x KV_seq_len x Num_heads x Dim_per_head)
// Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head)
std::cout << std::endl;
std::cout << "sdpa_ef" << std::endl;
std::cout << "q.sizes : " << query.sizes() << std::endl;
Tensor q_t = query.transpose(1, 2);
Tensor k_t = key.transpose(1, 2);
Tensor v_t = value.transpose(1, 2);

std::cout << "q_t.sizes: " << q_t.sizes() << std::endl;

std::cout << "qagain.sizes: " << query.sizes() << std::endl;
sdp::CustomMaskType custom_mask_type = is_causal
? sdp::CustomMaskType::CausalFromTopLeft
: sdp::CustomMaskType::NoCustomMask;
Expand Down Expand Up @@ -1026,6 +1039,16 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
// TODO In theory it is possible to compile with _CUDA_ARCH < 5.0 and run on a
// machine that is >= 5.0. In practice, this is not a problem but since
// this would avoid runtime architecture checks, we should look into it
const int64_t new_max_seqlen_batch_q = query.size(1);
const int64_t new_max_seqlen_batch_k = key.size(1);
const int64_t new_max_seqlen_batch_v = value.size(1);
std::cout << std::endl;
std::cout << "MEMORY_EFFICIENT VVVVVVVVVV" << std::endl;
std::cout << "MAX_SEQLEN_Q: " << new_max_seqlen_batch_q << std::endl;
std::cout << "MAX_SEQLEN_K: " << new_max_seqlen_batch_k << std::endl;
std::cout << "MAX_SEQLEN_V: " << new_max_seqlen_batch_v << std::endl;
std::cout << "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" << std::endl;
TORCH_CHECK(query.dim() == 4);

TORCH_CHECK(query.dim() == 4);
TORCH_CHECK(key.dim() == 4);
Expand Down Expand Up @@ -1128,97 +1151,145 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_

#ifdef USE_ROCM
// ROCM Implementation
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs"
" (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)")
if( bias.has_value() ) {
std::cout << std::endl;
std::cout << "Attn_bias sizes : " << bias.value().sizes() << std::endl;
std::cout << "attn_bias device: " << bias.value().device() << std::endl;
std::cout << "last dim stride: " << bias.value().stride(-1) << std::endl;
}

// AOTriton may accept aligned on logsumexp tensor in the future for better
// performance, but for now it requires compact logsumexp tensor, even if
// compute_logsumexp is false
constexpr int kAlignLSE = 1;
// Need this in both aot and CK case
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
res = at::empty({B, M, num_heads, Kv}, query.options());
logsumexp = at::empty(

std::cout << "CK Enabled?: " << at::globalContext().getROCmFAPreferredBackend() << std::endl;
if(at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
//forward_attention_ck(...);
std::cout << "In my branch" << std::endl;
std::optional<Tensor> out(res);
std::optional<Tensor> seqused_k = std::nullopt;
std::optional<Tensor> alibi_slopes = std::nullopt;
std::cout << "out(res) dtype " << out.value().dtype();

auto
[out_,
q,
k,
v,
lse,
seed_t,
offset_t,
p] =
pytorch_flash::mem_eff_forward_ck(
query,
key,
value,
dropout_p,
false, // return dropout_randval
custom_mask_type == 0 ? false : true, // is_causal
softmax_scale,
bias,
out,
std::nullopt, // cu_seqlens_q: sending in nothing since CKFA works this way
std::nullopt, // cu_seqlens_k
seqstart_q,
seqstart_k,
std::nullopt,// not passing in optional gen_
seqused_k);// not passing in optional seqused_k_

logsumexp = lse;
} else { // use aotriton
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs"
" (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)")
}

// AOTriton may accept aligned on logsumexp tensor in the future for better
// performance, but for now it requires compact logsumexp tensor, even if
// compute_logsumexp is false
constexpr int kAlignLSE = 1;
logsumexp = at::empty(
{ B, num_heads, max_seqlen_q },
query.options().dtype(at::ScalarType::Float));
at::Tensor softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q});
at::Tensor q_t = query.transpose(1, 2);
at::Tensor k_t = key.transpose(1, 2);
at::Tensor v_t = value.transpose(1, 2);
at::Tensor output_t = res.transpose(1, 2);
bool is_causal;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
is_causal = true;
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
is_causal = false;
} else {
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
}
at::Tensor softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q});
at::Tensor q_t = query.transpose(1, 2);
at::Tensor k_t = key.transpose(1, 2);
at::Tensor v_t = value.transpose(1, 2);
at::Tensor output_t = res.transpose(1, 2);
bool is_causal;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
is_causal = true;
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
is_causal = false;
} else {
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
}

const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();

using aotriton::v2::flash::attn_fwd;
using aotriton::v2::flash::attn_fwd_compact_varlen;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::mk_philoxtensor;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16);
at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options());
const bool use_philox_state = in_capture_stream;
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t);
auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0;
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
hipError_t err; // TODO: Error handling
if (seqstart_q.has_value()) {
// varlen aka nested tensor
err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"),
mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
softmax_scale,
mk_aotensor<2>(softmax_lse, "M"),
mk_aotensor(output_t, "Out"),
dropout_p,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
stream);
} else {
err = attn_fwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
softmax_scale,
mk_aotensor<2>(softmax_lse, "M"),
mk_aotensor(output_t, "Out"),
dropout_p,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
stream);
}
if (!compute_logsumexp) {
// Set the tensor to empty when compute_logsumexp is false
logsumexp = at::empty(
using aotriton::v2::flash::attn_fwd;
using aotriton::v2::flash::attn_fwd_compact_varlen;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::mk_philoxtensor;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16);
at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options());
const bool use_philox_state = in_capture_stream;
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t);
auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0;
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
hipError_t err; // TODO: Error handling
if (seqstart_q.has_value()) {
// varlen aka nested tensor
err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"),
mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
softmax_scale,
mk_aotensor<2>(softmax_lse, "M"),
mk_aotensor(output_t, "Out"),
dropout_p,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
stream);
} else {
err = attn_fwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
softmax_scale,
mk_aotensor<2>(softmax_lse, "M"),
mk_aotensor(output_t, "Out"),
dropout_p,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
stream);
}
if (!compute_logsumexp) {
// Set the tensor to empty when compute_logsumexp is false
logsumexp = at::empty(
{ B * num_heads, max_seqlen_q, 0 },
query.options().dtype(at::ScalarType::Float));
}
}
} // CK BACKEND
#else
// CUDA Implementation
cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
Expand Down Expand Up @@ -1391,6 +1462,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
AT_CUDA_CHECK(cudaGetLastError());

#endif // USE_ROCM
std::cout << "res dtype: " << res.dtype() << std::endl;
return std::make_tuple(
std::move(res),
std::move(logsumexp),
Expand Down
Loading