Skip to content

Commit

Permalink
[Inference]add blha_get_max_len op & modify block_multihead_attention…
Browse files Browse the repository at this point in the history
… op (#64246)

* [Inference]add blha_get_max_len op & modify block_multihead_attention op
  • Loading branch information
ming1753 committed May 21, 2024
1 parent 79a8490 commit 669a261
Show file tree
Hide file tree
Showing 12 changed files with 663 additions and 48 deletions.
16 changes: 14 additions & 2 deletions paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,29 @@
func : addcmul_xpu
data_type : x

- op : blha_get_max_len
args : (Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor batch_size)
output : Tensor(max_enc_len_this_time), Tensor(max_dec_len_this_time)
infer_meta :
func : BlhaGetMaxLenInferMeta
kernel :
func : blha_get_max_len
data_type : seq_lens_encoder
support_dygraph_mode : true

- op : block_multihead_attention_
args : (Tensor qkv, Tensor key_cache, Tensor value_cache, Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor seq_lens_this_time, Tensor padding_offsets, Tensor cum_offsets, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor block_tables, Tensor pre_key_cache, Tensor pre_value_cache, Tensor rope_emb, Tensor mask, Tensor tgt_mask, Tensor cache_k_quant_scales, Tensor cache_v_quant_scales, Tensor cache_k_dequant_scales, Tensor cache_v_dequant_scales, Tensor qkv_out_scale, Tensor qkv_bias, Tensor out_shift, Tensor out_smooth, int max_seq_len, int block_size, bool use_neox_style, bool dynamic_cachekv_quant=false, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0, float out_scale=-1, str compute_dtype = "default")
args : (Tensor qkv, Tensor key_cache, Tensor value_cache, Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor seq_lens_this_time, Tensor padding_offsets, Tensor cum_offsets, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor block_tables, Tensor pre_key_cache, Tensor pre_value_cache, Tensor rope_emb, Tensor mask, Tensor tgt_mask, Tensor cache_k_quant_scales, Tensor cache_v_quant_scales, Tensor cache_k_dequant_scales, Tensor cache_v_dequant_scales, Tensor qkv_out_scale, Tensor qkv_bias, Tensor out_shift, Tensor out_smooth, Tensor max_enc_len_this_time, Tensor max_dec_len_this_time, int max_seq_len, int block_size, bool use_neox_style, bool dynamic_cachekv_quant=false, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0, float out_scale=-1, str compute_dtype = "default")
output : Tensor(fmha_out), Tensor(qkv_out), Tensor(key_cache_out), Tensor(value_cache_out)
infer_meta :
func : BlockMultiheadAttentionInferMeta
kernel :
func : block_multihead_attention
data_type : qkv
optional : pre_key_cache, pre_value_cache, rope_emb, mask, tgt_mask, cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales, cache_v_dequant_scales, qkv_out_scale, qkv_bias, out_shift, out_smooth
optional : pre_key_cache, pre_value_cache, rope_emb, mask, tgt_mask, cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales, cache_v_dequant_scales, qkv_out_scale, qkv_bias, out_shift, out_smooth, max_enc_len_this_time, max_dec_len_this_time
inplace : (qkv -> qkv_out), (key_cache -> key_cache_out), (value_cache -> value_cache_out)
support_dygraph_mode : true
data_transform :
skip_transform : max_enc_len_this_time, max_dec_len_this_time

- op : bn_act_xpu
args : (Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, float momentum, float epsilon, str data_format, int act_type)
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ void FusedMultiTransformerInferMeta(
out->set_dims(x.dims());
}

void BlhaGetMaxLenInferMeta(const MetaTensor& seq_lens_encoder,
const MetaTensor& seq_lens_decoder,
const MetaTensor& batch_size,
MetaTensor* max_enc_len_this_time,
MetaTensor* max_dec_len_this_time) {
max_enc_len_this_time->set_dims({1});
max_enc_len_this_time->set_dtype(phi::DataType::INT32);
max_dec_len_this_time->set_dims({1});
max_dec_len_this_time->set_dtype(phi::DataType::INT32);
}

void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
const MetaTensor& key_cache,
const MetaTensor& value_cache,
Expand All @@ -256,6 +267,8 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
const MetaTensor& qkv_bias,
const MetaTensor& out_shift,
const MetaTensor& out_smooth,
const MetaTensor& max_enc_len_this_time,
const MetaTensor& max_dec_len_this_time,
int max_seq_len,
int block_size,
bool use_neox_style,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ void GroupNormalizeSiluXPUInferMeta(const MetaTensor& x,
float epsilon,
MetaTensor* out);

void BlhaGetMaxLenInferMeta(const MetaTensor& seq_lens_encoder,
const MetaTensor& seq_lens_decoder,
const MetaTensor& batch_size,
MetaTensor* max_enc_len_this_time,
MetaTensor* max_dec_len_this_time);

void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
const MetaTensor& key_cache,
const MetaTensor& value_cache,
Expand All @@ -101,6 +107,8 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
const MetaTensor& qkv_bias,
const MetaTensor& out_shift,
const MetaTensor& out_smooth,
const MetaTensor& max_enc_len_this_time,
const MetaTensor& max_dec_len_this_time,
int max_seq_len,
int block_size,
bool use_neox_style,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ if(WITH_ROCM)
"gpu/qr_kernel.cu"
"gpu/svd_kernel.cu"
"gpudnn/mha_cudnn_frontend.cu"
"fusion/gpu/blha_get_max_len.cu"
"fusion/gpu/block_multi_head_attention_kernel.cu"
"fusion/gpu/fused_bn_add_activation_grad_kernel.cu"
"fusion/gpu/fused_bn_add_activation_kernel.cu"
Expand Down
68 changes: 68 additions & 0 deletions paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/fusion/cutlass/variable_length_memory_efficient_attention.h"
#include "paddle/phi/kernels/fusion/gpu/block_attn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#include "paddle/phi/kernels/memcpy_kernel.h"
#include "paddle/utils/none.h"

namespace phi {
namespace fusion {

void GetMaxLenTensor(const phi::GPUContext& dev_ctx,
const phi::DenseTensor& seq_lens_tensor,
const phi::DenseTensor& batch_size,
DenseTensor* out) {
phi::DenseTensor max_len_tensor;
max_len_tensor.Resize({{1}});
auto* max_len_tensor_data = dev_ctx.template Alloc<int>(
&max_len_tensor, max_len_tensor.numel() * sizeof(int));
const int bsz = batch_size.dims()[0];
constexpr int blockSize = 128;
int max_len_cpu = 0;
GetMaxLenKernel<blockSize><<<1, blockSize, 0, dev_ctx.stream()>>>(
seq_lens_tensor.data<int>(), max_len_tensor.data<int>(), bsz);
MemcpyD2HKernel(dev_ctx, max_len_tensor, 0, out);
}

template <typename T, typename Context>
void BlhaGetMaxLenKernel(const Context& dev_ctx,
const DenseTensor& seq_lens_encoder,
const DenseTensor& seq_lens_decoder,
const phi::DenseTensor& batch_size,
DenseTensor* max_enc_len_this_time,
DenseTensor* max_dec_len_this_time) {
// decoder
max_dec_len_this_time->Resize({{1}});
GetMaxLenTensor(dev_ctx, seq_lens_decoder, batch_size, max_dec_len_this_time);

// encoder
max_enc_len_this_time->Resize({{1}});
GetMaxLenTensor(dev_ctx, seq_lens_encoder, batch_size, max_enc_len_this_time);
}
} // namespace fusion
} // namespace phi

PD_REGISTER_KERNEL(blha_get_max_len,
GPU,
ALL_LAYOUT,
phi::fusion::BlhaGetMaxLenKernel,
int,
int64_t) {}
17 changes: 0 additions & 17 deletions paddle/phi/kernels/fusion/gpu/block_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -2717,23 +2717,6 @@ __global__ void GetMaxLenKernel(const int *seq_lens,
}
}

int GetMaxLen(const phi::GPUContext &dev_ctx,
const phi::DenseTensor &seq_lens_tensor,
phi::DenseTensor *max_len_tensor,
const int batch_size) {
constexpr int blockSize = 128;
int max_len_cpu = 0;
GetMaxLenKernel<blockSize><<<1, blockSize, 0, dev_ctx.stream()>>>(
seq_lens_tensor.data<int>(), max_len_tensor->data<int>(), batch_size);
memory_utils::Copy(phi::CPUPlace(),
&max_len_cpu,
dev_ctx.GetPlace(),
max_len_tensor->data<int>(),
sizeof(int),
dev_ctx.stream());
return max_len_cpu;
}

template <typename T, int VecSize>
__global__ void InitOutValueKernel(T *output_data,
const int64_t numel,
Expand Down

0 comments on commit 669a261

Please sign in to comment.