Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference]add blha_get_max_len op & modify block_multihead_attention op #64246

Merged
merged 10 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 14 additions & 2 deletions paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,29 @@
func : addcmul_xpu
data_type : x

- op : blha_get_max_len
args : (Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor batch_size)
output : Tensor(max_enc_len_this_time), Tensor(max_dec_len_this_time)
infer_meta :
func : BlhaGetMaxLenInferMeta
kernel :
func : blha_get_max_len
data_type : seq_lens_encoder
support_dygraph_mode : true

- op : block_multihead_attention_
args : (Tensor qkv, Tensor key_cache, Tensor value_cache, Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor seq_lens_this_time, Tensor padding_offsets, Tensor cum_offsets, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor block_tables, Tensor pre_key_cache, Tensor pre_value_cache, Tensor rope_emb, Tensor mask, Tensor tgt_mask, Tensor cache_k_quant_scales, Tensor cache_v_quant_scales, Tensor cache_k_dequant_scales, Tensor cache_v_dequant_scales, Tensor qkv_out_scale, Tensor qkv_bias, Tensor out_shift, Tensor out_smooth, int max_seq_len, int block_size, bool use_neox_style, bool dynamic_cachekv_quant=false, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0, float out_scale=-1, str compute_dtype = "default")
args : (Tensor qkv, Tensor key_cache, Tensor value_cache, Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor seq_lens_this_time, Tensor padding_offsets, Tensor cum_offsets, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor block_tables, Tensor pre_key_cache, Tensor pre_value_cache, Tensor rope_emb, Tensor mask, Tensor tgt_mask, Tensor cache_k_quant_scales, Tensor cache_v_quant_scales, Tensor cache_k_dequant_scales, Tensor cache_v_dequant_scales, Tensor qkv_out_scale, Tensor qkv_bias, Tensor out_shift, Tensor out_smooth, Tensor max_enc_len_this_time, Tensor max_dec_len_this_time, int max_seq_len, int block_size, bool use_neox_style, bool dynamic_cachekv_quant=false, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0, float out_scale=-1, str compute_dtype = "default")
output : Tensor(fmha_out), Tensor(qkv_out), Tensor(key_cache_out), Tensor(value_cache_out)
infer_meta :
func : BlockMultiheadAttentionInferMeta
kernel :
func : block_multihead_attention
data_type : qkv
optional : pre_key_cache, pre_value_cache, rope_emb, mask, tgt_mask, cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales, cache_v_dequant_scales, qkv_out_scale, qkv_bias, out_shift, out_smooth
optional : pre_key_cache, pre_value_cache, rope_emb, mask, tgt_mask, cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales, cache_v_dequant_scales, qkv_out_scale, qkv_bias, out_shift, out_smooth, max_enc_len_this_time, max_dec_len_this_time
inplace : (qkv -> qkv_out), (key_cache -> key_cache_out), (value_cache -> value_cache_out)
support_dygraph_mode : true
data_transform :
skip_transform : max_enc_len_this_time, max_dec_len_this_time

- op : bn_act_xpu
args : (Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, float momentum, float epsilon, str data_format, int act_type)
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ void FusedMultiTransformerInferMeta(
out->set_dims(x.dims());
}

void BlhaGetMaxLenInferMeta(const MetaTensor& seq_lens_encoder,
const MetaTensor& seq_lens_decoder,
const MetaTensor& batch_size,
MetaTensor* max_enc_len_this_time,
MetaTensor* max_dec_len_this_time) {
max_enc_len_this_time->set_dims({1});
max_enc_len_this_time->set_dtype(phi::DataType::INT32);
max_dec_len_this_time->set_dims({1});
max_dec_len_this_time->set_dtype(phi::DataType::INT32);
}

void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
const MetaTensor& key_cache,
const MetaTensor& value_cache,
Expand All @@ -256,6 +267,8 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
const MetaTensor& qkv_bias,
const MetaTensor& out_shift,
const MetaTensor& out_smooth,
const MetaTensor& max_enc_len_this_time,
const MetaTensor& max_dec_len_this_time,
int max_seq_len,
int block_size,
bool use_neox_style,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ void GroupNormalizeSiluXPUInferMeta(const MetaTensor& x,
float epsilon,
MetaTensor* out);

void BlhaGetMaxLenInferMeta(const MetaTensor& seq_lens_encoder,
const MetaTensor& seq_lens_decoder,
const MetaTensor& batch_size,
MetaTensor* max_enc_len_this_time,
MetaTensor* max_dec_len_this_time);

void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
const MetaTensor& key_cache,
const MetaTensor& value_cache,
Expand All @@ -101,6 +107,8 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
const MetaTensor& qkv_bias,
const MetaTensor& out_shift,
const MetaTensor& out_smooth,
const MetaTensor& max_enc_len_this_time,
const MetaTensor& max_dec_len_this_time,
int max_seq_len,
int block_size,
bool use_neox_style,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ if(WITH_ROCM)
"gpu/qr_kernel.cu"
"gpu/svd_kernel.cu"
"gpudnn/mha_cudnn_frontend.cu"
"fusion/gpu/blha_get_max_len.cu"
"fusion/gpu/block_multi_head_attention_kernel.cu"
"fusion/gpu/fused_bn_add_activation_grad_kernel.cu"
"fusion/gpu/fused_bn_add_activation_kernel.cu"
Expand Down
68 changes: 68 additions & 0 deletions paddle/phi/kernels/fusion/gpu/blha_get_max_len.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/fusion/cutlass/variable_length_memory_efficient_attention.h"
#include "paddle/phi/kernels/fusion/gpu/block_attn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#include "paddle/phi/kernels/memcpy_kernel.h"
#include "paddle/utils/none.h"

namespace phi {
namespace fusion {

void GetMaxLenTensor(const phi::GPUContext& dev_ctx,
const phi::DenseTensor& seq_lens_tensor,
const phi::DenseTensor& batch_size,
DenseTensor* out) {
phi::DenseTensor max_len_tensor;
max_len_tensor.Resize({{1}});
auto* max_len_tensor_data = dev_ctx.template Alloc<int>(
&max_len_tensor, max_len_tensor.numel() * sizeof(int));
const int bsz = batch_size.dims()[0];
constexpr int blockSize = 128;
int max_len_cpu = 0;
GetMaxLenKernel<blockSize><<<1, blockSize, 0, dev_ctx.stream()>>>(
seq_lens_tensor.data<int>(), max_len_tensor.data<int>(), bsz);
MemcpyD2HKernel(dev_ctx, max_len_tensor, 0, out);
}

template <typename T, typename Context>
void BlhaGetMaxLenKernel(const Context& dev_ctx,
const DenseTensor& seq_lens_encoder,
const DenseTensor& seq_lens_decoder,
const phi::DenseTensor& batch_size,
DenseTensor* max_enc_len_this_time,
DenseTensor* max_dec_len_this_time) {
// decoder
max_dec_len_this_time->Resize({{1}});
GetMaxLenTensor(dev_ctx, seq_lens_decoder, batch_size, max_dec_len_this_time);

// encoder
max_enc_len_this_time->Resize({{1}});
GetMaxLenTensor(dev_ctx, seq_lens_encoder, batch_size, max_enc_len_this_time);
}
} // namespace fusion
} // namespace phi

PD_REGISTER_KERNEL(blha_get_max_len,
GPU,
ALL_LAYOUT,
phi::fusion::BlhaGetMaxLenKernel,
int,
int64_t) {}
17 changes: 0 additions & 17 deletions paddle/phi/kernels/fusion/gpu/block_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -2717,23 +2717,6 @@ __global__ void GetMaxLenKernel(const int *seq_lens,
}
}

int GetMaxLen(const phi::GPUContext &dev_ctx,
const phi::DenseTensor &seq_lens_tensor,
phi::DenseTensor *max_len_tensor,
const int batch_size) {
constexpr int blockSize = 128;
int max_len_cpu = 0;
GetMaxLenKernel<blockSize><<<1, blockSize, 0, dev_ctx.stream()>>>(
seq_lens_tensor.data<int>(), max_len_tensor->data<int>(), batch_size);
memory_utils::Copy(phi::CPUPlace(),
&max_len_cpu,
dev_ctx.GetPlace(),
max_len_tensor->data<int>(),
sizeof(int),
dev_ctx.stream());
return max_len_cpu;
}

template <typename T, int VecSize>
__global__ void InitOutValueKernel(T *output_data,
const int64_t numel,
Expand Down