-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Integration flash attention 2 #55758
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
b1cd6cf
to
12d4b4c
Compare
num_splits = 1; | ||
} | ||
bool zero_tensors = false; | ||
const int total_q = dims[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用int64_t
吧
@@ -55,110 +55,75 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, | |||
ctx.template Alloc<T>(dk); | |||
ctx.template Alloc<T>(dv); | |||
|
|||
cudaStream_t stream = ctx.stream(); | |||
bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false; | |||
const cudaStream_t stream = ctx.stream(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些参数计算逻辑,似乎在Pad和UnPad、前反向都有用到,可以定义个struct
统一下计算代码?
const int seqlen_q_rounded = round_multiple(seqlen_q, 128); | ||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128); | ||
|
||
softmax_lse->Resize({batch_size, num_heads, seqlen_q_rounded}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
softmax_lse->Resize({batch_size, num_heads, seqlen_q})
. See https://github.com/Dao-AILab/flash-attention/blob/d30f2e1cd50185c98ed88c0684b4a603f15bee37/csrc/flash_attn/flash_api.cpp#L273 .
不过这里多分配空间应该也不影响。 @Xreki 也可以帮忙看看。
.gitmodules
Outdated
[submodule "third_party/gtest"] | ||
path = third_party/gtest | ||
url = https://github.com/google/googletest.git | ||
ignore = dirty | ||
[submodule "third_party/flashattn"] | ||
path = third_party/flashattn | ||
url = https://github.com/umiswing/flash-attention.git |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
submodule改回Paddle的
DenseTensor* _softmax, | ||
DenseTensor* _softmax_lse, | ||
DenseTensor* _seed_offset, | ||
const DenseTensor* const fixed_seed_offset_ptr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一般const参数放前面
int num_splits = 0; // 0 for an internal heuristic, which is optimal | ||
if (FLAGS_cudnn_deterministic) { | ||
num_splits = 1; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
跟determinstic相关的这几行代码要不保留,注释掉并加个TODO吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Great work~
* Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80.
* Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80.
* Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80.
* [FlashAttn] add flash randomness control (#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <kuizhiqing@msn.com>
…lePaddle#56015) * [FlashAttn] add flash randomness control (PaddlePaddle#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (PaddlePaddle#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <kuizhiqing@msn.com>
…lePaddle#56015) * [FlashAttn] add flash randomness control (PaddlePaddle#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (PaddlePaddle#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <kuizhiqing@msn.com>
…lePaddle#56015) * [FlashAttn] add flash randomness control (PaddlePaddle#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (PaddlePaddle#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <kuizhiqing@msn.com>
…lePaddle#56015) * [FlashAttn] add flash randomness control (PaddlePaddle#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (PaddlePaddle#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <kuizhiqing@msn.com>
…lePaddle#56015) * [FlashAttn] add flash randomness control (PaddlePaddle#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (PaddlePaddle#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <kuizhiqing@msn.com>
…lePaddle#56015) * [FlashAttn] add flash randomness control (PaddlePaddle#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (PaddlePaddle#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <kuizhiqing@msn.com>
…lePaddle#56015) * [FlashAttn] add flash randomness control (PaddlePaddle#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (PaddlePaddle#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <kuizhiqing@msn.com>
…lePaddle#56015) * [FlashAttn] add flash randomness control (PaddlePaddle#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (PaddlePaddle#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <kuizhiqing@msn.com>
…lePaddle#56015) * [FlashAttn] add flash randomness control (PaddlePaddle#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (PaddlePaddle#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <kuizhiqing@msn.com>
…lePaddle#56015) * [FlashAttn] add flash randomness control (PaddlePaddle#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (PaddlePaddle#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <kuizhiqing@msn.com>
…lePaddle#56015) * [FlashAttn] add flash randomness control (PaddlePaddle#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (PaddlePaddle#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <kuizhiqing@msn.com>
…lePaddle#56015) * [FlashAttn] add flash randomness control (PaddlePaddle#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (PaddlePaddle#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <kuizhiqing@msn.com>
…lePaddle#56015) * [FlashAttn] add flash randomness control (PaddlePaddle#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (PaddlePaddle#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <kuizhiqing@msn.com>
* part-3 cherry from: add check for cembedding (#55621) * part-3 fix cherry from: add check for cembedding * part-3 fix c_embedding * fix test_gpt_with_pir caused by pir * part-3 cherry from: [Distributed] Support dp/sharding overlap in virtual pp (#55651) * Add virtual pp and dp overlap * add sharding/dp overlap * add dp/vpp overlap * fix code * fix log * part-3 cherry from: [cherry-pick] Integration flash attention 2 (#56015) * [FlashAttn] add flash randomness control (#52902) * add flash randomness control * fix VLOG undefied * [WIP] Integration flash attention 2 (#55758) * Work for fa-2 padded fwd. Code to be cleaned. * Work for fa2 unpadded fwd. * Work for padded-bwd, dk get small diff on np.random.seed(0) * Anyway I pass paddle's utest, except return softmax without dropout. * Clean code. * Modify interface. * Clean code and add some check. * Easy compile for dev. * Fix ci. * Fix ci-build. * Add std c++17 option again. * Limit max job when compiling fa2. * Remove const_cast * Add fwd params, to be cleaned. * Clean code. * Add bwd params. * Clean code. * Add enforce. * Use v2.0.4 * Pass RNG state to fa2 capi * Fix review. * Add assert * Skip compile for sm less than 80. --------- Co-authored-by: Chitsing KUI <kuizhiqing@msn.com> * part-4 cherry from: fix codestyle (#56066) * part-4 cherry from(no change): Add assert for static and other plateform (#56044) * part-4 cherry-pick from: dp and sharding coexist (#56096) * dp and sharding coexist * dp * part-4 cherry from: [Distributed] Add debug information for processgroupnccl (#56441) * add debug information * fix log * fix log * add detach for pp * part-4 cherry from: [BugFix]Fix bug in paddle.device.cdua.synchronize() (#56451) * fix bug in synchronize * fix bug in synchronize * part-4 cherry from: add fused gradient (#57048) * part-4 cherry from: [Distribtued] add eager_communication_connection for eager mode in nccl (#57517) * add eager_nccl_connection * add eager_connection * add eager_connection * part-4 cherry from: Add auto growth allocator for CUDA pinned allocator (#57625) * fix h2d bandwidth * remove useless flags * fix cherrry pick #56066 * part-5 cherry from: Add allocation debug FLAGS (#57797) * Add allocation debug FLAGS * add sync after value set * refine flags * part-5 cherry from: fix softmax backward (#57971) * part-5 cherry from: [Distributed]Optimize memory in processgroup (#58299) * optimize memory in processgroupnccl * optimize memory in processgroupnccl * optimize memory in processgroupnccl * optimize memory in processgroupnccl * part-5 cherry from: [Distributed]Add unbalance batch for virtual pp (#58383) * add unbalanced batch for vpp * add unbalanced batch for vpp * add unbalanced batch for vpp * fix * fix comments * fix kunlun compatibility issues * fix test_fused_rotary_position_embedding.py * fix allocator.h * tinyfix * fix conflicts * fix new ir translator c_embedding failure --------- Co-authored-by: ShenLiang <1422485404@qq.com> Co-authored-by: umiswing <umiswing@foxmail.com> Co-authored-by: Chitsing KUI <kuizhiqing@msn.com> Co-authored-by: niuliling123 <51102941+niuliling123@users.noreply.github.com> Co-authored-by: liuzhenhai93 <liuzhenhai93@outlook.com> Co-authored-by: sneaxiy <32832641+sneaxiy@users.noreply.github.com>
PR types
New features
PR changes
OPs
Description
Pcard-70459
Integrating flash-attention-2 to PaddlePaddle.
差异点: