Skip to content

Commit

Permalink
[NPU] add blha_get_max_len (#1246)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 committed May 21, 2024
1 parent 147d506 commit 83dfe3d
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -468,9 +468,9 @@ void FusedBlhaGlobalVar::update_mask(const phi::CustomContext &dev_ctx,
tmp_mask.Resize({max_seq_len, max_seq_len});
custom_kernel::ScaleKernel<phi::float16>(
dev_ctx, tril_ones_tensor, 1.0f, -1.0f, true, &tmp_mask);

// use 50000 to avoid overflow
custom_kernel::ScaleKernel<phi::float16>(
dev_ctx, tmp_mask, 1000000.0f, 0.0f, true, g_mask.get());
dev_ctx, tmp_mask, 50000.0f, 0.0f, true, g_mask.get());
}
}

Expand Down
3 changes: 2 additions & 1 deletion backends/npu/custom_op/llama_infer/save_with_output_msg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ void SaveOutMmsg(const paddle::Tensor& x,
static int msgid = msgget(key, IPC_CREAT | 0666);

msg_sed.mtype = 1;
bool not_need_stop_data = not_need_stop.data<bool>()[0];
auto not_need_stop_cpu = not_need_stop.copy_to(paddle::CPUPlace(), true);
bool not_need_stop_data = not_need_stop_cpu.data<bool>()[0];
msg_sed.mtext[0] = not_need_stop_data ? 1 : -1;
int bsz = x.shape()[0];
msg_sed.mtext[1] = bsz;
Expand Down
74 changes: 48 additions & 26 deletions backends/npu/custom_op/llama_infer/update_inputs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,8 @@ void UpdateInputes(const paddle::Tensor& stop_flags,
stop_flags.place()));
auto stream = static_cast<aclrtStream>(dev_ctx->stream());

auto not_need_stop_npu = not_need_stop.copy_to(stop_flags.place(), false);

auto stop_flags_tensor =
static_cast<const phi::DenseTensor*>(stop_flags.impl().get());
auto not_need_stop_tensor =
static_cast<const phi::DenseTensor*>(not_need_stop_npu.impl().get());
auto seq_lens_this_time_tensor =
static_cast<const phi::DenseTensor*>(seq_lens_this_time.impl().get());
auto seq_lens_encoder_tensor =
Expand All @@ -54,28 +50,54 @@ void UpdateInputes(const paddle::Tensor& stop_flags,
auto is_block_step_tensor =
static_cast<const phi::DenseTensor*>(is_block_step.impl().get());

const auto& runner = NpuOpRunner("UpdateInputs",
{*stop_flags_tensor,
*not_need_stop_tensor,
*seq_lens_this_time_tensor,
*seq_lens_encoder_tensor,
*seq_lens_decoder_tensor,
*input_ids_tensor,
*stop_nums_tensor,
*next_tokens_tensor,
*is_block_step_tensor},
{*not_need_stop_tensor,
*seq_lens_this_time_tensor,
*seq_lens_encoder_tensor,
*seq_lens_decoder_tensor,
*input_ids_tensor},
{});
runner.Run(stream);

auto not_need_stop_cpu =
not_need_stop_npu.copy_to(not_need_stop.place(), true);
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
bool not_need_stop_on_host =
not_need_stop.place().GetType() == phi::AllocationType::CPU;
if (not_need_stop_on_host) {
auto not_need_stop_npu = not_need_stop.copy_to(stop_flags.place(), false);
auto not_need_stop_tensor =
static_cast<const phi::DenseTensor*>(not_need_stop_npu.impl().get());
const auto& runner = NpuOpRunner("UpdateInputs",
{*stop_flags_tensor,
*not_need_stop_tensor,
*seq_lens_this_time_tensor,
*seq_lens_encoder_tensor,
*seq_lens_decoder_tensor,
*input_ids_tensor,
*stop_nums_tensor,
*next_tokens_tensor,
*is_block_step_tensor},
{*not_need_stop_tensor,
*seq_lens_this_time_tensor,
*seq_lens_encoder_tensor,
*seq_lens_decoder_tensor,
*input_ids_tensor},
{});
runner.Run(stream);
auto not_need_stop_cpu =
not_need_stop_npu.copy_to(not_need_stop.place(), true);
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
} else {
auto not_need_stop_tensor =
static_cast<const phi::DenseTensor*>(not_need_stop.impl().get());
const auto& runner = NpuOpRunner("UpdateInputs",
{*stop_flags_tensor,
*not_need_stop_tensor,
*seq_lens_this_time_tensor,
*seq_lens_encoder_tensor,
*seq_lens_decoder_tensor,
*input_ids_tensor,
*stop_nums_tensor,
*next_tokens_tensor,
*is_block_step_tensor},
{*not_need_stop_tensor,
*seq_lens_this_time_tensor,
*seq_lens_encoder_tensor,
*seq_lens_decoder_tensor,
*input_ids_tensor},
{});
runner.Run(stream);
}
}

PD_BUILD_OP(update_inputs)
Expand Down
35 changes: 35 additions & 0 deletions backends/npu/kernels/fusion/blha_get_max_len.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "kernels/funcs/npu_funcs.h"
#include "kernels/funcs/npu_op_runner.h"

namespace custom_kernel {
template <typename T, typename Context>
void BlhaGetMaxLenKernel(const Context& dev_ctx,
const phi::DenseTensor& seq_lens_encoder,
const phi::DenseTensor& seq_lens_decoder,
const phi::DenseTensor& batch_size,
phi::DenseTensor* max_enc_len_this_time,
phi::DenseTensor* max_dec_len_this_time) {
PADDLE_THROW(phi::errors::Unimplemented("Only supports model export"));
}
} // namespace custom_kernel

PD_REGISTER_PLUGIN_KERNEL(blha_get_max_len,
npu,
ALL_LAYOUT,
custom_kernel::BlhaGetMaxLenKernel,
int,
int64_t) {}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ void BlockMultiheadAttentionKernel(
const paddle::optional<phi::DenseTensor>& qkv_bias,
const paddle::optional<phi::DenseTensor>& out_shift,
const paddle::optional<phi::DenseTensor>& out_smooth,
const paddle::optional<phi::DenseTensor>& max_enc_len_this_time,
const paddle::optional<phi::DenseTensor>& max_dec_len_this_time,
int max_seq_len,
int block_size,
bool use_neox_style,
Expand All @@ -67,4 +69,7 @@ PD_REGISTER_PLUGIN_KERNEL(block_multihead_attention,
ALL_LAYOUT,
custom_kernel::BlockMultiheadAttentionKernel,
phi::float16,
int32_t) {}
int32_t) {
kernel->InputAt(24).SetBackend(phi::Backend::CPU);
kernel->InputAt(25).SetBackend(phi::Backend::CPU);
}
9 changes: 0 additions & 9 deletions backends/npu/kernels/scale_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ void AclopScaleKernel(const Context& dev_ctx,
VLOG(4) << "scale:" << scale << ", bias:" << bias
<< " ,bias_after_scale:" << bias_after_scale;
dev_ctx.template Alloc<T>(out);
if (std::isinf(scale) || std::isnan(scale)) {
FillNpuTensorWithConstant<T>(out, dev_ctx, static_cast<T>(scale));
return;
}
if (!bias_after_scale) {
bias *= scale;
}
Expand Down Expand Up @@ -109,11 +105,6 @@ void ScaleKernel(const Context& dev_ctx,
VLOG(4) << "scale:" << scale << ", bias:" << bias
<< " ,bias_after_scale:" << bias_after_scale;

if (std::isinf(scale) || std::isnan(scale)) {
FillNpuTensorWithConstant<T>(out, dev_ctx, static_cast<T>(scale));
return;
}

if (!bias_after_scale) {
bias *= scale;
}
Expand Down
2 changes: 2 additions & 0 deletions backends/npu/passes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def addPasses(pass_builder, model_type, quant_type):
if model_type == "llama" and quant_type == "a8w8":
register_pass(pass_builder, "remove_residual_in_fused_bias_residual_layernorm")
register_pass(pass_builder, "remove_residual_in_rms_norm")
register_pass(pass_builder, "remove_blha_get_max_len")
register_pass(pass_builder, "llama_fuse_attention_smooth_quant_layer_begin")
register_pass(pass_builder, "llama_fuse_attention_smooth_quant_layer_end")
register_pass(pass_builder, "llama_fuse_attention_smooth_quant_layer")
Expand All @@ -46,6 +47,7 @@ def addPasses(pass_builder, model_type, quant_type):
elif model_type == "llama":
register_pass(pass_builder, "remove_residual_in_fused_bias_residual_layernorm")
register_pass(pass_builder, "remove_residual_in_rms_norm")
register_pass(pass_builder, "remove_blha_get_max_len")
register_pass(pass_builder, "llama_fuse_attention_layer_begin")
register_pass(pass_builder, "llama_fuse_attention_layer_end")
register_pass(pass_builder, "llama_fuse_attention_layer")
Expand Down
19 changes: 19 additions & 0 deletions backends/npu/passes/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,25 @@ def replace(residual, x):
return pattern, replace


@ir.RegisterPass
def remove_blha_get_max_len():
def pattern(seq_lens_encoder, seq_lens_decoder, batch_size):
blha_get_max_len = ir.PassDesc.OP.blha_get_max_len(
seq_lens_encoder=seq_lens_encoder,
seq_lens_decoder=seq_lens_decoder,
batch_size=batch_size,
)
return (
blha_get_max_len.Output("max_enc_len_this_time")[0],
blha_get_max_len.Output("max_dec_len_this_time")[0],
)

def replace(seq_lens_encoder, seq_lens_decoder, batch_size):
return seq_lens_encoder, seq_lens_decoder

return pattern, replace


@ir.RegisterPass
def llama_fuse_attention_layer():
def pattern(
Expand Down

0 comments on commit 83dfe3d

Please sign in to comment.