From 70f8bee1fc8b7783bb8cb66a1eac995e55a3f886 Mon Sep 17 00:00:00 2001 From: ronny1996 Date: Mon, 20 May 2024 16:33:22 +0800 Subject: [PATCH] [NPU] add blha_get_max_len --- .../atb_ops/fused_blha_layer_op_utils.cc | 4 +- .../llama_infer/save_with_output_msg.cc | 3 +- .../custom_op/llama_infer/update_inputs.cc | 74 ++++++++++++------- .../npu/kernels/fusion/blha_get_max_len.cc | 35 +++++++++ .../block_multihead_attention_kernel.cc | 7 +- backends/npu/kernels/scale_kernel.cc | 9 --- backends/npu/passes/common.py | 2 + backends/npu/passes/llama.py | 19 +++++ 8 files changed, 114 insertions(+), 39 deletions(-) create mode 100644 backends/npu/kernels/fusion/blha_get_max_len.cc diff --git a/backends/npu/custom_op/llama_infer/atb_ops/fused_blha_layer_op_utils.cc b/backends/npu/custom_op/llama_infer/atb_ops/fused_blha_layer_op_utils.cc index dc19cbd28..5e9f4082e 100644 --- a/backends/npu/custom_op/llama_infer/atb_ops/fused_blha_layer_op_utils.cc +++ b/backends/npu/custom_op/llama_infer/atb_ops/fused_blha_layer_op_utils.cc @@ -468,9 +468,9 @@ void FusedBlhaGlobalVar::update_mask(const phi::CustomContext &dev_ctx, tmp_mask.Resize({max_seq_len, max_seq_len}); custom_kernel::ScaleKernel( dev_ctx, tril_ones_tensor, 1.0f, -1.0f, true, &tmp_mask); - + // use 50000 to avoid overflow custom_kernel::ScaleKernel( - dev_ctx, tmp_mask, 1000000.0f, 0.0f, true, g_mask.get()); + dev_ctx, tmp_mask, 50000.0f, 0.0f, true, g_mask.get()); } } diff --git a/backends/npu/custom_op/llama_infer/save_with_output_msg.cc b/backends/npu/custom_op/llama_infer/save_with_output_msg.cc index fcc12d5c3..f0fdad257 100644 --- a/backends/npu/custom_op/llama_infer/save_with_output_msg.cc +++ b/backends/npu/custom_op/llama_infer/save_with_output_msg.cc @@ -38,7 +38,8 @@ void SaveOutMmsg(const paddle::Tensor& x, static int msgid = msgget(key, IPC_CREAT | 0666); msg_sed.mtype = 1; - bool not_need_stop_data = not_need_stop.data()[0]; + auto not_need_stop_cpu = not_need_stop.copy_to(paddle::CPUPlace(), true); + bool not_need_stop_data = not_need_stop_cpu.data()[0]; msg_sed.mtext[0] = not_need_stop_data ? 1 : -1; int bsz = x.shape()[0]; msg_sed.mtext[1] = bsz; diff --git a/backends/npu/custom_op/llama_infer/update_inputs.cc b/backends/npu/custom_op/llama_infer/update_inputs.cc index f52fd9e47..c18f03903 100644 --- a/backends/npu/custom_op/llama_infer/update_inputs.cc +++ b/backends/npu/custom_op/llama_infer/update_inputs.cc @@ -33,12 +33,8 @@ void UpdateInputes(const paddle::Tensor& stop_flags, stop_flags.place())); auto stream = static_cast(dev_ctx->stream()); - auto not_need_stop_npu = not_need_stop.copy_to(stop_flags.place(), false); - auto stop_flags_tensor = static_cast(stop_flags.impl().get()); - auto not_need_stop_tensor = - static_cast(not_need_stop_npu.impl().get()); auto seq_lens_this_time_tensor = static_cast(seq_lens_this_time.impl().get()); auto seq_lens_encoder_tensor = @@ -54,28 +50,54 @@ void UpdateInputes(const paddle::Tensor& stop_flags, auto is_block_step_tensor = static_cast(is_block_step.impl().get()); - const auto& runner = NpuOpRunner("UpdateInputs", - {*stop_flags_tensor, - *not_need_stop_tensor, - *seq_lens_this_time_tensor, - *seq_lens_encoder_tensor, - *seq_lens_decoder_tensor, - *input_ids_tensor, - *stop_nums_tensor, - *next_tokens_tensor, - *is_block_step_tensor}, - {*not_need_stop_tensor, - *seq_lens_this_time_tensor, - *seq_lens_encoder_tensor, - *seq_lens_decoder_tensor, - *input_ids_tensor}, - {}); - runner.Run(stream); - - auto not_need_stop_cpu = - not_need_stop_npu.copy_to(not_need_stop.place(), true); - bool* not_need_stop_data = const_cast(not_need_stop.data()); - not_need_stop_data[0] = not_need_stop_cpu.data()[0]; + bool not_need_stop_on_host = + not_need_stop.place().GetType() == phi::AllocationType::CPU; + if (not_need_stop_on_host) { + auto not_need_stop_npu = not_need_stop.copy_to(stop_flags.place(), false); + auto not_need_stop_tensor = + static_cast(not_need_stop_npu.impl().get()); + const auto& runner = NpuOpRunner("UpdateInputs", + {*stop_flags_tensor, + *not_need_stop_tensor, + *seq_lens_this_time_tensor, + *seq_lens_encoder_tensor, + *seq_lens_decoder_tensor, + *input_ids_tensor, + *stop_nums_tensor, + *next_tokens_tensor, + *is_block_step_tensor}, + {*not_need_stop_tensor, + *seq_lens_this_time_tensor, + *seq_lens_encoder_tensor, + *seq_lens_decoder_tensor, + *input_ids_tensor}, + {}); + runner.Run(stream); + auto not_need_stop_cpu = + not_need_stop_npu.copy_to(not_need_stop.place(), true); + bool* not_need_stop_data = const_cast(not_need_stop.data()); + not_need_stop_data[0] = not_need_stop_cpu.data()[0]; + } else { + auto not_need_stop_tensor = + static_cast(not_need_stop.impl().get()); + const auto& runner = NpuOpRunner("UpdateInputs", + {*stop_flags_tensor, + *not_need_stop_tensor, + *seq_lens_this_time_tensor, + *seq_lens_encoder_tensor, + *seq_lens_decoder_tensor, + *input_ids_tensor, + *stop_nums_tensor, + *next_tokens_tensor, + *is_block_step_tensor}, + {*not_need_stop_tensor, + *seq_lens_this_time_tensor, + *seq_lens_encoder_tensor, + *seq_lens_decoder_tensor, + *input_ids_tensor}, + {}); + runner.Run(stream); + } } PD_BUILD_OP(update_inputs) diff --git a/backends/npu/kernels/fusion/blha_get_max_len.cc b/backends/npu/kernels/fusion/blha_get_max_len.cc new file mode 100644 index 000000000..c8acdb870 --- /dev/null +++ b/backends/npu/kernels/fusion/blha_get_max_len.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "kernels/funcs/npu_funcs.h" +#include "kernels/funcs/npu_op_runner.h" + +namespace custom_kernel { +template +void BlhaGetMaxLenKernel(const Context& dev_ctx, + const phi::DenseTensor& seq_lens_encoder, + const phi::DenseTensor& seq_lens_decoder, + const phi::DenseTensor& batch_size, + phi::DenseTensor* max_enc_len_this_time, + phi::DenseTensor* max_dec_len_this_time) { + PADDLE_THROW(phi::errors::Unimplemented("Only supports model export")); +} +} // namespace custom_kernel + +PD_REGISTER_PLUGIN_KERNEL(blha_get_max_len, + npu, + ALL_LAYOUT, + custom_kernel::BlhaGetMaxLenKernel, + int, + int64_t) {} diff --git a/backends/npu/kernels/fusion/block_multihead_attention_kernel.cc b/backends/npu/kernels/fusion/block_multihead_attention_kernel.cc index 3b9e44e1e..49baececc 100644 --- a/backends/npu/kernels/fusion/block_multihead_attention_kernel.cc +++ b/backends/npu/kernels/fusion/block_multihead_attention_kernel.cc @@ -44,6 +44,8 @@ void BlockMultiheadAttentionKernel( const paddle::optional& qkv_bias, const paddle::optional& out_shift, const paddle::optional& out_smooth, + const paddle::optional& max_enc_len_this_time, + const paddle::optional& max_dec_len_this_time, int max_seq_len, int block_size, bool use_neox_style, @@ -67,4 +69,7 @@ PD_REGISTER_PLUGIN_KERNEL(block_multihead_attention, ALL_LAYOUT, custom_kernel::BlockMultiheadAttentionKernel, phi::float16, - int32_t) {} + int32_t) { + kernel->InputAt(24).SetBackend(phi::Backend::CPU); + kernel->InputAt(25).SetBackend(phi::Backend::CPU); +} diff --git a/backends/npu/kernels/scale_kernel.cc b/backends/npu/kernels/scale_kernel.cc index 7dc10f8a7..44a6e4ce5 100644 --- a/backends/npu/kernels/scale_kernel.cc +++ b/backends/npu/kernels/scale_kernel.cc @@ -37,10 +37,6 @@ void AclopScaleKernel(const Context& dev_ctx, VLOG(4) << "scale:" << scale << ", bias:" << bias << " ,bias_after_scale:" << bias_after_scale; dev_ctx.template Alloc(out); - if (std::isinf(scale) || std::isnan(scale)) { - FillNpuTensorWithConstant(out, dev_ctx, static_cast(scale)); - return; - } if (!bias_after_scale) { bias *= scale; } @@ -109,11 +105,6 @@ void ScaleKernel(const Context& dev_ctx, VLOG(4) << "scale:" << scale << ", bias:" << bias << " ,bias_after_scale:" << bias_after_scale; - if (std::isinf(scale) || std::isnan(scale)) { - FillNpuTensorWithConstant(out, dev_ctx, static_cast(scale)); - return; - } - if (!bias_after_scale) { bias *= scale; } diff --git a/backends/npu/passes/common.py b/backends/npu/passes/common.py index b918d35be..e8d3f9e02 100644 --- a/backends/npu/passes/common.py +++ b/backends/npu/passes/common.py @@ -37,6 +37,7 @@ def addPasses(pass_builder, model_type, quant_type): if model_type == "llama" and quant_type == "a8w8": register_pass(pass_builder, "remove_residual_in_fused_bias_residual_layernorm") register_pass(pass_builder, "remove_residual_in_rms_norm") + register_pass(pass_builder, "remove_blha_get_max_len") register_pass(pass_builder, "llama_fuse_attention_smooth_quant_layer_begin") register_pass(pass_builder, "llama_fuse_attention_smooth_quant_layer_end") register_pass(pass_builder, "llama_fuse_attention_smooth_quant_layer") @@ -46,6 +47,7 @@ def addPasses(pass_builder, model_type, quant_type): elif model_type == "llama": register_pass(pass_builder, "remove_residual_in_fused_bias_residual_layernorm") register_pass(pass_builder, "remove_residual_in_rms_norm") + register_pass(pass_builder, "remove_blha_get_max_len") register_pass(pass_builder, "llama_fuse_attention_layer_begin") register_pass(pass_builder, "llama_fuse_attention_layer_end") register_pass(pass_builder, "llama_fuse_attention_layer") diff --git a/backends/npu/passes/llama.py b/backends/npu/passes/llama.py index 8e50b4000..7c5a0c7ad 100644 --- a/backends/npu/passes/llama.py +++ b/backends/npu/passes/llama.py @@ -58,6 +58,25 @@ def replace(residual, x): return pattern, replace +@ir.RegisterPass +def remove_blha_get_max_len(): + def pattern(seq_lens_encoder, seq_lens_decoder, batch_size): + blha_get_max_len = ir.PassDesc.OP.blha_get_max_len( + seq_lens_encoder=seq_lens_encoder, + seq_lens_decoder=seq_lens_decoder, + batch_size=batch_size, + ) + return ( + blha_get_max_len.Output("max_enc_len_this_time")[0], + blha_get_max_len.Output("max_dec_len_this_time")[0], + ) + + def replace(seq_lens_encoder, seq_lens_decoder, batch_size): + return seq_lens_encoder, seq_lens_decoder + + return pattern, replace + + @ir.RegisterPass def llama_fuse_attention_layer(): def pattern(