Skip to content

Commit

Permalink
[NPU] add blha_get_max_len
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 committed May 20, 2024
1 parent 7397c7f commit 81521f0
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -468,9 +468,9 @@ void FusedBlhaGlobalVar::update_mask(const phi::CustomContext &dev_ctx,
tmp_mask.Resize({max_seq_len, max_seq_len});
custom_kernel::ScaleKernel<phi::float16>(
dev_ctx, tril_ones_tensor, 1.0f, -1.0f, true, &tmp_mask);

// use 50000 to avoid overflow
custom_kernel::ScaleKernel<phi::float16>(
dev_ctx, tmp_mask, 1000000.0f, 0.0f, true, g_mask.get());
dev_ctx, tmp_mask, 50000.0f, 0.0f, true, g_mask.get());
}
}

Expand Down
35 changes: 35 additions & 0 deletions backends/npu/kernels/fusion/blha_get_max_len.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "kernels/funcs/npu_funcs.h"
#include "kernels/funcs/npu_op_runner.h"

namespace custom_kernel {
template <typename T, typename Context>
void BlhaGetMaxLenKernel(const Context& dev_ctx,
const phi::DenseTensor& seq_lens_encoder,
const phi::DenseTensor& seq_lens_decoder,
const phi::DenseTensor& batch_size,
phi::DenseTensor* max_enc_len_this_time,
phi::DenseTensor* max_dec_len_this_time) {
PADDLE_THROW(phi::errors::Unimplemented("Only supports model export"));
}
} // namespace custom_kernel

PD_REGISTER_PLUGIN_KERNEL(blha_get_max_len,
npu,
ALL_LAYOUT,
custom_kernel::BlhaGetMaxLenKernel,
int,
int64_t) {}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ void BlockMultiheadAttentionKernel(
const paddle::optional<phi::DenseTensor>& qkv_bias,
const paddle::optional<phi::DenseTensor>& out_shift,
const paddle::optional<phi::DenseTensor>& out_smooth,
const paddle::optional<phi::DenseTensor>& max_enc_len_this_time,
const paddle::optional<phi::DenseTensor>& max_dec_len_this_time,
int max_seq_len,
int block_size,
bool use_neox_style,
Expand All @@ -67,4 +69,7 @@ PD_REGISTER_PLUGIN_KERNEL(block_multihead_attention,
ALL_LAYOUT,
custom_kernel::BlockMultiheadAttentionKernel,
phi::float16,
int32_t) {}
int32_t) {
kernel->InputAt(24).SetBackend(phi::Backend::CPU);
kernel->InputAt(25).SetBackend(phi::Backend::CPU);
}
9 changes: 0 additions & 9 deletions backends/npu/kernels/scale_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ void AclopScaleKernel(const Context& dev_ctx,
VLOG(4) << "scale:" << scale << ", bias:" << bias
<< " ,bias_after_scale:" << bias_after_scale;
dev_ctx.template Alloc<T>(out);
if (std::isinf(scale) || std::isnan(scale)) {
FillNpuTensorWithConstant<T>(out, dev_ctx, static_cast<T>(scale));
return;
}
if (!bias_after_scale) {
bias *= scale;
}
Expand Down Expand Up @@ -109,11 +105,6 @@ void ScaleKernel(const Context& dev_ctx,
VLOG(4) << "scale:" << scale << ", bias:" << bias
<< " ,bias_after_scale:" << bias_after_scale;

if (std::isinf(scale) || std::isnan(scale)) {
FillNpuTensorWithConstant<T>(out, dev_ctx, static_cast<T>(scale));
return;
}

if (!bias_after_scale) {
bias *= scale;
}
Expand Down
2 changes: 2 additions & 0 deletions backends/npu/passes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def addPasses(pass_builder, model_type, quant_type):
if model_type == "llama" and quant_type == "a8w8":
register_pass(pass_builder, "remove_residual_in_fused_bias_residual_layernorm")
register_pass(pass_builder, "remove_residual_in_rms_norm")
register_pass(pass_builder, "remove_blha_get_max_len")
register_pass(pass_builder, "llama_fuse_attention_smooth_quant_layer_begin")
register_pass(pass_builder, "llama_fuse_attention_smooth_quant_layer_end")
register_pass(pass_builder, "llama_fuse_attention_smooth_quant_layer")
Expand All @@ -46,6 +47,7 @@ def addPasses(pass_builder, model_type, quant_type):
elif model_type == "llama":
register_pass(pass_builder, "remove_residual_in_fused_bias_residual_layernorm")
register_pass(pass_builder, "remove_residual_in_rms_norm")
register_pass(pass_builder, "remove_blha_get_max_len")
register_pass(pass_builder, "llama_fuse_attention_layer_begin")
register_pass(pass_builder, "llama_fuse_attention_layer_end")
register_pass(pass_builder, "llama_fuse_attention_layer")
Expand Down
19 changes: 19 additions & 0 deletions backends/npu/passes/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,25 @@ def replace(residual, x):
return pattern, replace


@ir.RegisterPass
def remove_blha_get_max_len():
def pattern(seq_lens_encoder, seq_lens_decoder, batch_size):
blha_get_max_len = ir.PassDesc.OP.blha_get_max_len(
seq_lens_encoder=seq_lens_encoder,
seq_lens_decoder=seq_lens_decoder,
batch_size=batch_size,
)
return (
blha_get_max_len.Output("max_enc_len_this_time")[0],
blha_get_max_len.Output("max_dec_len_this_time")[0],
)

def replace(seq_lens_encoder, seq_lens_decoder, batch_size):
return seq_lens_encoder, seq_lens_decoder

return pattern, replace


@ir.RegisterPass
def llama_fuse_attention_layer():
def pattern(
Expand Down

0 comments on commit 81521f0

Please sign in to comment.