From d83c6e366c3bbe6f28221186b27ee87508c742ce Mon Sep 17 00:00:00 2001 From: minghaipeng Date: Mon, 13 May 2024 05:26:48 +0000 Subject: [PATCH 01/10] [Inference]add blha_get_max_len op & modify block_multihead_attention op --- paddle/phi/api/yaml/fused_ops.yaml | 16 +++- paddle/phi/infermeta/fusion.cc | 13 +++ paddle/phi/infermeta/fusion.h | 8 ++ .../kernels/fusion/gpu/blha_get_max_len.cu | 63 +++++++++++++ paddle/phi/kernels/fusion/gpu/block_attn.h | 17 ---- .../gpu/block_multi_head_attention_kernel.cu | 93 +++++++++++++------ .../paddle/incubate/nn/functional/__init__.py | 2 + .../nn/functional/blha_get_max_len.py | 70 ++++++++++++++ .../functional/block_multihead_attention.py | 8 ++ 9 files changed, 244 insertions(+), 46 deletions(-) create mode 100644 paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu create mode 100644 python/paddle/incubate/nn/functional/blha_get_max_len.py diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 304c543d1a463..00ea646b83508 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -32,17 +32,29 @@ func : addcmul_xpu data_type : x +- op : blha_get_max_len + args : (Tensor seq_lens_encoder, Tensor seq_lens_decoder, int bsz) + output : Tensor(max_enc_len_this_time), Tensor(max_dec_len_this_time) + infer_meta : + func : BlhaGetMaxLenInferMeta + kernel : + func : blha_get_max_len + data_type : seq_lens_encoder + support_dygraph_mode : true + - op : block_multihead_attention_ - args : (Tensor qkv, Tensor key_cache, Tensor value_cache, Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor seq_lens_this_time, Tensor padding_offsets, Tensor cum_offsets, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor block_tables, Tensor pre_key_cache, Tensor pre_value_cache, Tensor rope_emb, Tensor mask, Tensor tgt_mask, Tensor cache_k_quant_scales, Tensor cache_v_quant_scales, Tensor cache_k_dequant_scales, Tensor cache_v_dequant_scales, Tensor qkv_out_scale, Tensor qkv_bias, Tensor out_shift, Tensor out_smooth, int max_seq_len, int block_size, bool use_neox_style, bool dynamic_cachekv_quant=false, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0, float out_scale=-1, str compute_dtype = "default") + args : (Tensor qkv, Tensor key_cache, Tensor value_cache, Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor seq_lens_this_time, Tensor padding_offsets, Tensor cum_offsets, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor block_tables, Tensor pre_key_cache, Tensor pre_value_cache, Tensor rope_emb, Tensor mask, Tensor tgt_mask, Tensor cache_k_quant_scales, Tensor cache_v_quant_scales, Tensor cache_k_dequant_scales, Tensor cache_v_dequant_scales, Tensor qkv_out_scale, Tensor qkv_bias, Tensor out_shift, Tensor out_smooth, Tensor max_enc_len_this_time, Tensor max_dec_len_this_time, int max_seq_len, int block_size, bool use_neox_style, bool dynamic_cachekv_quant=false, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0, float out_scale=-1, str compute_dtype = "default") output : Tensor(fmha_out), Tensor(qkv_out), Tensor(key_cache_out), Tensor(value_cache_out) infer_meta : func : BlockMultiheadAttentionInferMeta kernel : func : block_multihead_attention data_type : qkv - optional : pre_key_cache, pre_value_cache, rope_emb, mask, tgt_mask, cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales, cache_v_dequant_scales, qkv_out_scale, qkv_bias, out_shift, out_smooth + optional : pre_key_cache, pre_value_cache, rope_emb, mask, tgt_mask, cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales, cache_v_dequant_scales, qkv_out_scale, qkv_bias, out_shift, out_smooth, max_enc_len_this_time, max_dec_len_this_time inplace : (qkv -> qkv_out), (key_cache -> key_cache_out), (value_cache -> value_cache_out) support_dygraph_mode : true + data_transform : + skip_transform : max_enc_len_this_time, max_dec_len_this_time - op : bn_act_xpu args : (Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, float momentum, float epsilon, str data_format, int act_type) diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 590473bd2094e..c036bf30539fb 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -232,6 +232,17 @@ void FusedMultiTransformerInferMeta( out->set_dims(x.dims()); } +void BlhaGetMaxLenInferMeta(const MetaTensor& seq_lens_encoder, + const MetaTensor& seq_lens_decoder, + const int bsz, + MetaTensor* max_enc_len_this_time, + MetaTensor* max_dec_len_this_time) { + max_enc_len_this_time->set_dims({1}); + max_enc_len_this_time->set_dtype(phi::DataType::INT32); + max_dec_len_this_time->set_dims({1}); + max_dec_len_this_time->set_dtype(phi::DataType::INT32); +} + void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv, const MetaTensor& key_cache, const MetaTensor& value_cache, @@ -256,6 +267,8 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv, const MetaTensor& qkv_bias, const MetaTensor& out_shift, const MetaTensor& out_smooth, + const MetaTensor& max_enc_len_this_time, + const MetaTensor& max_dec_len_this_time, int max_seq_len, int block_size, bool use_neox_style, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 488fa6b9904c8..3a19bda6182d7 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -77,6 +77,12 @@ void GroupNormalizeSiluXPUInferMeta(const MetaTensor& x, float epsilon, MetaTensor* out); +void BlhaGetMaxLenInferMeta(const MetaTensor& seq_lens_encoder, + const MetaTensor& seq_lens_decoder, + const int bsz, + MetaTensor* max_enc_len_this_time, + MetaTensor* max_dec_len_this_time); + void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv, const MetaTensor& key_cache, const MetaTensor& value_cache, @@ -101,6 +107,8 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv, const MetaTensor& qkv_bias, const MetaTensor& out_shift, const MetaTensor& out_smooth, + const MetaTensor& max_enc_len_this_time, + const MetaTensor& max_dec_len_this_time, int max_seq_len, int block_size, bool use_neox_style, diff --git a/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu b/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu new file mode 100644 index 0000000000000..3f7ea5d2c7688 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu @@ -0,0 +1,63 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/flash_attn_kernel.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/fusion/cutlass/variable_length_memory_efficient_attention.h" +#include "paddle/phi/kernels/fusion/gpu/block_attn.h" +#include "paddle/phi/kernels/gpu/flash_attn_utils.h" +#include "paddle/phi/kernels/memcpy_kernel.h" +#include "paddle/utils/none.h" + +namespace phi { +namespace fusion { + +void GetMaxLenTensor(const phi::GPUContext& dev_ctx, + const phi::DenseTensor& seq_lens_tensor, + const int batch_size, + DenseTensor* out) { + phi::DenseTensor max_len_tensor; + max_len_tensor.Resize({{1}}); + auto* max_len_tensor_data = dev_ctx.template Alloc( + &max_len_tensor, max_len_tensor.numel() * sizeof(int)); + constexpr int blockSize = 128; + int max_len_cpu = 0; + GetMaxLenKernel<<<1, blockSize, 0, dev_ctx.stream()>>>( + seq_lens_tensor.data(), max_len_tensor.data(), batch_size); + MemcpyD2HKernel(dev_ctx, max_len_tensor, 0, out); +} + +template +void BlhaGetMaxLenKernel(const Context& dev_ctx, + const DenseTensor& seq_lens_encoder, + const DenseTensor& seq_lens_decoder, + const int bsz, + DenseTensor* max_enc_len_this_time, + DenseTensor* max_dec_len_this_time) { + // decoder + max_dec_len_this_time->Resize({{1}}); + GetMaxLenTensor(dev_ctx, seq_lens_decoder, bsz, max_dec_len_this_time); + + // encoder + max_enc_len_this_time->Resize({{1}}); + GetMaxLenTensor(dev_ctx, seq_lens_encoder, bsz, max_enc_len_this_time); +} +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL( + blha_get_max_len, GPU, ALL_LAYOUT, phi::fusion::BlhaGetMaxLenKernel, int) {} diff --git a/paddle/phi/kernels/fusion/gpu/block_attn.h b/paddle/phi/kernels/fusion/gpu/block_attn.h index 500ffe939870f..f433a426435a2 100644 --- a/paddle/phi/kernels/fusion/gpu/block_attn.h +++ b/paddle/phi/kernels/fusion/gpu/block_attn.h @@ -2717,23 +2717,6 @@ __global__ void GetMaxLenKernel(const int *seq_lens, } } -int GetMaxLen(const phi::GPUContext &dev_ctx, - const phi::DenseTensor &seq_lens_tensor, - phi::DenseTensor *max_len_tensor, - const int batch_size) { - constexpr int blockSize = 128; - int max_len_cpu = 0; - GetMaxLenKernel<<<1, blockSize, 0, dev_ctx.stream()>>>( - seq_lens_tensor.data(), max_len_tensor->data(), batch_size); - memory_utils::Copy(phi::CPUPlace(), - &max_len_cpu, - dev_ctx.GetPlace(), - max_len_tensor->data(), - sizeof(int), - dev_ctx.stream()); - return max_len_cpu; -} - template __global__ void InitOutValueKernel(T *output_data, const int64_t numel, diff --git a/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu b/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu index 57754fd3b82aa..2bfcdb6f557c9 100644 --- a/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu @@ -25,6 +25,23 @@ namespace phi { namespace fusion { +int GetMaxLen(const phi::GPUContext& dev_ctx, + const phi::DenseTensor& seq_lens_tensor, + phi::DenseTensor* max_len_tensor, + const int batch_size) { + constexpr int blockSize = 128; + int max_len_cpu = 0; + GetMaxLenKernel<<<1, blockSize, 0, dev_ctx.stream()>>>( + seq_lens_tensor.data(), max_len_tensor->data(), batch_size); + memory_utils::Copy(phi::CPUPlace(), + &max_len_cpu, + dev_ctx.GetPlace(), + max_len_tensor->data(), + sizeof(int), + dev_ctx.stream()); + return max_len_cpu; +} + template inline HOSTDEVICE data_t RoundWithTiesToEven(data_t x) { data_t xLower = floor(x); @@ -230,6 +247,8 @@ void DispatchWithDtype( const paddle::optional& qkv_bias, const paddle::optional& out_shift, const paddle::optional& out_smooth, + const paddle::optional& max_enc_len_this_time, + const paddle::optional& max_dec_len_this_time, int max_seq_len, int block_size, bool use_neox_style, @@ -283,22 +302,32 @@ void DispatchWithDtype( VLOG(3) << "token_num: " << token_num << " pre_cache_length: " << pre_cache_length; - phi::DenseTensor max_dec_len_tensor; - max_dec_len_tensor.Resize({{1}}); - auto* max_dec_len_data = dev_ctx.template Alloc( - &max_dec_len_tensor, max_dec_len_tensor.numel() * sizeof(int)); - int max_dec_len_this_time = - GetMaxLen(dev_ctx, seq_lens_decoder, &max_dec_len_tensor, bsz); + int max_dec_len_this_time_data(0); + if (!max_dec_len_this_time) { + phi::DenseTensor max_dec_len_tensor; + max_dec_len_tensor.Resize({{1}}); + auto* max_dec_len_data = dev_ctx.template Alloc( + &max_dec_len_tensor, max_dec_len_tensor.numel() * sizeof(int)); + max_dec_len_this_time_data = + GetMaxLen(dev_ctx, seq_lens_decoder, &max_dec_len_tensor, bsz); + } else { + max_dec_len_this_time_data = *max_dec_len_this_time.get().data(); + } - phi::DenseTensor max_enc_len_tensor; - max_enc_len_tensor.Resize({{1}}); - auto* max_enc_len_data = dev_ctx.template Alloc( - &max_enc_len_tensor, max_enc_len_tensor.numel() * sizeof(int)); - int max_enc_len_this_time = - GetMaxLen(dev_ctx, seq_lens_encoder, &max_enc_len_tensor, bsz); + int max_enc_len_this_time_data(0); + if (!max_enc_len_this_time) { + phi::DenseTensor max_enc_len_tensor; + max_enc_len_tensor.Resize({{1}}); + auto* max_enc_len_data = dev_ctx.template Alloc( + &max_enc_len_tensor, max_enc_len_tensor.numel() * sizeof(int)); + max_enc_len_this_time_data = + GetMaxLen(dev_ctx, seq_lens_encoder, &max_enc_len_tensor, bsz); + } else { + max_enc_len_this_time_data = *max_enc_len_this_time.get().data(); + } phi::DenseTensor qkv_out_decoder; - if (max_dec_len_this_time > 0) { + if (max_dec_len_this_time_data > 0) { qkv_out_decoder.Resize({{bsz, 3, num_head, dim_head}}); auto* qkv_out_decoder_data = dev_ctx.template Alloc( &qkv_out_decoder, qkv_out_decoder.numel() * sizeof(T)); @@ -307,7 +336,7 @@ void DispatchWithDtype( phi::DenseTensor unpadding_q, unpadding_k, unpadding_v; phi::DenseTensor softmax_out, softmax_lse, seed_offset; phi::DenseTensor q_trans, k_trans, v_trans, qktv_out; - if (max_enc_len_this_time > 0) { + if (max_enc_len_this_time_data > 0) { if (!use_pre_cache) { unpadding_q.Resize({{token_num, num_head, dim_head}}); unpadding_k.Resize({{token_num, num_head, dim_head}}); @@ -317,16 +346,16 @@ void DispatchWithDtype( dev_ctx.template Alloc(&unpadding_k, unpadding_k.numel() * sizeof(T)); dev_ctx.template Alloc(&unpadding_v, unpadding_v.numel() * sizeof(T)); } else { - q_trans.Resize({{bsz, num_head, max_enc_len_this_time, dim_head}}); + q_trans.Resize({{bsz, num_head, max_enc_len_this_time_data, dim_head}}); k_trans.Resize({{bsz, num_head, - max_enc_len_this_time + pre_cache_length, + max_enc_len_this_time_data + pre_cache_length, dim_head}}); v_trans.Resize({{bsz, num_head, - max_enc_len_this_time + pre_cache_length, + max_enc_len_this_time_data + pre_cache_length, dim_head}}); - qktv_out.Resize({{bsz, num_head, max_enc_len_this_time, dim_head}}); + qktv_out.Resize({{bsz, num_head, max_enc_len_this_time_data, dim_head}}); dev_ctx.template Alloc(&q_trans, q_trans.numel() * sizeof(T)); dev_ctx.template Alloc(&k_trans, k_trans.numel() * sizeof(T)); @@ -335,7 +364,7 @@ void DispatchWithDtype( } } VLOG(3) << "encoder"; - VLOG(3) << "max_enc_len_this_time: " << max_enc_len_this_time; + VLOG(3) << "max_enc_len_this_time: " << max_enc_len_this_time_data; if (qkv_out_scale) { VLOG(1) << "qkv_out_scale: " << qkv_out_scale.get_ptr()->dims(); @@ -368,7 +397,7 @@ void DispatchWithDtype( dev_ctx, ins, &outs, phi::funcs::AddFunctor()); } - if (max_enc_len_this_time > 0) { + if (max_enc_len_this_time_data > 0) { const int* sequence_lengths_data = seq_lens_encoder.data(); if (rope_emb) { rotary_qk_variable(dev_ctx, @@ -408,8 +437,8 @@ void DispatchWithDtype( cu_seqlens_k, paddle::none /*fixed_seed_offset*/, causual ? paddle::none : mask, - max_enc_len_this_time, - max_enc_len_this_time, + max_enc_len_this_time_data, + max_enc_len_this_time_data, 1.0f / sqrt(static_cast(dim_head)), 0.0, causual, @@ -434,7 +463,7 @@ void DispatchWithDtype( token_num, bsz, num_head, - max_enc_len_this_time, + max_enc_len_this_time_data, max_seq_len, pre_cache_length, dim_head); @@ -461,7 +490,7 @@ void DispatchWithDtype( fmha_buf.data(), bsz, num_head, - max_enc_len_this_time, + max_enc_len_this_time_data, max_seq_len, dim_head, token_num, @@ -513,8 +542,8 @@ void DispatchWithDtype( VLOG(3) << "cache end"; } VLOG(3) << "encoder done"; - VLOG(3) << "max_dec_len_this_time: " << max_dec_len_this_time; - if (max_dec_len_this_time > 0) { + VLOG(3) << "max_dec_len_this_time: " << max_dec_len_this_time_data; + if (max_dec_len_this_time_data > 0) { GetDecoderTensor(dev_ctx, qkv_buf, nullptr, @@ -554,7 +583,7 @@ void DispatchWithDtype( pre_cache_length, num_head, dim_head, - max_dec_len_this_time, + max_dec_len_this_time_data, rope_emb ? 1 : 0, 1. / sqrt(dim_head), /*compute_bias*/ false, @@ -632,6 +661,8 @@ void BlockMultiheadAttentionKernel( const paddle::optional& qkv_bias, const paddle::optional& out_shift, const paddle::optional& out_smooth, + const paddle::optional& max_enc_len_this_time, + const paddle::optional& max_dec_len_this_time, int max_seq_len, int block_size, bool use_neox_style, @@ -674,6 +705,8 @@ void BlockMultiheadAttentionKernel( qkv_bias, out_shift, out_smooth, + max_enc_len_this_time, + max_dec_len_this_time, max_seq_len, block_size, use_neox_style, @@ -714,6 +747,8 @@ void BlockMultiheadAttentionKernel( qkv_bias, out_shift, out_smooth, + max_enc_len_this_time, + max_dec_len_this_time, max_seq_len, block_size, use_neox_style, @@ -757,6 +792,8 @@ void BlockMultiheadAttentionKernel( qkv_bias, out_shift, out_smooth, + max_enc_len_this_time, + max_dec_len_this_time, max_seq_len, block_size, use_neox_style, @@ -797,6 +834,8 @@ void BlockMultiheadAttentionKernel( qkv_bias, out_shift, out_smooth, + max_enc_len_this_time, + max_dec_len_this_time, max_seq_len, block_size, use_neox_style, diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index 1b6b153201990..e967d699416b9 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .blha_get_max_len import blha_get_max_len from .block_multihead_attention import block_multihead_attention from .fused_dot_product_attention import ( fused_dot_product_attention, # noqa: F401 @@ -54,6 +55,7 @@ "fused_rms_norm", "fused_layer_norm", "masked_multihead_attention", + "blha_get_max_len", "block_multihead_attention", "swiglu", ] diff --git a/python/paddle/incubate/nn/functional/blha_get_max_len.py b/python/paddle/incubate/nn/functional/blha_get_max_len.py new file mode 100644 index 0000000000000..ef3df049641d7 --- /dev/null +++ b/python/paddle/incubate/nn/functional/blha_get_max_len.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle import _C_ops +from paddle.framework import LayerHelper, in_dynamic_mode + + +def blha_get_max_len(seq_lens_encoder, seq_lens_decoder, bsz): + """ + Apply Fused BlhaGetMaxLen kernel. Typically used before the block_multihead_attention operator. + + Args: + seq_lens_encoder (Tensor): Sentence length of the encoder. + seq_lens_decoder (Tensor): Sentence length of the decoder. + bsz (Int): the batch size. + + Returns: + Tensor|(max_enc_len_this_time, max_dec_len_this_time) + + Examples: + .. code-block:: python + + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') + + >>> seq_lens_encoder = paddle.cast(paddle.randn(shape=[1]), dtype=paddle.int32) + >>> seq_lens_decoder = paddle.cast(paddle.randn(shape=[1]), dtype=paddle.int32) + >>> bsz = 1 + >>> max_enc_len_this_time, max_dec_len_this_time = paddle.incubate.nn.functional.blha_get_max_len(seq_lens_encoder, seq_lens_decoder, bsz) + """ + if in_dynamic_mode(): + return _C_ops.blha_get_max_len(seq_lens_encoder, seq_lens_decoder, bsz) + + helper = LayerHelper('blha_get_max_len', **locals()) + max_enc_len_this_time = helper.create_variable_for_type_inference( + dtype="int32" + ) + max_dec_len_this_time = helper.create_variable_for_type_inference( + dtype="int32" + ) + + inputs = {} + inputs['seq_lens_encoder'] = seq_lens_encoder + inputs['seq_lens_decoder'] = seq_lens_decoder + + outputs = { + 'max_enc_len_this_time': max_enc_len_this_time, + 'max_dec_len_this_time': max_dec_len_this_time, + } + helper.append_op( + type='blha_get_max_len', + inputs=inputs, + outputs=outputs, + attrs={ + 'bsz': bsz, + }, + ) + return max_enc_len_this_time, max_dec_len_this_time diff --git a/python/paddle/incubate/nn/functional/block_multihead_attention.py b/python/paddle/incubate/nn/functional/block_multihead_attention.py index 6409f160aaf69..9ab6b71a7157d 100644 --- a/python/paddle/incubate/nn/functional/block_multihead_attention.py +++ b/python/paddle/incubate/nn/functional/block_multihead_attention.py @@ -38,6 +38,8 @@ def block_multihead_attention( qkv_bias=None, out_shift=None, out_smooth=None, + max_enc_len_this_time=None, + max_dec_len_this_time=None, rope_emb=None, mask=None, tgt_mask=None, @@ -301,6 +303,8 @@ def block_multihead_attention( qkv_bias, out_shift, out_smooth, + max_enc_len_this_time, + max_dec_len_this_time, max_seq_len, block_size, use_neox_style, @@ -353,6 +357,10 @@ def block_multihead_attention( inputs["out_shift"] = out_shift if out_smooth is not None: inputs["out_smooth"] = out_smooth + if max_enc_len_this_time is not None: + inputs["max_enc_len_this_time"] = max_enc_len_this_time + if max_dec_len_this_time is not None: + inputs["max_dec_len_this_time"] = max_dec_len_this_time outputs = { 'fmha_out': out, From f5f05f9ee384827cc6568a285f0b3e8ed03a344f Mon Sep 17 00:00:00 2001 From: minghaipeng Date: Mon, 13 May 2024 07:38:39 +0000 Subject: [PATCH 02/10] modify batch_size from int attr to tensor(shape[0]) --- paddle/phi/api/yaml/fused_ops.yaml | 2 +- paddle/phi/infermeta/fusion.cc | 2 +- paddle/phi/infermeta/fusion.h | 2 +- .../phi/kernels/fusion/gpu/blha_get_max_len.cu | 13 +++++++------ .../incubate/nn/functional/blha_get_max_len.py | 17 +++++++++-------- 5 files changed, 19 insertions(+), 17 deletions(-) diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 00ea646b83508..375760682e073 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -33,7 +33,7 @@ data_type : x - op : blha_get_max_len - args : (Tensor seq_lens_encoder, Tensor seq_lens_decoder, int bsz) + args : (Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor batch_size) output : Tensor(max_enc_len_this_time), Tensor(max_dec_len_this_time) infer_meta : func : BlhaGetMaxLenInferMeta diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index c036bf30539fb..a0d7b7a920509 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -234,7 +234,7 @@ void FusedMultiTransformerInferMeta( void BlhaGetMaxLenInferMeta(const MetaTensor& seq_lens_encoder, const MetaTensor& seq_lens_decoder, - const int bsz, + const MetaTensor& batch_size, MetaTensor* max_enc_len_this_time, MetaTensor* max_dec_len_this_time) { max_enc_len_this_time->set_dims({1}); diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 3a19bda6182d7..4192a5d7f1796 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -79,7 +79,7 @@ void GroupNormalizeSiluXPUInferMeta(const MetaTensor& x, void BlhaGetMaxLenInferMeta(const MetaTensor& seq_lens_encoder, const MetaTensor& seq_lens_decoder, - const int bsz, + const MetaTensor& batch_size, MetaTensor* max_enc_len_this_time, MetaTensor* max_dec_len_this_time); diff --git a/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu b/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu index 3f7ea5d2c7688..a2e8b715bba24 100644 --- a/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu +++ b/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -28,16 +28,17 @@ namespace fusion { void GetMaxLenTensor(const phi::GPUContext& dev_ctx, const phi::DenseTensor& seq_lens_tensor, - const int batch_size, + const phi::DenseTensor& batch_size, DenseTensor* out) { phi::DenseTensor max_len_tensor; max_len_tensor.Resize({{1}}); auto* max_len_tensor_data = dev_ctx.template Alloc( &max_len_tensor, max_len_tensor.numel() * sizeof(int)); + const int bsz = batch_size.dims()[0]; constexpr int blockSize = 128; int max_len_cpu = 0; GetMaxLenKernel<<<1, blockSize, 0, dev_ctx.stream()>>>( - seq_lens_tensor.data(), max_len_tensor.data(), batch_size); + seq_lens_tensor.data(), max_len_tensor.data(), bsz); MemcpyD2HKernel(dev_ctx, max_len_tensor, 0, out); } @@ -45,16 +46,16 @@ template void BlhaGetMaxLenKernel(const Context& dev_ctx, const DenseTensor& seq_lens_encoder, const DenseTensor& seq_lens_decoder, - const int bsz, + const phi::DenseTensor& batch_size, DenseTensor* max_enc_len_this_time, DenseTensor* max_dec_len_this_time) { // decoder max_dec_len_this_time->Resize({{1}}); - GetMaxLenTensor(dev_ctx, seq_lens_decoder, bsz, max_dec_len_this_time); + GetMaxLenTensor(dev_ctx, seq_lens_decoder, batch_size, max_dec_len_this_time); // encoder max_enc_len_this_time->Resize({{1}}); - GetMaxLenTensor(dev_ctx, seq_lens_encoder, bsz, max_enc_len_this_time); + GetMaxLenTensor(dev_ctx, seq_lens_encoder, batch_size, max_enc_len_this_time); } } // namespace fusion } // namespace phi diff --git a/python/paddle/incubate/nn/functional/blha_get_max_len.py b/python/paddle/incubate/nn/functional/blha_get_max_len.py index ef3df049641d7..f54a34c022def 100644 --- a/python/paddle/incubate/nn/functional/blha_get_max_len.py +++ b/python/paddle/incubate/nn/functional/blha_get_max_len.py @@ -16,14 +16,14 @@ from paddle.framework import LayerHelper, in_dynamic_mode -def blha_get_max_len(seq_lens_encoder, seq_lens_decoder, bsz): +def blha_get_max_len(seq_lens_encoder, seq_lens_decoder, batch_size): """ Apply Fused BlhaGetMaxLen kernel. Typically used before the block_multihead_attention operator. Args: seq_lens_encoder (Tensor): Sentence length of the encoder. seq_lens_decoder (Tensor): Sentence length of the decoder. - bsz (Int): the batch size. + batch_size (Tensor): the batch size. Returns: Tensor|(max_enc_len_this_time, max_dec_len_this_time) @@ -37,11 +37,14 @@ def blha_get_max_len(seq_lens_encoder, seq_lens_decoder, bsz): >>> seq_lens_encoder = paddle.cast(paddle.randn(shape=[1]), dtype=paddle.int32) >>> seq_lens_decoder = paddle.cast(paddle.randn(shape=[1]), dtype=paddle.int32) - >>> bsz = 1 - >>> max_enc_len_this_time, max_dec_len_this_time = paddle.incubate.nn.functional.blha_get_max_len(seq_lens_encoder, seq_lens_decoder, bsz) + >>> bsz = 10 + >>> batch_size = paddle.ones(shape=[bsz]) + >>> max_enc_len_this_time, max_dec_len_this_time = paddle.incubate.nn.functional.blha_get_max_len(seq_lens_encoder, seq_lens_decoder, batch_size) """ if in_dynamic_mode(): - return _C_ops.blha_get_max_len(seq_lens_encoder, seq_lens_decoder, bsz) + return _C_ops.blha_get_max_len( + seq_lens_encoder, seq_lens_decoder, batch_size + ) helper = LayerHelper('blha_get_max_len', **locals()) max_enc_len_this_time = helper.create_variable_for_type_inference( @@ -54,6 +57,7 @@ def blha_get_max_len(seq_lens_encoder, seq_lens_decoder, bsz): inputs = {} inputs['seq_lens_encoder'] = seq_lens_encoder inputs['seq_lens_decoder'] = seq_lens_decoder + inputs['batch_size'] = batch_size outputs = { 'max_enc_len_this_time': max_enc_len_this_time, @@ -63,8 +67,5 @@ def blha_get_max_len(seq_lens_encoder, seq_lens_decoder, bsz): type='blha_get_max_len', inputs=inputs, outputs=outputs, - attrs={ - 'bsz': bsz, - }, ) return max_enc_len_this_time, max_dec_len_this_time From 03ee47c1d42a2d3649a6e7545335e3ffa6895b95 Mon Sep 17 00:00:00 2001 From: minghaipeng Date: Mon, 13 May 2024 11:08:04 +0000 Subject: [PATCH 03/10] fix bug --- paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu | 8 ++++++-- .../incubate/nn/functional/block_multihead_attention.py | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu b/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu index a2e8b715bba24..78a46e16989e5 100644 --- a/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu +++ b/paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu @@ -60,5 +60,9 @@ void BlhaGetMaxLenKernel(const Context& dev_ctx, } // namespace fusion } // namespace phi -PD_REGISTER_KERNEL( - blha_get_max_len, GPU, ALL_LAYOUT, phi::fusion::BlhaGetMaxLenKernel, int) {} +PD_REGISTER_KERNEL(blha_get_max_len, + GPU, + ALL_LAYOUT, + phi::fusion::BlhaGetMaxLenKernel, + int, + int64_t) {} diff --git a/python/paddle/incubate/nn/functional/block_multihead_attention.py b/python/paddle/incubate/nn/functional/block_multihead_attention.py index 9ab6b71a7157d..5ee22aa0ee532 100644 --- a/python/paddle/incubate/nn/functional/block_multihead_attention.py +++ b/python/paddle/incubate/nn/functional/block_multihead_attention.py @@ -78,6 +78,8 @@ def block_multihead_attention( qkv_bias (Tensor): The bias of qkv. Its shape is [3 * num_head * head_size]. out_shift (Tensor): Shift bias of fmha_out, which is the 1st return value. Its shape is [num_head * head_size]. out_smooth (Tensor): Smooth weight of fmha_out. Its shape is [num_head * head_size]. + max_enc_len_this_time (Tensor): Sentence length of the encoder this time. Its shape is [1]. + max_dec_len_this_time (Tensor): Sentence length of the decoder this time. Its shape is [1]. rope_emb (Tensor): The RoPE embedding. Its shape is [2, batchsize, max_seq_len, 1, head_size // 2]. mask (Tensor): The mask of qk_matmul in encoder. Its shape is [batchsize, 1, max_seq_len, max_seq_len]. tgt_mask (Tensor): The mask of qk_matmul in decoder. Its shape is [batchsize, 1, 1, max_seq_len]. From 72a5479c4349e2e74e800f55e8881b6717683093 Mon Sep 17 00:00:00 2001 From: minghaipeng Date: Mon, 13 May 2024 12:03:38 +0000 Subject: [PATCH 04/10] add test --- paddle/phi/kernels/CMakeLists.txt | 1 + .../nn/functional/blha_get_max_len.py | 4 +- test/legacy_test/test_blha_get_max_len_op.py | 55 +++++++++++++++++++ .../test_block_multihead_attention.py | 34 ++++++++++++ 4 files changed, 92 insertions(+), 2 deletions(-) create mode 100644 test/legacy_test/test_blha_get_max_len_op.py diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 17665623b56c1..0aca647dd6a49 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -215,6 +215,7 @@ if(WITH_ROCM) "gpu/qr_kernel.cu" "gpu/svd_kernel.cu" "gpudnn/mha_cudnn_frontend.cu" + "fusion/gpu/blha_get_max_len.cu" "fusion/gpu/block_multi_head_attention_kernel.cu" "fusion/gpu/fused_bn_add_activation_grad_kernel.cu" "fusion/gpu/fused_bn_add_activation_kernel.cu" diff --git a/python/paddle/incubate/nn/functional/blha_get_max_len.py b/python/paddle/incubate/nn/functional/blha_get_max_len.py index f54a34c022def..0474a235cfc13 100644 --- a/python/paddle/incubate/nn/functional/blha_get_max_len.py +++ b/python/paddle/incubate/nn/functional/blha_get_max_len.py @@ -35,8 +35,8 @@ def blha_get_max_len(seq_lens_encoder, seq_lens_decoder, batch_size): >>> import paddle >>> paddle.device.set_device('gpu') - >>> seq_lens_encoder = paddle.cast(paddle.randn(shape=[1]), dtype=paddle.int32) - >>> seq_lens_decoder = paddle.cast(paddle.randn(shape=[1]), dtype=paddle.int32) + >>> seq_lens_encoder = paddle.cast(paddle.randn(shape=[10]), dtype=paddle.int32) + >>> seq_lens_decoder = paddle.cast(paddle.randn(shape=[10]), dtype=paddle.int32) >>> bsz = 10 >>> batch_size = paddle.ones(shape=[bsz]) >>> max_enc_len_this_time, max_dec_len_this_time = paddle.incubate.nn.functional.blha_get_max_len(seq_lens_encoder, seq_lens_decoder, batch_size) diff --git a/test/legacy_test/test_blha_get_max_len_op.py b/test/legacy_test/test_blha_get_max_len_op.py new file mode 100644 index 0000000000000..675c022011862 --- /dev/null +++ b/test/legacy_test/test_blha_get_max_len_op.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.incubate.nn.functional import blha_get_max_len + + +class TestBlhaGetMaxLenOp(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.name = "TestBlhaGetMaxLenOp" + self.place = paddle.CUDAPlace(0) + self.batch_size = 10 + self.test_encoder_data = np.random.randint(1, 100, size=self.batch_size) + self.test_encoder_data_res = np.max(self.test_encoder_data) + self.test_decoder_data = np.random.randint(1, 100, size=self.batch_size) + self.test_decoder_data_res = np.max(self.test_decoder_data) + self.seq_lens_encoder = paddle.to_tensor( + self.test_encoder_data, + "int32", + ) + self.seq_lens_decoder = paddle.to_tensor( + self.test_decoder_data, + "int32", + ) + self.batch_size_tensor = paddle.ones([self.batch_size]) + + def test_all(self): + paddle.disable_static() + max_enc_len_this_time, max_dec_len_this_time = blha_get_max_len( + self.seq_lens_encoder, self.seq_lens_decoder, self.batch_size_tensor + ) + assert ( + max_enc_len_this_time.numpy() == self.test_encoder_data_res + and max_dec_len_this_time.numpy() == self.test_decoder_data_res + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_block_multihead_attention.py b/test/legacy_test/test_block_multihead_attention.py index 7f3033044e1c5..3bf5fca0e6adc 100644 --- a/test/legacy_test/test_block_multihead_attention.py +++ b/test/legacy_test/test_block_multihead_attention.py @@ -392,6 +392,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -496,6 +498,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask self.tgt_mask, # tgt_mask @@ -690,6 +694,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time self.rope_emb, # rotary_embs None, # attn_mask None, # tgt_mask @@ -800,6 +806,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time self.rope_emb, # rotary_embs None, # attn_mask None, # tgt_mask @@ -979,6 +987,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs self.attention_mask, # attn_mask None, # tgt_mask @@ -1083,6 +1093,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs self.attention_mask, # attn_mask None, # tgt_mask @@ -1285,6 +1297,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -1498,6 +1512,8 @@ def test_all(self): qkv_bias, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -1643,6 +1659,8 @@ def test_all(self): qkv_bias, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -1858,6 +1876,8 @@ def test_all(self): qkv_bias, # qkv_bias shift, # out_shift smooth, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -2021,6 +2041,8 @@ def test_all(self): qkv_bias, # qkv_bias shift, # out_shift smooth, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -2186,6 +2208,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -2298,6 +2322,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -2469,6 +2495,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -2579,6 +2607,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -2762,6 +2792,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -2871,6 +2903,8 @@ def test_all(self): None, # qkv_bias None, # out_shift None, # out_smooth + None, # max_enc_len_this_time + None, # max_dec_len_this_time None, # rotary_embs None, # attn_mask None, # tgt_mask From 81f0b15994f31069e8f54ed6cac3d0b5a2a25439 Mon Sep 17 00:00:00 2001 From: minghaipeng Date: Tue, 14 May 2024 03:03:17 +0000 Subject: [PATCH 05/10] only gpu --- test/legacy_test/test_blha_get_max_len_op.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/test/legacy_test/test_blha_get_max_len_op.py b/test/legacy_test/test_blha_get_max_len_op.py index 675c022011862..4028d83a14d4c 100644 --- a/test/legacy_test/test_blha_get_max_len_op.py +++ b/test/legacy_test/test_blha_get_max_len_op.py @@ -17,6 +17,7 @@ import numpy as np import paddle +from paddle.base import core from paddle.incubate.nn.functional import blha_get_max_len @@ -41,14 +42,17 @@ def setUp(self): self.batch_size_tensor = paddle.ones([self.batch_size]) def test_all(self): - paddle.disable_static() - max_enc_len_this_time, max_dec_len_this_time = blha_get_max_len( - self.seq_lens_encoder, self.seq_lens_decoder, self.batch_size_tensor - ) - assert ( - max_enc_len_this_time.numpy() == self.test_encoder_data_res - and max_dec_len_this_time.numpy() == self.test_decoder_data_res - ) + if core.is_compiled_with_cuda(): + paddle.disable_static() + max_enc_len_this_time, max_dec_len_this_time = blha_get_max_len( + self.seq_lens_encoder, + self.seq_lens_decoder, + self.batch_size_tensor, + ) + assert ( + max_enc_len_this_time.numpy() == self.test_encoder_data_res + and max_dec_len_this_time.numpy() == self.test_decoder_data_res + ) if __name__ == '__main__': From 87e6daf7c38f84a041aa33a02b3592f192a640bc Mon Sep 17 00:00:00 2001 From: minghaipeng Date: Tue, 14 May 2024 05:20:49 +0000 Subject: [PATCH 06/10] fix bug --- test/legacy_test/test_blha_get_max_len_op.py | 22 ++++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/legacy_test/test_blha_get_max_len_op.py b/test/legacy_test/test_blha_get_max_len_op.py index 4028d83a14d4c..b650e1f0d149b 100644 --- a/test/legacy_test/test_blha_get_max_len_op.py +++ b/test/legacy_test/test_blha_get_max_len_op.py @@ -21,6 +21,7 @@ from paddle.incubate.nn.functional import blha_get_max_len +@unittest.skipIf(not core.is_compiled_with_cuda()) class TestBlhaGetMaxLenOp(unittest.TestCase): def setUp(self): paddle.disable_static() @@ -42,17 +43,16 @@ def setUp(self): self.batch_size_tensor = paddle.ones([self.batch_size]) def test_all(self): - if core.is_compiled_with_cuda(): - paddle.disable_static() - max_enc_len_this_time, max_dec_len_this_time = blha_get_max_len( - self.seq_lens_encoder, - self.seq_lens_decoder, - self.batch_size_tensor, - ) - assert ( - max_enc_len_this_time.numpy() == self.test_encoder_data_res - and max_dec_len_this_time.numpy() == self.test_decoder_data_res - ) + paddle.disable_static() + max_enc_len_this_time, max_dec_len_this_time = blha_get_max_len( + self.seq_lens_encoder, + self.seq_lens_decoder, + self.batch_size_tensor, + ) + assert ( + max_enc_len_this_time.numpy() == self.test_encoder_data_res + and max_dec_len_this_time.numpy() == self.test_decoder_data_res + ) if __name__ == '__main__': From 98728648512325fd0d1d93f815a4e44d0634b048 Mon Sep 17 00:00:00 2001 From: minghaipeng Date: Tue, 14 May 2024 07:13:46 +0000 Subject: [PATCH 07/10] fix bug --- test/legacy_test/test_blha_get_max_len_op.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/legacy_test/test_blha_get_max_len_op.py b/test/legacy_test/test_blha_get_max_len_op.py index b650e1f0d149b..903368b471ca4 100644 --- a/test/legacy_test/test_blha_get_max_len_op.py +++ b/test/legacy_test/test_blha_get_max_len_op.py @@ -21,7 +21,9 @@ from paddle.incubate.nn.functional import blha_get_max_len -@unittest.skipIf(not core.is_compiled_with_cuda()) +@unittest.skipIf( + not core.is_compiled_with_cuda(), "Only support GPU in CUDA mode." +) class TestBlhaGetMaxLenOp(unittest.TestCase): def setUp(self): paddle.disable_static() From ea065b5b03abd0f99a1f4b787b435b325eaa9ade Mon Sep 17 00:00:00 2001 From: minghaipeng Date: Tue, 14 May 2024 08:23:39 +0000 Subject: [PATCH 08/10] fix doc bug --- .../paddle/incubate/nn/functional/block_multihead_attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/incubate/nn/functional/block_multihead_attention.py b/python/paddle/incubate/nn/functional/block_multihead_attention.py index 5ee22aa0ee532..a55f61de2c678 100644 --- a/python/paddle/incubate/nn/functional/block_multihead_attention.py +++ b/python/paddle/incubate/nn/functional/block_multihead_attention.py @@ -255,6 +255,8 @@ def block_multihead_attention( ... None, # qkv_bias ... None, # out_shift ... None, # out_smooth + ... None, # max_enc_len_this_time + ... None, # max_dec_len_this_time ... None, # rotary_embs ... None, # attn_mask ... None, # tgt_mask From 642ab05a813a3ef0d2acc89f79939351be5db81d Mon Sep 17 00:00:00 2001 From: minghaipeng Date: Wed, 15 May 2024 09:08:47 +0000 Subject: [PATCH 09/10] fix bug --- .../gpu/block_multi_head_attention_kernel.cu | 18 +- test/legacy_test/test_blha_get_max_len_op.py | 29 +- .../test_block_multihead_attention.py | 278 ++++++++++++++++++ 3 files changed, 316 insertions(+), 9 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu b/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu index 2bfcdb6f557c9..239cf361ca29a 100644 --- a/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu @@ -311,6 +311,10 @@ void DispatchWithDtype( max_dec_len_this_time_data = GetMaxLen(dev_ctx, seq_lens_decoder, &max_dec_len_tensor, bsz); } else { + PADDLE_ENFORCE_EQ(max_dec_len_this_time.get().place().GetType(), + phi::AllocationType::CPU, + "max_dec_len_this_time must be on CPU, but Got %s.", + max_dec_len_this_time.get().place()); max_dec_len_this_time_data = *max_dec_len_this_time.get().data(); } @@ -323,6 +327,10 @@ void DispatchWithDtype( max_enc_len_this_time_data = GetMaxLen(dev_ctx, seq_lens_encoder, &max_enc_len_tensor, bsz); } else { + PADDLE_ENFORCE_EQ(max_enc_len_this_time.get().place().GetType(), + phi::AllocationType::CPU, + "max_enc_len_this_time must be on CPU, but Got %s.", + max_enc_len_this_time.get().place()); max_enc_len_this_time_data = *max_enc_len_this_time.get().data(); } @@ -864,12 +872,18 @@ PD_REGISTER_KERNEL(block_multihead_attention, phi::fusion::BlockMultiheadAttentionKernel, phi::dtype::bfloat16, phi::dtype::float16, - int32_t) {} + int32_t) { + kernel->InputAt(24).SetBackend(phi::Backend::CPU); + kernel->InputAt(25).SetBackend(phi::Backend::CPU); +} #else PD_REGISTER_KERNEL(block_multihead_attention, GPU, ALL_LAYOUT, phi::fusion::BlockMultiheadAttentionKernel, phi::dtype::float16, - int32_t) {} + int32_t) { + kernel->InputAt(24).SetBackend(phi::Backend::CPU); + kernel->InputAt(25).SetBackend(phi::Backend::CPU); +} #endif diff --git a/test/legacy_test/test_blha_get_max_len_op.py b/test/legacy_test/test_blha_get_max_len_op.py index 903368b471ca4..914dce275975f 100644 --- a/test/legacy_test/test_blha_get_max_len_op.py +++ b/test/legacy_test/test_blha_get_max_len_op.py @@ -26,14 +26,17 @@ ) class TestBlhaGetMaxLenOp(unittest.TestCase): def setUp(self): - paddle.disable_static() - self.name = "TestBlhaGetMaxLenOp" + self.name = "TestBlhaGetMaxLenOpDynamic" self.place = paddle.CUDAPlace(0) self.batch_size = 10 self.test_encoder_data = np.random.randint(1, 100, size=self.batch_size) - self.test_encoder_data_res = np.max(self.test_encoder_data) + self.test_encoder_data_res = paddle.to_tensor( + np.max(self.test_encoder_data), "int32" + ) self.test_decoder_data = np.random.randint(1, 100, size=self.batch_size) - self.test_decoder_data_res = np.max(self.test_decoder_data) + self.test_decoder_data_res = paddle.to_tensor( + np.max(self.test_decoder_data), "int32" + ) self.seq_lens_encoder = paddle.to_tensor( self.test_encoder_data, "int32", @@ -44,7 +47,7 @@ def setUp(self): ) self.batch_size_tensor = paddle.ones([self.batch_size]) - def test_all(self): + def test_dynamic_api(self): paddle.disable_static() max_enc_len_this_time, max_dec_len_this_time = blha_get_max_len( self.seq_lens_encoder, @@ -52,8 +55,20 @@ def test_all(self): self.batch_size_tensor, ) assert ( - max_enc_len_this_time.numpy() == self.test_encoder_data_res - and max_dec_len_this_time.numpy() == self.test_decoder_data_res + max_enc_len_this_time == self.test_encoder_data_res + and max_dec_len_this_time == self.test_decoder_data_res + ) + + def test_static_api(self): + paddle.enable_static() + max_enc_len_this_time, max_dec_len_this_time = blha_get_max_len( + self.seq_lens_encoder, + self.seq_lens_decoder, + self.batch_size_tensor, + ) + assert ( + max_enc_len_this_time == self.test_encoder_data_res + and max_dec_len_this_time == self.test_decoder_data_res ) diff --git a/test/legacy_test/test_block_multihead_attention.py b/test/legacy_test/test_block_multihead_attention.py index 3bf5fca0e6adc..e36605cf8ea07 100644 --- a/test/legacy_test/test_block_multihead_attention.py +++ b/test/legacy_test/test_block_multihead_attention.py @@ -516,6 +516,284 @@ def test_all(self): ) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11040 + or not is_sm_supported, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" + "and device's compute capability must be 8.x or 90", +) +class TestBlockMultiHeadAttnEncDecSkipGetMaxLen(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.name = "TestBlockMultiHeadAttnEncDecSkipGetMaxLen" + self.place = paddle.CUDAPlace(0) + self.batch_size = 2 + self.num_head = 8 + self.seq_len = 64 + self.max_dec_len = 64 + self.dim_head = 64 + self.hid_dim = self.num_head * self.dim_head + self.blocksize = 64 + self.block_num_per_seq = ( + self.seq_len + self.max_dec_len + self.blocksize - 1 + ) // self.blocksize + self.max_block_num = self.block_num_per_seq * self.batch_size + self.free_list = list(range(self.max_block_num - 1, -1, -1)) + self.seq_lens_encoder = paddle.to_tensor( + [ + self.seq_len, + ] + * self.batch_size, + "int32", + ) + self.seq_lens_decoder = paddle.to_tensor( + [ + 0, + ] + * self.batch_size, + "int32", + ) + self.seq_lens_this_time = self.seq_lens_encoder + self.max_enc_len_this_time = paddle.to_tensor( + [self.seq_len], "int32" + ).cpu() + self.max_dec_len_this_time = paddle.to_tensor([0], "int32").cpu() + self.shape = ( + self.batch_size, + self.num_head, + self.seq_len, + self.dim_head, + ) + self.cache_shape = ( + self.max_block_num, + self.num_head, + self.blocksize, + self.dim_head, + ) + self.dtype = 'float16' + self.attention_mask = create_attn_mask( + self.dtype, + self.batch_size, + [ + self.seq_len, + ] + * self.batch_size, + ) + + self.tgt_mask = paddle.randn( + [self.batch_size, self.num_head, 1, self.seq_len + 1], + dtype=self.dtype, + ) + + self.scale = 1.0 / np.sqrt(self.shape[-1]) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.block_tables = paddle.zeros( + shape=(self.batch_size, self.block_num_per_seq), dtype="int32" + ) + for i in range(self.batch_size): + need_block_num = ( + self.seq_len + self.max_dec_len + self.blocksize - 1 + ) // self.blocksize + for j in range(need_block_num): + self.block_tables[i, j] = self.free_list.pop() + ( + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset( + self.batch_size, self.seq_len, self.seq_lens_this_time + ) + self.token_num = self.padding_offset.shape[0] + + def test_all(self): + paddle.disable_static() + # encoder + query = np.random.random(self.shape) + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + key = np.random.random(self.shape) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + value = np.random.random(self.shape) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + qkv = paddle.stack( + [ + q.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + k.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + v.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + ], + axis=1, + ).reshape([self.token_num, -1]) + out_ = naive_attention_impl( + q, k, v, None, None, None, None, self.attention_mask, self.scale + ) + out_ = remove_padding( + self.seq_lens_this_time, self.cu_seqlens_q, out_, self.token_num + ) + out = block_multihead_attention( + qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + self.block_tables, + None, # pre_key_cache + None, # pre_value_cache + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # qkv_out_scale + None, # qkv_bias + None, # out_shift + None, # out_smooth + self.max_enc_len_this_time, # max_enc_len_this_time + self.max_dec_len_this_time, # max_dec_len_this_time + None, # rotary_embs + None, # attn_mask + None, # tgt_mask + self.seq_len, + self.blocksize, + False, # use_neox_rotary_style, + )[0] + + np.testing.assert_allclose( + out.numpy(), + out_.numpy(), + rtol=5e-03, + atol=1e-03, + ) + + # decoder + naive_cache_k, naive_cache_v = block_cache_to_naive_cache( + self.cache_k, + self.cache_v, + self.batch_size, + self.block_tables, + self.seq_len, + ) + + self.seq_lens_decoder[:] = self.seq_lens_encoder + self.seq_lens_encoder[:] = 0 + self.seq_lens_this_time[:] = 1 + self.max_enc_len_this_time = paddle.to_tensor([0], "int32").cpu() + self.max_dec_len_this_time = paddle.to_tensor( + [self.seq_len], "int32" + ).cpu() + self.shape = ( + self.batch_size, + self.num_head, + 1, + self.dim_head, + ) + query = np.random.random(self.shape) + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + key = np.random.random(self.shape) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + value = np.random.random(self.shape) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + qkv = paddle.stack( + [ + q.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + k.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + v.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + ], + axis=1, + ).reshape([self.batch_size, -1]) + ( + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time) + + out_ = ( + naive_attention_impl( + q, + k, + v, + naive_cache_k, + naive_cache_v, + None, + None, + self.tgt_mask, + self.scale, + ) + .transpose([0, 2, 1, 3]) + .reshape([self.batch_size, -1]) + ) + out = block_multihead_attention( + qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + self.block_tables, + None, # pre_key_cache + None, # pre_value_cache + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # qkv_out_scale + None, # qkv_bias + None, # out_shift + None, # out_smooth + self.max_enc_len_this_time, # max_enc_len_this_time + self.max_dec_len_this_time, # max_dec_len_this_time + None, # rotary_embs + None, # attn_mask + self.tgt_mask, # tgt_mask + 1, # seq_len, + self.blocksize, + False, # use_neox_rotary_style + )[0] + # NOTE: The diff of decoder is a little big + np.testing.assert_allclose( + out.numpy(), + out_.numpy(), + rtol=5e-02, + atol=5e-02, + ) + + @unittest.skipIf( not core.is_compiled_with_cuda() or get_cuda_version() < 11040 From 16840b1e268f68469f6b859551e2fc76c760a9db Mon Sep 17 00:00:00 2001 From: minghaipeng Date: Thu, 16 May 2024 12:55:29 +0000 Subject: [PATCH 10/10] fix bug --- .../gpu/block_multi_head_attention_kernel.cu | 20 +++++++++++-------- .../nn/functional/blha_get_max_len.py | 4 ++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu b/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu index 239cf361ca29a..9455dcc6e1368 100644 --- a/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu @@ -311,10 +311,12 @@ void DispatchWithDtype( max_dec_len_this_time_data = GetMaxLen(dev_ctx, seq_lens_decoder, &max_dec_len_tensor, bsz); } else { - PADDLE_ENFORCE_EQ(max_dec_len_this_time.get().place().GetType(), - phi::AllocationType::CPU, - "max_dec_len_this_time must be on CPU, but Got %s.", - max_dec_len_this_time.get().place()); + PADDLE_ENFORCE_EQ( + max_dec_len_this_time.get().place().GetType(), + phi::AllocationType::CPU, + errors::InvalidArgument( + "The place of input max_dec_len_this_time must be CPU, but got %s.", + max_dec_len_this_time.get().place())); max_dec_len_this_time_data = *max_dec_len_this_time.get().data(); } @@ -327,10 +329,12 @@ void DispatchWithDtype( max_enc_len_this_time_data = GetMaxLen(dev_ctx, seq_lens_encoder, &max_enc_len_tensor, bsz); } else { - PADDLE_ENFORCE_EQ(max_enc_len_this_time.get().place().GetType(), - phi::AllocationType::CPU, - "max_enc_len_this_time must be on CPU, but Got %s.", - max_enc_len_this_time.get().place()); + PADDLE_ENFORCE_EQ( + max_enc_len_this_time.get().place().GetType(), + phi::AllocationType::CPU, + errors::InvalidArgument( + "The place of input max_enc_len_this_time must be CPU, but got %s.", + max_enc_len_this_time.get().place())); max_enc_len_this_time_data = *max_enc_len_this_time.get().data(); } diff --git a/python/paddle/incubate/nn/functional/blha_get_max_len.py b/python/paddle/incubate/nn/functional/blha_get_max_len.py index 0474a235cfc13..f330803e1b2fa 100644 --- a/python/paddle/incubate/nn/functional/blha_get_max_len.py +++ b/python/paddle/incubate/nn/functional/blha_get_max_len.py @@ -13,7 +13,7 @@ # limitations under the License. from paddle import _C_ops -from paddle.framework import LayerHelper, in_dynamic_mode +from paddle.framework import LayerHelper, in_dynamic_or_pir_mode def blha_get_max_len(seq_lens_encoder, seq_lens_decoder, batch_size): @@ -41,7 +41,7 @@ def blha_get_max_len(seq_lens_encoder, seq_lens_decoder, batch_size): >>> batch_size = paddle.ones(shape=[bsz]) >>> max_enc_len_this_time, max_dec_len_this_time = paddle.incubate.nn.functional.blha_get_max_len(seq_lens_encoder, seq_lens_decoder, batch_size) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.blha_get_max_len( seq_lens_encoder, seq_lens_decoder, batch_size )