Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse qkv and fix bart decoding #5111

Merged
merged 2 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions paddlenlp/ops/fast_transformer/src/fusion_bart_decoding_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,21 @@ std::vector<paddle::Tensor> BartDecodingForward(
const int& beam_size,
const int& topk,
const float& topp,
const float& temperature,
const int& n_head,
const int& size_per_head,
const int& num_layer,
const int& bos_id,
const int& eos_id,
const int64_t& max_len,
const int64_t& min_len,
const float& beam_search_diversity_rate,
const bool& rel_len,
const float& alpha,
const bool& early_stopping) {
int batch_size = input.shape()[0];
int max_out_len = rel_len ? max_len + input.shape()[1] : max_len;
int min_out_len = rel_len ? min_len + input.shape()[1] : min_len;

std::vector<int64_t> output_dims;
std::vector<int64_t> parent_ids_dims;
Expand Down Expand Up @@ -153,12 +156,14 @@ std::vector<paddle::Tensor> BartDecodingForward(
beam_size,
topk,
topp,
temperature,
n_head,
size_per_head,
num_layer,
bos_id,
eos_id,
max_out_len,
min_out_len,
beam_search_diversity_rate,
alpha,
early_stopping);
Expand Down Expand Up @@ -206,12 +211,14 @@ std::vector<std::vector<int64_t>> BartDecodingInferShape(
const int& beam_size,
const int& topk,
const float& topp,
const float& temperature,
const int& n_head,
const int& size_per_head,
const int& num_layer,
const int& bos_id,
const int& eos_id,
const int64_t& max_len,
const int64_t& min_len,
const float& beam_search_diversity_rate,
const bool& rel_len,
const float& alpha,
Expand Down Expand Up @@ -328,12 +335,14 @@ PD_BUILD_OP(fusion_bart_decoding)
"beam_size: int",
"topk: int",
"topp: float",
"temperature: float",
"n_head: int",
"size_per_head: int",
"num_layer: int",
"bos_id: int",
"eos_id: int",
"max_len: int64_t",
"min_len: int64_t",
"beam_search_diversity_rate: float",
"rel_len: bool",
"alpha: float",
Expand Down
127 changes: 86 additions & 41 deletions paddlenlp/ops/fast_transformer/src/fusion_bart_decoding_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ limitations under the License. */
#include <sstream>
#include <vector>

#include "cublas_handle.h"
#include "fastertransformer/cuda/cub/cub.cuh"
#include "fusion_bart_decoding_op.h"
#include "pd_traits.h"


template <paddle::DataType D>
std::vector<paddle::Tensor> bart_decoding_kernel(
const paddle::Tensor& input,
Expand Down Expand Up @@ -69,17 +69,17 @@ std::vector<paddle::Tensor> bart_decoding_kernel(
const int& beam_size,
const int& topk,
const float& topp,
const float& temperature,
const int& head_num_,
const int& size_per_head_,
const int& num_layer_,
const int& start_id_,
const int& end_id_,
const int64_t& max_seq_len_,
const int64_t& min_seq_len_,
const float& beam_search_diversity_rate_,
const float& alpha,
const bool& early_stopping,
cublasHandle_t cublas_handle_,
cublasLtHandle_t cublaslt_handle_,
cudaStream_t stream) {
int beam_width_ = (decoding_strategy == "beam_search" ||
decoding_strategy == "beam_search_v2" ||
Expand Down Expand Up @@ -110,8 +110,9 @@ std::vector<paddle::Tensor> bart_decoding_kernel(
typedef typename traits_::data_t data_t_;

DecodingInitParam<DataType_> decoding_params;
decoding_params.cublas_handle = cublas_handle_;
decoding_params.cublaslt_handle = cublaslt_handle_;
decoding_params.cublas_handle = CublasHandle::GetInstance()->cublas_handle_;
decoding_params.cublaslt_handle =
CublasHandle::GetInstance()->cublaslt_handle_;

decoding_params.output_ids = output_ids.mutable_data<int>(input.place());
decoding_params.parent_ids = parent_ids.mutable_data<int>(input.place());
Expand All @@ -126,13 +127,25 @@ std::vector<paddle::Tensor> bart_decoding_kernel(
reinterpret_cast<const DataType_*>(input.data<data_t_>());
decoding_params.memory_sequence_length = memory_sequence_length.data<int>();

//TODO(gongenlei): Support MP & PP
TensorParallelParam tensor_parallel_param;
LayerParallelParam layer_parallel_param;
tensor_parallel_param.rank = 0;
tensor_parallel_param.world_size = 1;
tensor_parallel_param.local_head_num_ = head_num_;
tensor_parallel_param.local_hidden_units_ = memory_hidden_dim;
layer_parallel_param.rank = 0;
layer_parallel_param.world_size = 1;
layer_parallel_param.layers_per_group = num_layer_;
layer_parallel_param.local_batch_size = batch_size_;

DecoderInitParam<DataType_>* params =
new DecoderInitParam<DataType_>[num_layer_];

for (int i = 0; i < num_layer_; i++) {
params[i].stream = stream;
params[i].cublas_handle = cublas_handle_;
params[i].cublaslt_handle = cublaslt_handle_;
params[i].cublas_handle = CublasHandle::GetInstance()->cublas_handle_;
params[i].cublaslt_handle = CublasHandle::GetInstance()->cublaslt_handle_;

if (decoding_strategy == "beam_search" ||
decoding_strategy == "beam_search_v2" ||
Expand All @@ -158,20 +171,28 @@ std::vector<paddle::Tensor> bart_decoding_kernel(
params[i].self_attention.query_weight.bias =
reinterpret_cast<const DataType_*>(
self_attn_query_bias[i].data<data_t_>());
// // key
// params[i].self_attention.key_weight.kernel =
// reinterpret_cast<const DataType_*>(
// self_attn_key_weight[i].data<data_t_>());
// params[i].self_attention.key_weight.bias =
// reinterpret_cast<const DataType_*>(
// self_attn_key_bias[i].data<data_t_>());
// // value
// params[i].self_attention.value_weight.kernel =
// reinterpret_cast<const DataType_*>(
// self_attn_value_weight[i].data<data_t_>());
// params[i].self_attention.value_weight.bias =
// reinterpret_cast<const DataType_*>(
// self_attn_value_bias[i].data<data_t_>());

// key
params[i].self_attention.key_weight.kernel =
reinterpret_cast<const DataType_*>(
self_attn_key_weight[i].data<data_t_>());
params[i].self_attention.key_weight.bias =
reinterpret_cast<const DataType_*>(
self_attn_key_bias[i].data<data_t_>());
params[i].self_attention.key_weight.kernel = nullptr;
params[i].self_attention.key_weight.bias = nullptr;
// value
params[i].self_attention.value_weight.kernel =
reinterpret_cast<const DataType_*>(
self_attn_value_weight[i].data<data_t_>());
params[i].self_attention.value_weight.bias =
reinterpret_cast<const DataType_*>(
self_attn_value_bias[i].data<data_t_>());
params[i].self_attention.value_weight.kernel = nullptr;
params[i].self_attention.value_weight.bias = nullptr;

// out proj
params[i].self_attention.attention_output_weight.kernel =
reinterpret_cast<const DataType_*>(
Expand Down Expand Up @@ -269,12 +290,23 @@ std::vector<paddle::Tensor> bart_decoding_kernel(
end_id_,
beam_search_diversity_rate_,
true, /*is_fuse_topk_softMax*/
false, /*is_fuse_qkv*/
true, /*is_fuse_qkv*/
false, /*keep_alive_beam*/
alpha,
alpha,
false, /*normalization_before*/
2,
ActivationType::GELU);
2, /*pos_offset*/
ActivationType::GELU,
false, /*pos_bias*/
false, /*prefix_lm*/
-1, /*finished_candidate_num*/
false, /*early_stopping*/
false, /*is_mbart*/
min_seq_len_);

decoding_beamsearch_->set_tensor_parallel_param(
tensor_parallel_param);
decoding_beamsearch_->set_layer_parallel_param(
layer_parallel_param);

decoding_beamsearch_->forward(params, decoding_params);

Expand All @@ -297,7 +329,7 @@ std::vector<paddle::Tensor> bart_decoding_kernel(
end_id_,
beam_search_diversity_rate_,
true, /*is_fuse_topk_softMax*/
false, /*is_fuse_qkv*/
true, /*is_fuse_qkv*/
true, /*keep_alive_beam*/
alpha,
false, /*normalization_before*/
Expand All @@ -306,7 +338,14 @@ std::vector<paddle::Tensor> bart_decoding_kernel(
false, /*pos_bias*/
false, /*prefix_lm*/
finished_candidate_num_,
early_stopping);
early_stopping,
false, /*is_mbart*/
min_seq_len_);

decoding_beamsearch_->set_tensor_parallel_param(
tensor_parallel_param);
decoding_beamsearch_->set_layer_parallel_param(
layer_parallel_param);

decoding_beamsearch_->forward(params, decoding_params);

Expand All @@ -329,10 +368,20 @@ std::vector<paddle::Tensor> bart_decoding_kernel(
end_id_,
candidate_num_,
probability_threshold_,
false,
false,
2,
ActivationType::GELU);
true, /*is_fuse_qkv*/
false, /*normalization_before*/
2, /*pos_offset*/
ActivationType::GELU,
false, /*pos_bias*/
temperature, /*temperature*/
1.0, /*repeat_penalty*/
false, /*prefix_lm*/
false, /*is_mbart*/
min_seq_len_);
decoding_sampling_->set_tensor_parallel_param(
tensor_parallel_param);
decoding_sampling_->set_layer_parallel_param(
layer_parallel_param);

decoding_sampling_->forward(params, decoding_params);

Expand Down Expand Up @@ -389,21 +438,20 @@ std::vector<paddle::Tensor> BartDecodingCUDAForward(
const int& beam_size,
const int& topk,
const float& topp,
const float& temperature,
const int& n_head,
const int& size_per_head,
const int& num_layer,
const int& bos_id,
const int& eos_id,
const int64_t& max_len,
const int64_t& min_len,
const float& beam_search_diversity_rate,
const float& alpha,
const bool& early_stopping) {
auto stream = input.stream();
cublasHandle_t cublas_handle_;
cublasCreate(&cublas_handle_);
cublasLtHandle_t cublaslt_handle_;
cublasLtCreate(&cublaslt_handle_);
cublasSetStream(cublas_handle_, stream);

cublasSetStream(CublasHandle::GetInstance()->cublas_handle_, stream);

std::vector<paddle::Tensor> ret;

Expand Down Expand Up @@ -451,17 +499,17 @@ std::vector<paddle::Tensor> BartDecodingCUDAForward(
beam_size,
topk,
topp,
temperature,
n_head,
size_per_head,
num_layer,
bos_id,
eos_id,
max_len,
min_len,
beam_search_diversity_rate,
alpha,
early_stopping,
cublas_handle_,
cublaslt_handle_,
stream);
break;
}
Expand Down Expand Up @@ -508,17 +556,17 @@ std::vector<paddle::Tensor> BartDecodingCUDAForward(
beam_size,
topk,
topp,
temperature,
n_head,
size_per_head,
num_layer,
bos_id,
eos_id,
max_len,
min_len,
beam_search_diversity_rate,
alpha,
early_stopping,
cublas_handle_,
cublaslt_handle_,
stream);
break;
}
Expand All @@ -529,8 +577,5 @@ std::vector<paddle::Tensor> BartDecodingCUDAForward(
break;
}
}

cublasDestroy(cublas_handle_);
cublasLtDestroy(cublaslt_handle_);
return ret;
}
2 changes: 2 additions & 0 deletions paddlenlp/ops/fast_transformer/src/fusion_bart_decoding_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ std::vector<paddle::Tensor> BartDecodingCUDAForward(
const int& beam_size,
const int& topk,
const float& topp,
const float& temperature,
const int& n_head,
const int& size_per_head,
const int& num_layer,
const int& bos_id,
const int& eos_id,
const int64_t& max_len,
const int64_t& min_len,
const float& beam_search_diversity_rate,
const float& alpha,
const bool& early_stopping);
Loading