Skip to content

Commit

Permalink
Fuse qkv and fix bart decoding (#5111)
Browse files Browse the repository at this point in the history
* fix bart decoding and fuse qkv

* add more args
  • Loading branch information
gongenlei committed Mar 6, 2023
1 parent e7e4fa5 commit cfc5a47
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 167 deletions.
9 changes: 9 additions & 0 deletions paddlenlp/ops/fast_transformer/src/fusion_bart_decoding_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,21 @@ std::vector<paddle::Tensor> BartDecodingForward(
const int& beam_size,
const int& topk,
const float& topp,
const float& temperature,
const int& n_head,
const int& size_per_head,
const int& num_layer,
const int& bos_id,
const int& eos_id,
const int64_t& max_len,
const int64_t& min_len,
const float& beam_search_diversity_rate,
const bool& rel_len,
const float& alpha,
const bool& early_stopping) {
int batch_size = input.shape()[0];
int max_out_len = rel_len ? max_len + input.shape()[1] : max_len;
int min_out_len = rel_len ? min_len + input.shape()[1] : min_len;

std::vector<int64_t> output_dims;
std::vector<int64_t> parent_ids_dims;
Expand Down Expand Up @@ -153,12 +156,14 @@ std::vector<paddle::Tensor> BartDecodingForward(
beam_size,
topk,
topp,
temperature,
n_head,
size_per_head,
num_layer,
bos_id,
eos_id,
max_out_len,
min_out_len,
beam_search_diversity_rate,
alpha,
early_stopping);
Expand Down Expand Up @@ -206,12 +211,14 @@ std::vector<std::vector<int64_t>> BartDecodingInferShape(
const int& beam_size,
const int& topk,
const float& topp,
const float& temperature,
const int& n_head,
const int& size_per_head,
const int& num_layer,
const int& bos_id,
const int& eos_id,
const int64_t& max_len,
const int64_t& min_len,
const float& beam_search_diversity_rate,
const bool& rel_len,
const float& alpha,
Expand Down Expand Up @@ -328,12 +335,14 @@ PD_BUILD_OP(fusion_bart_decoding)
"beam_size: int",
"topk: int",
"topp: float",
"temperature: float",
"n_head: int",
"size_per_head: int",
"num_layer: int",
"bos_id: int",
"eos_id: int",
"max_len: int64_t",
"min_len: int64_t",
"beam_search_diversity_rate: float",
"rel_len: bool",
"alpha: float",
Expand Down
127 changes: 86 additions & 41 deletions paddlenlp/ops/fast_transformer/src/fusion_bart_decoding_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ limitations under the License. */
#include <sstream>
#include <vector>

#include "cublas_handle.h"
#include "fastertransformer/cuda/cub/cub.cuh"
#include "fusion_bart_decoding_op.h"
#include "pd_traits.h"


template <paddle::DataType D>
std::vector<paddle::Tensor> bart_decoding_kernel(
const paddle::Tensor& input,
Expand Down Expand Up @@ -69,17 +69,17 @@ std::vector<paddle::Tensor> bart_decoding_kernel(
const int& beam_size,
const int& topk,
const float& topp,
const float& temperature,
const int& head_num_,
const int& size_per_head_,
const int& num_layer_,
const int& start_id_,
const int& end_id_,
const int64_t& max_seq_len_,
const int64_t& min_seq_len_,
const float& beam_search_diversity_rate_,
const float& alpha,
const bool& early_stopping,
cublasHandle_t cublas_handle_,
cublasLtHandle_t cublaslt_handle_,
cudaStream_t stream) {
int beam_width_ = (decoding_strategy == "beam_search" ||
decoding_strategy == "beam_search_v2" ||
Expand Down Expand Up @@ -110,8 +110,9 @@ std::vector<paddle::Tensor> bart_decoding_kernel(
typedef typename traits_::data_t data_t_;

DecodingInitParam<DataType_> decoding_params;
decoding_params.cublas_handle = cublas_handle_;
decoding_params.cublaslt_handle = cublaslt_handle_;
decoding_params.cublas_handle = CublasHandle::GetInstance()->cublas_handle_;
decoding_params.cublaslt_handle =
CublasHandle::GetInstance()->cublaslt_handle_;

decoding_params.output_ids = output_ids.mutable_data<int>(input.place());
decoding_params.parent_ids = parent_ids.mutable_data<int>(input.place());
Expand All @@ -126,13 +127,25 @@ std::vector<paddle::Tensor> bart_decoding_kernel(
reinterpret_cast<const DataType_*>(input.data<data_t_>());
decoding_params.memory_sequence_length = memory_sequence_length.data<int>();

//TODO(gongenlei): Support MP & PP
TensorParallelParam tensor_parallel_param;
LayerParallelParam layer_parallel_param;
tensor_parallel_param.rank = 0;
tensor_parallel_param.world_size = 1;
tensor_parallel_param.local_head_num_ = head_num_;
tensor_parallel_param.local_hidden_units_ = memory_hidden_dim;
layer_parallel_param.rank = 0;
layer_parallel_param.world_size = 1;
layer_parallel_param.layers_per_group = num_layer_;
layer_parallel_param.local_batch_size = batch_size_;

DecoderInitParam<DataType_>* params =
new DecoderInitParam<DataType_>[num_layer_];

for (int i = 0; i < num_layer_; i++) {
params[i].stream = stream;
params[i].cublas_handle = cublas_handle_;
params[i].cublaslt_handle = cublaslt_handle_;
params[i].cublas_handle = CublasHandle::GetInstance()->cublas_handle_;
params[i].cublaslt_handle = CublasHandle::GetInstance()->cublaslt_handle_;

if (decoding_strategy == "beam_search" ||
decoding_strategy == "beam_search_v2" ||
Expand All @@ -158,20 +171,28 @@ std::vector<paddle::Tensor> bart_decoding_kernel(
params[i].self_attention.query_weight.bias =
reinterpret_cast<const DataType_*>(
self_attn_query_bias[i].data<data_t_>());
// // key
// params[i].self_attention.key_weight.kernel =
// reinterpret_cast<const DataType_*>(
// self_attn_key_weight[i].data<data_t_>());
// params[i].self_attention.key_weight.bias =
// reinterpret_cast<const DataType_*>(
// self_attn_key_bias[i].data<data_t_>());
// // value
// params[i].self_attention.value_weight.kernel =
// reinterpret_cast<const DataType_*>(
// self_attn_value_weight[i].data<data_t_>());
// params[i].self_attention.value_weight.bias =
// reinterpret_cast<const DataType_*>(
// self_attn_value_bias[i].data<data_t_>());

// key
params[i].self_attention.key_weight.kernel =
reinterpret_cast<const DataType_*>(
self_attn_key_weight[i].data<data_t_>());
params[i].self_attention.key_weight.bias =
reinterpret_cast<const DataType_*>(
self_attn_key_bias[i].data<data_t_>());
params[i].self_attention.key_weight.kernel = nullptr;
params[i].self_attention.key_weight.bias = nullptr;
// value
params[i].self_attention.value_weight.kernel =
reinterpret_cast<const DataType_*>(
self_attn_value_weight[i].data<data_t_>());
params[i].self_attention.value_weight.bias =
reinterpret_cast<const DataType_*>(
self_attn_value_bias[i].data<data_t_>());
params[i].self_attention.value_weight.kernel = nullptr;
params[i].self_attention.value_weight.bias = nullptr;

// out proj
params[i].self_attention.attention_output_weight.kernel =
reinterpret_cast<const DataType_*>(
Expand Down Expand Up @@ -269,12 +290,23 @@ std::vector<paddle::Tensor> bart_decoding_kernel(
end_id_,
beam_search_diversity_rate_,
true, /*is_fuse_topk_softMax*/
false, /*is_fuse_qkv*/
true, /*is_fuse_qkv*/
false, /*keep_alive_beam*/
alpha,
alpha,
false, /*normalization_before*/
2,
ActivationType::GELU);
2, /*pos_offset*/
ActivationType::GELU,
false, /*pos_bias*/
false, /*prefix_lm*/
-1, /*finished_candidate_num*/
false, /*early_stopping*/
false, /*is_mbart*/
min_seq_len_);

decoding_beamsearch_->set_tensor_parallel_param(
tensor_parallel_param);
decoding_beamsearch_->set_layer_parallel_param(
layer_parallel_param);

decoding_beamsearch_->forward(params, decoding_params);

Expand All @@ -297,7 +329,7 @@ std::vector<paddle::Tensor> bart_decoding_kernel(
end_id_,
beam_search_diversity_rate_,
true, /*is_fuse_topk_softMax*/
false, /*is_fuse_qkv*/
true, /*is_fuse_qkv*/
true, /*keep_alive_beam*/
alpha,
false, /*normalization_before*/
Expand All @@ -306,7 +338,14 @@ std::vector<paddle::Tensor> bart_decoding_kernel(
false, /*pos_bias*/
false, /*prefix_lm*/
finished_candidate_num_,
early_stopping);
early_stopping,
false, /*is_mbart*/
min_seq_len_);

decoding_beamsearch_->set_tensor_parallel_param(
tensor_parallel_param);
decoding_beamsearch_->set_layer_parallel_param(
layer_parallel_param);

decoding_beamsearch_->forward(params, decoding_params);

Expand All @@ -329,10 +368,20 @@ std::vector<paddle::Tensor> bart_decoding_kernel(
end_id_,
candidate_num_,
probability_threshold_,
false,
false,
2,
ActivationType::GELU);
true, /*is_fuse_qkv*/
false, /*normalization_before*/
2, /*pos_offset*/
ActivationType::GELU,
false, /*pos_bias*/
temperature, /*temperature*/
1.0, /*repeat_penalty*/
false, /*prefix_lm*/
false, /*is_mbart*/
min_seq_len_);
decoding_sampling_->set_tensor_parallel_param(
tensor_parallel_param);
decoding_sampling_->set_layer_parallel_param(
layer_parallel_param);

decoding_sampling_->forward(params, decoding_params);

Expand Down Expand Up @@ -389,21 +438,20 @@ std::vector<paddle::Tensor> BartDecodingCUDAForward(
const int& beam_size,
const int& topk,
const float& topp,
const float& temperature,
const int& n_head,
const int& size_per_head,
const int& num_layer,
const int& bos_id,
const int& eos_id,
const int64_t& max_len,
const int64_t& min_len,
const float& beam_search_diversity_rate,
const float& alpha,
const bool& early_stopping) {
auto stream = input.stream();
cublasHandle_t cublas_handle_;
cublasCreate(&cublas_handle_);
cublasLtHandle_t cublaslt_handle_;
cublasLtCreate(&cublaslt_handle_);
cublasSetStream(cublas_handle_, stream);

cublasSetStream(CublasHandle::GetInstance()->cublas_handle_, stream);

std::vector<paddle::Tensor> ret;

Expand Down Expand Up @@ -451,17 +499,17 @@ std::vector<paddle::Tensor> BartDecodingCUDAForward(
beam_size,
topk,
topp,
temperature,
n_head,
size_per_head,
num_layer,
bos_id,
eos_id,
max_len,
min_len,
beam_search_diversity_rate,
alpha,
early_stopping,
cublas_handle_,
cublaslt_handle_,
stream);
break;
}
Expand Down Expand Up @@ -508,17 +556,17 @@ std::vector<paddle::Tensor> BartDecodingCUDAForward(
beam_size,
topk,
topp,
temperature,
n_head,
size_per_head,
num_layer,
bos_id,
eos_id,
max_len,
min_len,
beam_search_diversity_rate,
alpha,
early_stopping,
cublas_handle_,
cublaslt_handle_,
stream);
break;
}
Expand All @@ -529,8 +577,5 @@ std::vector<paddle::Tensor> BartDecodingCUDAForward(
break;
}
}

cublasDestroy(cublas_handle_);
cublasLtDestroy(cublaslt_handle_);
return ret;
}
2 changes: 2 additions & 0 deletions paddlenlp/ops/fast_transformer/src/fusion_bart_decoding_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ std::vector<paddle::Tensor> BartDecodingCUDAForward(
const int& beam_size,
const int& topk,
const float& topp,
const float& temperature,
const int& n_head,
const int& size_per_head,
const int& num_layer,
const int& bos_id,
const int& eos_id,
const int64_t& max_len,
const int64_t& min_len,
const float& beam_search_diversity_rate,
const float& alpha,
const bool& early_stopping);
Loading

0 comments on commit cfc5a47

Please sign in to comment.