diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index fde9be0d974a7..079051c970cb7 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -32,17 +32,29 @@ func : addcmul_xpu data_type : x +- op : blha_get_max_len + args : (Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor batch_size) + output : Tensor(max_enc_len_this_time), Tensor(max_dec_len_this_time) + infer_meta : + func : BlhaGetMaxLenInferMeta + kernel : + func : blha_get_max_len + data_type : seq_lens_encoder + support_dygraph_mode : true + - op : block_multihead_attention_ - args : (Tensor qkv, Tensor key_cache, Tensor value_cache, Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor seq_lens_this_time, Tensor padding_offsets, Tensor cum_offsets, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor block_tables, Tensor pre_key_cache, Tensor pre_value_cache, Tensor rope_emb, Tensor mask, Tensor tgt_mask, Tensor cache_k_quant_scales, Tensor cache_v_quant_scales, Tensor cache_k_dequant_scales, Tensor cache_v_dequant_scales, Tensor qkv_out_scale, Tensor qkv_bias, Tensor out_shift, Tensor out_smooth, int max_seq_len, int block_size, bool use_neox_style, bool dynamic_cachekv_quant=false, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0, float out_scale=-1, str compute_dtype = "default") + args : (Tensor qkv, Tensor key_cache, Tensor value_cache, Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor seq_lens_this_time, Tensor padding_offsets, Tensor cum_offsets, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor block_tables, Tensor pre_key_cache, Tensor pre_value_cache, Tensor rope_emb, Tensor mask, Tensor tgt_mask, Tensor cache_k_quant_scales, Tensor cache_v_quant_scales, Tensor cache_k_dequant_scales, Tensor cache_v_dequant_scales, Tensor qkv_out_scale, Tensor qkv_bias, Tensor out_shift, Tensor out_smooth, Tensor max_enc_len_this_time, Tensor max_dec_len_this_time, int max_seq_len, int block_size, bool use_neox_style, bool dynamic_cachekv_quant=false, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0, float out_scale=-1, str compute_dtype = "default") output : Tensor(fmha_out), Tensor(qkv_out), Tensor(key_cache_out), Tensor(value_cache_out) infer_meta : func : BlockMultiheadAttentionInferMeta kernel : func : block_multihead_attention data_type : qkv - optional : pre_key_cache, pre_value_cache, rope_emb, mask, tgt_mask, cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales, cache_v_dequant_scales, qkv_out_scale, qkv_bias, out_shift, out_smooth + optional : pre_key_cache, pre_value_cache, rope_emb, mask, tgt_mask, cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales, cache_v_dequant_scales, qkv_out_scale, qkv_bias, out_shift, out_smooth, max_enc_len_this_time, max_dec_len_this_time inplace : (qkv -> qkv_out), (key_cache -> key_cache_out), (value_cache -> value_cache_out) support_dygraph_mode : true + data_transform : + skip_transform : max_enc_len_this_time, max_dec_len_this_time - op : bn_act_xpu args : (Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, float momentum, float epsilon, str data_format, int act_type) diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 91a515a667f65..79022ff4d2e00 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -232,6 +232,17 @@ void FusedMultiTransformerInferMeta( out->set_dims(x.dims()); } +void BlhaGetMaxLenInferMeta(const MetaTensor& seq_lens_encoder, + const MetaTensor& seq_lens_decoder, + const MetaTensor& batch_size, + MetaTensor* max_enc_len_this_time, + MetaTensor* max_dec_len_this_time) { + max_enc_len_this_time->set_dims({1}); + max_enc_len_this_time->set_dtype(phi::DataType::INT32); + max_dec_len_this_time->set_dims({1}); + max_dec_len_this_time->set_dtype(phi::DataType::INT32); +} + void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv, const MetaTensor& key_cache, const MetaTensor& value_cache, @@ -256,6 +267,8 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv, const MetaTensor& qkv_bias, const MetaTensor& out_shift, const MetaTensor& out_smooth, + const MetaTensor& max_enc_len_this_time, + const MetaTensor& max_dec_len_this_time, int max_seq_len, int block_size, bool use_neox_style, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 78d5791396186..6ecffd7675979 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -77,6 +77,12 @@ void GroupNormalizeSiluXPUInferMeta(const MetaTensor& x, float epsilon, MetaTensor* out); +void BlhaGetMaxLenInferMeta(const MetaTensor& seq_lens_encoder, + const MetaTensor& seq_lens_decoder, + const MetaTensor& batch_size, + MetaTensor* max_enc_len_this_time, + MetaTensor* max_dec_len_this_time); + void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv, const MetaTensor& key_cache, const MetaTensor& value_cache, @@ -101,6 +107,8 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv, const MetaTensor& qkv_bias, const MetaTensor& out_shift, const MetaTensor& out_smooth, + const MetaTensor& max_enc_len_this_time, + const MetaTensor& max_dec_len_this_time, int max_seq_len, int block_size, bool use_neox_style, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 17665623b56c1..0aca647dd6a49 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -215,6 +215,7 @@ if(WITH_ROCM) "gpu/qr_kernel.cu" "gpu/svd_kernel.cu" "gpudnn/mha_cudnn_frontend.cu" + "fusion/gpu/blha_get_max_len.cu" "fusion/gpu/block_multi_head_attention_kernel.cu" "fusion/gpu/fused_bn_add_activation_grad_kernel.cu" "fusion/gpu/fused_bn_add_activation_kernel.cu" diff --git a/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu b/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu new file mode 100644 index 0000000000000..78a46e16989e5 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu @@ -0,0 +1,68 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/flash_attn_kernel.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/fusion/cutlass/variable_length_memory_efficient_attention.h" +#include "paddle/phi/kernels/fusion/gpu/block_attn.h" +#include "paddle/phi/kernels/gpu/flash_attn_utils.h" +#include "paddle/phi/kernels/memcpy_kernel.h" +#include "paddle/utils/none.h" + +namespace phi { +namespace fusion { + +void GetMaxLenTensor(const phi::GPUContext& dev_ctx, + const phi::DenseTensor& seq_lens_tensor, + const phi::DenseTensor& batch_size, + DenseTensor* out) { + phi::DenseTensor max_len_tensor; + max_len_tensor.Resize({{1}}); + auto* max_len_tensor_data = dev_ctx.template Alloc( + &max_len_tensor, max_len_tensor.numel() * sizeof(int)); + const int bsz = batch_size.dims()[0]; + constexpr int blockSize = 128; + int max_len_cpu = 0; + GetMaxLenKernel<<<1, blockSize, 0, dev_ctx.stream()>>>( + seq_lens_tensor.data(), max_len_tensor.data(), bsz); + MemcpyD2HKernel(dev_ctx, max_len_tensor, 0, out); +} + +template +void BlhaGetMaxLenKernel(const Context& dev_ctx, + const DenseTensor& seq_lens_encoder, + const DenseTensor& seq_lens_decoder, + const phi::DenseTensor& batch_size, + DenseTensor* max_enc_len_this_time, + DenseTensor* max_dec_len_this_time) { + // decoder + max_dec_len_this_time->Resize({{1}}); + GetMaxLenTensor(dev_ctx, seq_lens_decoder, batch_size, max_dec_len_this_time); + + // encoder + max_enc_len_this_time->Resize({{1}}); + GetMaxLenTensor(dev_ctx, seq_lens_encoder, batch_size, max_enc_len_this_time); +} +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(blha_get_max_len, + GPU, + ALL_LAYOUT, + phi::fusion::BlhaGetMaxLenKernel, + int, + int64_t) {} diff --git a/paddle/phi/kernels/fusion/gpu/block_attn.h b/paddle/phi/kernels/fusion/gpu/block_attn.h index 500ffe939870f..f433a426435a2 100644 --- a/paddle/phi/kernels/fusion/gpu/block_attn.h +++ b/paddle/phi/kernels/fusion/gpu/block_attn.h @@ -2717,23 +2717,6 @@ __global__ void GetMaxLenKernel(const int *seq_lens, } } -int GetMaxLen(const phi::GPUContext &dev_ctx, - const phi::DenseTensor &seq_lens_tensor, - phi::DenseTensor *max_len_tensor, - const int batch_size) { - constexpr int blockSize = 128; - int max_len_cpu = 0; - GetMaxLenKernel<<<1, blockSize, 0, dev_ctx.stream()>>>( - seq_lens_tensor.data(), max_len_tensor->data(), batch_size); - memory_utils::Copy(phi::CPUPlace(), - &max_len_cpu, - dev_ctx.GetPlace(), - max_len_tensor->data(), - sizeof(int), - dev_ctx.stream()); - return max_len_cpu; -} - template __global__ void InitOutValueKernel(T *output_data, const int64_t numel, diff --git a/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu b/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu index 403472e6d945c..e94b5209bacf5 100644 --- a/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu @@ -29,6 +29,23 @@ namespace phi { namespace fusion { +int GetMaxLen(const phi::GPUContext& dev_ctx, + const phi::DenseTensor& seq_lens_tensor, + phi::DenseTensor* max_len_tensor, + const int batch_size) { + constexpr int blockSize = 128; + int max_len_cpu = 0; + GetMaxLenKernel<<<1, blockSize, 0, dev_ctx.stream()>>>( + seq_lens_tensor.data(), max_len_tensor->data(), batch_size); + memory_utils::Copy(phi::CPUPlace(), + &max_len_cpu, + dev_ctx.GetPlace(), + max_len_tensor->data(), + sizeof(int), + dev_ctx.stream()); + return max_len_cpu; +} + template inline HOSTDEVICE data_t RoundWithTiesToEven(data_t x) { data_t xLower = floor(x); @@ -234,6 +251,8 @@ void DispatchWithDtype( const paddle::optional& qkv_bias, const paddle::optional& out_shift, const paddle::optional& out_smooth, + const paddle::optional& max_enc_len_this_time, + const paddle::optional& max_dec_len_this_time, int max_seq_len, int block_size, bool use_neox_style, @@ -287,22 +306,44 @@ void DispatchWithDtype( VLOG(3) << "token_num: " << token_num << " pre_cache_length: " << pre_cache_length; - phi::DenseTensor max_dec_len_tensor; - max_dec_len_tensor.Resize({{1}}); - auto* max_dec_len_data = dev_ctx.template Alloc( - &max_dec_len_tensor, max_dec_len_tensor.numel() * sizeof(int)); - int max_dec_len_this_time = - GetMaxLen(dev_ctx, seq_lens_decoder, &max_dec_len_tensor, bsz); + int max_dec_len_this_time_data(0); + if (!max_dec_len_this_time) { + phi::DenseTensor max_dec_len_tensor; + max_dec_len_tensor.Resize({{1}}); + auto* max_dec_len_data = dev_ctx.template Alloc( + &max_dec_len_tensor, max_dec_len_tensor.numel() * sizeof(int)); + max_dec_len_this_time_data = + GetMaxLen(dev_ctx, seq_lens_decoder, &max_dec_len_tensor, bsz); + } else { + PADDLE_ENFORCE_EQ( + max_dec_len_this_time.get().place().GetType(), + phi::AllocationType::CPU, + errors::InvalidArgument( + "The place of input max_dec_len_this_time must be CPU, but got %s.", + max_dec_len_this_time.get().place())); + max_dec_len_this_time_data = *max_dec_len_this_time.get().data(); + } - phi::DenseTensor max_enc_len_tensor; - max_enc_len_tensor.Resize({{1}}); - auto* max_enc_len_data = dev_ctx.template Alloc( - &max_enc_len_tensor, max_enc_len_tensor.numel() * sizeof(int)); - int max_enc_len_this_time = - GetMaxLen(dev_ctx, seq_lens_encoder, &max_enc_len_tensor, bsz); + int max_enc_len_this_time_data(0); + if (!max_enc_len_this_time) { + phi::DenseTensor max_enc_len_tensor; + max_enc_len_tensor.Resize({{1}}); + auto* max_enc_len_data = dev_ctx.template Alloc( + &max_enc_len_tensor, max_enc_len_tensor.numel() * sizeof(int)); + max_enc_len_this_time_data = + GetMaxLen(dev_ctx, seq_lens_encoder, &max_enc_len_tensor, bsz); + } else { + PADDLE_ENFORCE_EQ( + max_enc_len_this_time.get().place().GetType(), + phi::AllocationType::CPU, + errors::InvalidArgument( + "The place of input max_enc_len_this_time must be CPU, but got %s.", + max_enc_len_this_time.get().place())); + max_enc_len_this_time_data = *max_enc_len_this_time.get().data(); + } phi::DenseTensor qkv_out_decoder; - if (max_dec_len_this_time > 0) { + if (max_dec_len_this_time_data > 0) { qkv_out_decoder.Resize({{bsz, 3, num_head, dim_head}}); auto* qkv_out_decoder_data = dev_ctx.template Alloc( &qkv_out_decoder, qkv_out_decoder.numel() * sizeof(T)); @@ -311,7 +352,7 @@ void DispatchWithDtype( phi::DenseTensor unpadding_q, unpadding_k, unpadding_v; phi::DenseTensor softmax_out, softmax_lse, seed_offset; phi::DenseTensor q_trans, k_trans, v_trans, qktv_out; - if (max_enc_len_this_time > 0) { + if (max_enc_len_this_time_data > 0) { if (!use_pre_cache) { unpadding_q.Resize({{token_num, num_head, dim_head}}); unpadding_k.Resize({{token_num, num_head, dim_head}}); @@ -321,16 +362,16 @@ void DispatchWithDtype( dev_ctx.template Alloc(&unpadding_k, unpadding_k.numel() * sizeof(T)); dev_ctx.template Alloc(&unpadding_v, unpadding_v.numel() * sizeof(T)); } else { - q_trans.Resize({{bsz, num_head, max_enc_len_this_time, dim_head}}); + q_trans.Resize({{bsz, num_head, max_enc_len_this_time_data, dim_head}}); k_trans.Resize({{bsz, num_head, - max_enc_len_this_time + pre_cache_length, + max_enc_len_this_time_data + pre_cache_length, dim_head}}); v_trans.Resize({{bsz, num_head, - max_enc_len_this_time + pre_cache_length, + max_enc_len_this_time_data + pre_cache_length, dim_head}}); - qktv_out.Resize({{bsz, num_head, max_enc_len_this_time, dim_head}}); + qktv_out.Resize({{bsz, num_head, max_enc_len_this_time_data, dim_head}}); dev_ctx.template Alloc(&q_trans, q_trans.numel() * sizeof(T)); dev_ctx.template Alloc(&k_trans, k_trans.numel() * sizeof(T)); @@ -339,7 +380,7 @@ void DispatchWithDtype( } } VLOG(3) << "encoder"; - VLOG(3) << "max_enc_len_this_time: " << max_enc_len_this_time; + VLOG(3) << "max_enc_len_this_time: " << max_enc_len_this_time_data; if (qkv_out_scale) { VLOG(1) << "qkv_out_scale: " << qkv_out_scale.get_ptr()->dims(); @@ -372,7 +413,7 @@ void DispatchWithDtype( dev_ctx, ins, &outs, phi::funcs::AddFunctor()); } - if (max_enc_len_this_time > 0) { + if (max_enc_len_this_time_data > 0) { const int* sequence_lengths_data = seq_lens_encoder.data(); if (rope_emb) { rotary_qk_variable(dev_ctx, @@ -416,8 +457,8 @@ void DispatchWithDtype( cu_seqlens_k, paddle::none /*fixed_seed_offset*/, causual ? paddle::none : mask, - max_enc_len_this_time, - max_enc_len_this_time, + max_enc_len_this_time_data, + max_enc_len_this_time_data, 1.0f / sqrt(static_cast(dim_head)), 0.0, causual, @@ -444,7 +485,7 @@ void DispatchWithDtype( token_num, bsz, num_head, - max_enc_len_this_time, + max_enc_len_this_time_data, max_seq_len, pre_cache_length, dim_head); @@ -471,7 +512,7 @@ void DispatchWithDtype( fmha_buf.data(), bsz, num_head, - max_enc_len_this_time, + max_enc_len_this_time_data, max_seq_len, dim_head, token_num, @@ -523,8 +564,8 @@ void DispatchWithDtype( VLOG(3) << "cache end"; } VLOG(3) << "encoder done"; - VLOG(3) << "max_dec_len_this_time: " << max_dec_len_this_time; - if (max_dec_len_this_time > 0) { + VLOG(3) << "max_dec_len_this_time: " << max_dec_len_this_time_data; + if (max_dec_len_this_time_data > 0) { GetDecoderTensor(dev_ctx, qkv_buf, nullptr, @@ -564,7 +605,7 @@ void DispatchWithDtype( pre_cache_length, num_head, dim_head, - max_dec_len_this_time, + max_dec_len_this_time_data, rope_emb ? 1 : 0, 1. / sqrt(dim_head), /*compute_bias*/ false, @@ -642,6 +683,8 @@ void BlockMultiheadAttentionKernel( const paddle::optional& qkv_bias, const paddle::optional& out_shift, const paddle::optional& out_smooth, + const paddle::optional& max_enc_len_this_time, + const paddle::optional& max_dec_len_this_time, int max_seq_len, int block_size, bool use_neox_style, @@ -684,6 +727,8 @@ void BlockMultiheadAttentionKernel( qkv_bias, out_shift, out_smooth, + max_enc_len_this_time, + max_dec_len_this_time, max_seq_len, block_size, use_neox_style, @@ -724,6 +769,8 @@ void BlockMultiheadAttentionKernel( qkv_bias, out_shift, out_smooth, + max_enc_len_this_time, + max_dec_len_this_time, max_seq_len, block_size, use_neox_style, @@ -767,6 +814,8 @@ void BlockMultiheadAttentionKernel( qkv_bias, out_shift, out_smooth, + max_enc_len_this_time, + max_dec_len_this_time, max_seq_len, block_size, use_neox_style, @@ -807,6 +856,8 @@ void BlockMultiheadAttentionKernel( qkv_bias, out_shift, out_smooth, + max_enc_len_this_time, + max_dec_len_this_time, max_seq_len, block_size, use_neox_style, @@ -835,12 +886,18 @@ PD_REGISTER_KERNEL(block_multihead_attention, phi::fusion::BlockMultiheadAttentionKernel, phi::dtype::bfloat16, phi::dtype::float16, - int32_t) {} + int32_t) { + kernel->InputAt(24).SetBackend(phi::Backend::CPU); + kernel->InputAt(25).SetBackend(phi::Backend::CPU); +} #else PD_REGISTER_KERNEL(block_multihead_attention, GPU, ALL_LAYOUT, phi::fusion::BlockMultiheadAttentionKernel, phi::dtype::float16, - int32_t) {} + int32_t) { + kernel->InputAt(24).SetBackend(phi::Backend::CPU); + kernel->InputAt(25).SetBackend(phi::Backend::CPU); +} #endif diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index 1b6b153201990..e967d699416b9 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .blha_get_max_len import blha_get_max_len from .block_multihead_attention import block_multihead_attention from .fused_dot_product_attention import ( fused_dot_product_attention, # noqa: F401 @@ -54,6 +55,7 @@ "fused_rms_norm", "fused_layer_norm", "masked_multihead_attention", + "blha_get_max_len", "block_multihead_attention", "swiglu", ] diff --git a/python/paddle/incubate/nn/functional/blha_get_max_len.py b/python/paddle/incubate/nn/functional/blha_get_max_len.py new file mode 100644 index 0000000000000..f330803e1b2fa --- /dev/null +++ b/python/paddle/incubate/nn/functional/blha_get_max_len.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle import _C_ops +from paddle.framework import LayerHelper, in_dynamic_or_pir_mode + + +def blha_get_max_len(seq_lens_encoder, seq_lens_decoder, batch_size): + """ + Apply Fused BlhaGetMaxLen kernel. Typically used before the block_multihead_attention operator. + + Args: + seq_lens_encoder (Tensor): Sentence length of the encoder. + seq_lens_decoder (Tensor): Sentence length of the decoder. + batch_size (Tensor): the batch size. + + Returns: + Tensor|(max_enc_len_this_time, max_dec_len_this_time) + + Examples: + .. code-block:: python + + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') + + >>> seq_lens_encoder = paddle.cast(paddle.randn(shape=[10]), dtype=paddle.int32) + >>> seq_lens_decoder = paddle.cast(paddle.randn(shape=[10]), dtype=paddle.int32) + >>> bsz = 10 + >>> batch_size = paddle.ones(shape=[bsz]) + >>> max_enc_len_this_time, max_dec_len_this_time = paddle.incubate.nn.functional.blha_get_max_len(seq_lens_encoder, seq_lens_decoder, batch_size) + """ + if in_dynamic_or_pir_mode(): + return _C_ops.blha_get_max_len( + seq_lens_encoder, seq_lens_decoder, batch_size + ) + + helper = LayerHelper('blha_get_max_len', **locals()) + max_enc_len_this_time = helper.create_variable_for_type_inference( + dtype="int32" + ) + max_dec_len_this_time = helper.create_variable_for_type_inference( + dtype="int32" + ) + + inputs = {} + inputs['seq_lens_encoder'] = seq_lens_encoder + inputs['seq_lens_decoder'] = seq_lens_decoder + inputs['batch_size'] = batch_size + + outputs = { + 'max_enc_len_this_time': max_enc_len_this_time, + 'max_dec_len_this_time': max_dec_len_this_time, + } + helper.append_op( + type='blha_get_max_len', + inputs=inputs, + outputs=outputs, + ) + return max_enc_len_this_time, max_dec_len_this_time diff --git a/python/paddle/incubate/nn/functional/block_multihead_attention.py b/python/paddle/incubate/nn/functional/block_multihead_attention.py index 6409f160aaf69..a55f61de2c678 100644 --- a/python/paddle/incubate/nn/functional/block_multihead_attention.py +++ b/python/paddle/incubate/nn/functional/block_multihead_attention.py @@ -38,6 +38,8 @@ def block_multihead_attention( qkv_bias=None, out_shift=None, out_smooth=None, + max_enc_len_this_time=None, + max_dec_len_this_time=None, rope_emb=None, mask=None, tgt_mask=None, @@ -76,6 +78,8 @@ def block_multihead_attention( qkv_bias (Tensor): The bias of qkv. Its shape is [3 * num_head * head_size]. out_shift (Tensor): Shift bias of fmha_out, which is the 1st return value. Its shape is [num_head * head_size]. out_smooth (Tensor): Smooth weight of fmha_out. Its shape is [num_head * head_size]. + max_enc_len_this_time (Tensor): Sentence length of the encoder this time. Its shape is [1]. + max_dec_len_this_time (Tensor): Sentence length of the decoder this time. Its shape is [1]. rope_emb (Tensor): The RoPE embedding. Its shape is [2, batchsize, max_seq_len, 1, head_size // 2]. mask (Tensor): The mask of qk_matmul in encoder. Its shape is [batchsize, 1, max_seq_len, max_seq_len]. tgt_mask (Tensor): The mask of qk_matmul in decoder. Its shape is [batchsize, 1, 1, max_seq_len]. @@ -251,6 +255,8 @@ def block_multihead_attention( ... None, # qkv_bias ... None, # out_shift ... None, # out_smooth + ... None, # max_enc_len_this_time + ... None, # max_dec_len_this_time ... None, # rotary_embs ... None, # attn_mask ... None, # tgt_mask @@ -301,6 +307,8 @@ def block_multihead_attention( qkv_bias, out_shift, out_smooth, + max_enc_len_this_time, + max_dec_len_this_time, max_seq_len, block_size, use_neox_style, @@ -353,6 +361,10 @@ def block_multihead_attention( inputs["out_shift"] = out_shift if out_smooth is not None: inputs["out_smooth"] = out_smooth + if max_enc_len_this_time is not None: + inputs["max_enc_len_this_time"] = max_enc_len_this_time + if max_dec_len_this_time is not None: + inputs["max_dec_len_this_time"] = max_dec_len_this_time outputs = { 'fmha_out': out, diff --git a/test/legacy_test/test_blha_get_max_len_op.py b/test/legacy_test/test_blha_get_max_len_op.py new file mode 100644 index 0000000000000..914dce275975f --- /dev/null +++ b/test/legacy_test/test_blha_get_max_len_op.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.base import core +from paddle.incubate.nn.functional import blha_get_max_len + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "Only support GPU in CUDA mode." +) +class TestBlhaGetMaxLenOp(unittest.TestCase): + def setUp(self): + self.name = "TestBlhaGetMaxLenOpDynamic" + self.place = paddle.CUDAPlace(0) + self.batch_size = 10 + self.test_encoder_data = np.random.randint(1, 100, size=self.batch_size) + self.test_encoder_data_res = paddle.to_tensor( + np.max(self.test_encoder_data), "int32" + ) + self.test_decoder_data = np.random.randint(1, 100, size=self.batch_size) + self.test_decoder_data_res = paddle.to_tensor( + np.max(self.test_decoder_data), "int32" + ) + self.seq_lens_encoder = paddle.to_tensor( + self.test_encoder_data, + "int32", + ) + self.seq_lens_decoder = paddle.to_tensor( + self.test_decoder_data, + "int32", + ) + self.batch_size_tensor = paddle.ones([self.batch_size]) + + def test_dynamic_api(self): + paddle.disable_static() + max_enc_len_this_time, max_dec_len_this_time = blha_get_max_len( + self.seq_lens_encoder, + self.seq_lens_decoder, + self.batch_size_tensor, + ) + assert ( + max_enc_len_this_time == self.test_encoder_data_res + and max_dec_len_this_time == self.test_decoder_data_res + ) + + def test_static_api(self): + paddle.enable_static() + max_enc_len_this_time, max_dec_len_this_time = blha_get_max_len( + self.seq_lens_encoder, + self.seq_lens_decoder, + self.batch_size_tensor, + ) + assert ( + max_enc_len_this_time == self.test_encoder_data_res + and max_dec_len_this_time == self.test_decoder_data_res + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_block_multihead_attention.py b/test/legacy_test/test_block_multihead_attention.py index 7f3033044e1c5..e36605cf8ea07 100644 --- a/test/legacy_test/test_block_multihead_attention.py +++ b/test/legacy_test/test_block_multihead_attention.py @@ -392,6 +392,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -496,6 +498,286 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time + None, # rotary_embs + None, # attn_mask + self.tgt_mask, # tgt_mask + 1, # seq_len, + self.blocksize, + False, # use_neox_rotary_style + )[0] + # NOTE: The diff of decoder is a little big + np.testing.assert_allclose( + out.numpy(), + out_.numpy(), + rtol=5e-02, + atol=5e-02, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11040 + or not is_sm_supported, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" + "and device's compute capability must be 8.x or 90", +) +class TestBlockMultiHeadAttnEncDecSkipGetMaxLen(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.name = "TestBlockMultiHeadAttnEncDecSkipGetMaxLen" + self.place = paddle.CUDAPlace(0) + self.batch_size = 2 + self.num_head = 8 + self.seq_len = 64 + self.max_dec_len = 64 + self.dim_head = 64 + self.hid_dim = self.num_head * self.dim_head + self.blocksize = 64 + self.block_num_per_seq = ( + self.seq_len + self.max_dec_len + self.blocksize - 1 + ) // self.blocksize + self.max_block_num = self.block_num_per_seq * self.batch_size + self.free_list = list(range(self.max_block_num - 1, -1, -1)) + self.seq_lens_encoder = paddle.to_tensor( + [ + self.seq_len, + ] + * self.batch_size, + "int32", + ) + self.seq_lens_decoder = paddle.to_tensor( + [ + 0, + ] + * self.batch_size, + "int32", + ) + self.seq_lens_this_time = self.seq_lens_encoder + self.max_enc_len_this_time = paddle.to_tensor( + [self.seq_len], "int32" + ).cpu() + self.max_dec_len_this_time = paddle.to_tensor([0], "int32").cpu() + self.shape = ( + self.batch_size, + self.num_head, + self.seq_len, + self.dim_head, + ) + self.cache_shape = ( + self.max_block_num, + self.num_head, + self.blocksize, + self.dim_head, + ) + self.dtype = 'float16' + self.attention_mask = create_attn_mask( + self.dtype, + self.batch_size, + [ + self.seq_len, + ] + * self.batch_size, + ) + + self.tgt_mask = paddle.randn( + [self.batch_size, self.num_head, 1, self.seq_len + 1], + dtype=self.dtype, + ) + + self.scale = 1.0 / np.sqrt(self.shape[-1]) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.block_tables = paddle.zeros( + shape=(self.batch_size, self.block_num_per_seq), dtype="int32" + ) + for i in range(self.batch_size): + need_block_num = ( + self.seq_len + self.max_dec_len + self.blocksize - 1 + ) // self.blocksize + for j in range(need_block_num): + self.block_tables[i, j] = self.free_list.pop() + ( + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset( + self.batch_size, self.seq_len, self.seq_lens_this_time + ) + self.token_num = self.padding_offset.shape[0] + + def test_all(self): + paddle.disable_static() + # encoder + query = np.random.random(self.shape) + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + key = np.random.random(self.shape) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + value = np.random.random(self.shape) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + qkv = paddle.stack( + [ + q.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + k.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + v.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + ], + axis=1, + ).reshape([self.token_num, -1]) + out_ = naive_attention_impl( + q, k, v, None, None, None, None, self.attention_mask, self.scale + ) + out_ = remove_padding( + self.seq_lens_this_time, self.cu_seqlens_q, out_, self.token_num + ) + out = block_multihead_attention( + qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + self.block_tables, + None, # pre_key_cache + None, # pre_value_cache + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # qkv_out_scale + None, # qkv_bias + None, # out_shift + None, # out_smooth + self.max_enc_len_this_time, # max_enc_len_this_time + self.max_dec_len_this_time, # max_dec_len_this_time + None, # rotary_embs + None, # attn_mask + None, # tgt_mask + self.seq_len, + self.blocksize, + False, # use_neox_rotary_style, + )[0] + + np.testing.assert_allclose( + out.numpy(), + out_.numpy(), + rtol=5e-03, + atol=1e-03, + ) + + # decoder + naive_cache_k, naive_cache_v = block_cache_to_naive_cache( + self.cache_k, + self.cache_v, + self.batch_size, + self.block_tables, + self.seq_len, + ) + + self.seq_lens_decoder[:] = self.seq_lens_encoder + self.seq_lens_encoder[:] = 0 + self.seq_lens_this_time[:] = 1 + self.max_enc_len_this_time = paddle.to_tensor([0], "int32").cpu() + self.max_dec_len_this_time = paddle.to_tensor( + [self.seq_len], "int32" + ).cpu() + self.shape = ( + self.batch_size, + self.num_head, + 1, + self.dim_head, + ) + query = np.random.random(self.shape) + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + key = np.random.random(self.shape) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + value = np.random.random(self.shape) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + qkv = paddle.stack( + [ + q.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + k.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + v.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + ], + axis=1, + ).reshape([self.batch_size, -1]) + ( + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time) + + out_ = ( + naive_attention_impl( + q, + k, + v, + naive_cache_k, + naive_cache_v, + None, + None, + self.tgt_mask, + self.scale, + ) + .transpose([0, 2, 1, 3]) + .reshape([self.batch_size, -1]) + ) + out = block_multihead_attention( + qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + self.block_tables, + None, # pre_key_cache + None, # pre_value_cache + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # qkv_out_scale + None, # qkv_bias + None, # out_shift + None, # out_smooth + self.max_enc_len_this_time, # max_enc_len_this_time + self.max_dec_len_this_time, # max_dec_len_this_time None, # rotary_embs None, # attn_mask self.tgt_mask, # tgt_mask @@ -690,6 +972,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time self.rope_emb, # rotary_embs None, # attn_mask None, # tgt_mask @@ -800,6 +1084,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time self.rope_emb, # rotary_embs None, # attn_mask None, # tgt_mask @@ -979,6 +1265,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs self.attention_mask, # attn_mask None, # tgt_mask @@ -1083,6 +1371,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs self.attention_mask, # attn_mask None, # tgt_mask @@ -1285,6 +1575,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -1498,6 +1790,8 @@ def test_all(self): qkv_bias, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -1643,6 +1937,8 @@ def test_all(self): qkv_bias, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -1858,6 +2154,8 @@ def test_all(self): qkv_bias, # qkv_bias shift, # out_shift smooth, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -2021,6 +2319,8 @@ def test_all(self): qkv_bias, # qkv_bias shift, # out_shift smooth, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -2186,6 +2486,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -2298,6 +2600,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -2469,6 +2773,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -2579,6 +2885,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -2762,6 +3070,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -2871,6 +3181,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask