From a04725402c14462f5bbcff33be06c79fb12beab6 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Fri, 10 May 2024 05:32:39 +0000 Subject: [PATCH] delete csrc/generation/reset_need_stop_value.cc --- csrc/generation/reset_need_stop_value.cc | 12 --- csrc/generation/save_with_output_msg.cc | 8 +- csrc/generation/update_inputs.cu | 99 ++++++++++--------- csrc/setup_cuda.py | 1 - llm/predictor.py | 21 +--- .../transformers/generation_utils.py | 4 +- 6 files changed, 67 insertions(+), 78 deletions(-) delete mode 100644 csrc/generation/reset_need_stop_value.cc diff --git a/csrc/generation/reset_need_stop_value.cc b/csrc/generation/reset_need_stop_value.cc deleted file mode 100644 index 07efb643d06..00000000000 --- a/csrc/generation/reset_need_stop_value.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "paddle/extension.h" - -void SetStopValue(const paddle::Tensor& not_need_stop) { - bool *stop_data = const_cast(not_need_stop.data()); - stop_data[0] = true; -} - -PD_BUILD_OP(reset_stop_value) - .Inputs({"not_need_stop"}) - .Outputs({"not_need_stop_out"}) - .SetInplaceMap({{"not_need_stop", "not_need_stop_out"}}) - .SetKernelFn(PD_KERNEL(SetStopValue)); diff --git a/csrc/generation/save_with_output_msg.cc b/csrc/generation/save_with_output_msg.cc index ea04f8e3e6a..9123578e256 100644 --- a/csrc/generation/save_with_output_msg.cc +++ b/csrc/generation/save_with_output_msg.cc @@ -32,13 +32,15 @@ void SaveOutMmsg(const paddle::Tensor& x, if (rank_id > 0) return; auto x_cpu = x.copy_to(paddle::CPUPlace(), false); int64_t *x_data = x_cpu.data(); + auto not_need_stop_cpu = not_need_stop.copy_to(paddle::CPUPlace(), false); + bool* not_need_stop_data = not_need_stop_cpu.data(); + static struct msgdata msg_sed; static key_t key = ftok("./", 1); static int msgid = msgget(key, IPC_CREAT | 0666); msg_sed.mtype = 1; - bool not_need_stop_data = not_need_stop.data()[0]; - msg_sed.mtext[0] = not_need_stop_data ? 1 : -1; + msg_sed.mtext[0] = not_need_stop_data[0] ? 1 : -1; int bsz = x.shape()[0]; msg_sed.mtext[1] = bsz; for (int i = 2; i < bsz + 2; i++) { @@ -55,4 +57,4 @@ PD_BUILD_OP(save_output) .Attrs({"rank_id: int64_t"}) .Outputs({"x_out"}) .SetInplaceMap({{"x", "x_out"}}) - .SetKernelFn(PD_KERNEL(SaveOutMmsg)); \ No newline at end of file + .SetKernelFn(PD_KERNEL(SaveOutMmsg)); diff --git a/csrc/generation/update_inputs.cu b/csrc/generation/update_inputs.cu index ab9bcde2720..49c2db002ba 100644 --- a/csrc/generation/update_inputs.cu +++ b/csrc/generation/update_inputs.cu @@ -1,19 +1,32 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "helper.h" template -__global__ void update_inputs_kernel( - bool *not_need_stop, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int64_t *input_ids, - const int64_t *stop_nums, - const bool *stop_flags, - const bool *is_block_step, - const int64_t *next_tokens, - const int bsz, - const int max_bsz, - const int input_ids_stride) { +__global__ void update_inputs_kernel(bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int64_t *input_ids, + const int64_t *stop_nums, + const bool *stop_flags, + const bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride) { int thread_idx = threadIdx.x; typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -37,7 +50,10 @@ __global__ void update_inputs_kernel( const int seq_len_encoder = seq_lens_encoder[thread_idx]; const int seq_len_decoder = seq_lens_decoder[thread_idx]; - seq_lens_decoder[thread_idx] = stop_flag_now ? 0 : (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1); + seq_lens_decoder[thread_idx] = + stop_flag_now + ? 0 + : (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1); seq_lens_this_time[thread_idx] = stop_flag_now ? 0 : 1; seq_lens_encoder[thread_idx] = 0; @@ -51,43 +67,38 @@ __global__ void update_inputs_kernel( } } -void UpdateInputes(const paddle::Tensor& stop_flags, - const paddle::Tensor& not_need_stop, // cpu - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& input_ids, - const paddle::Tensor& stop_nums, - const paddle::Tensor& next_tokens, - const paddle::Tensor& is_block_step) { +void UpdateInputes(const paddle::Tensor &stop_flags, + const paddle::Tensor ¬_need_stop, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &input_ids, + const paddle::Tensor &stop_nums, + const paddle::Tensor &next_tokens, + const paddle::Tensor &is_block_step) { const int max_bsz = stop_flags.shape()[0]; const int now_bsz = seq_lens_this_time.shape()[0]; const int input_ids_stride = input_ids.shape()[1]; - auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); update_inputs_kernel<1024><<<1, 1024, 0, input_ids.stream()>>>( - const_cast(not_need_stop_gpu.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(input_ids.data()), - stop_nums.data(), - stop_flags.data(), - is_block_step.data(), - next_tokens.data(), - now_bsz, - max_bsz, - input_ids_stride - ); - auto not_need_stop_cpu = not_need_stop_gpu.copy_to(not_need_stop.place(), false); - bool *not_need_stop_data = const_cast(not_need_stop.data()); - not_need_stop_data[0] = not_need_stop_cpu.data()[0]; + const_cast(not_need_stop.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(input_ids.data()), + stop_nums.data(), + stop_flags.data(), + is_block_step.data(), + next_tokens.data(), + now_bsz, + max_bsz, + input_ids_stride); } PD_BUILD_OP(update_inputs) - .Inputs({"stop_flags", - "not_need_stop", - "seq_lens_this_time", - "seq_lens_encoder", + .Inputs({"stop_flags", + "not_need_stop", + "seq_lens_this_time", + "seq_lens_encoder", "seq_lens_decoder", "input_ids", "stop_nums", diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index e2957ecf650..0b25ef3eac9 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -72,7 +72,6 @@ def get_gencode_flags(): "./generation/stop_generation_multi_ends_v2.cu", "./generation/update_inputs.cu", "./generation/get_output.cc", - "./generation/reset_need_stop_value.cc", "./generation/save_with_output_msg.cc", "./generation/write_int8_cache_kv.cu", "./generation/step.cu", diff --git a/llm/predictor.py b/llm/predictor.py index 3e9f47d8025..2093d22e9f5 100644 --- a/llm/predictor.py +++ b/llm/predictor.py @@ -49,8 +49,8 @@ AutoConfig, AutoModelForCausalLM, AutoTokenizer, - ChatGLMv2Tokenizer, ChatGLMTokenizer, + ChatGLMv2Tokenizer, LlamaTokenizer, PretrainedModel, PretrainedTokenizer, @@ -58,16 +58,6 @@ from paddlenlp.utils.import_utils import import_module, is_paddlenlp_ops_available from paddlenlp.utils.log import logger -try: - from paddlenlp_ops import reset_stop_value -except (ImportError, ModuleNotFoundError): - logger.warning( - "if you run predictor.py with --inference_model argument, please ensure you install " - "the paddlenlp_ops by following the instructions " - "provided at https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md" - ) - - # Note(@RochardWooSJTU): MAX_BSZ must be the same as definition in get_output / save_output MAX_BSZ = 512 @@ -242,7 +232,8 @@ def _preprocess(self, source): padding=True, # when use chat_template, it should not add special tokens # chatglm2 prefix-tokens can not be tokenized into ids - add_special_tokens=self.tokenizer.chat_template is None or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)), + add_special_tokens=self.tokenizer.chat_template is None + or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)), ) return tokenized_source @@ -877,7 +868,7 @@ def init_inputs(self, config: PredictorArgument): self.inputs["seq_lens_encoder"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int32") self.inputs["seq_lens_decoder"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int32") self.inputs["step_idx"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int64") - self.inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=False, dtype="bool").cpu() + self.inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=False, dtype="bool") self.inputs["stop_flags"] = paddle.full(shape=[config.batch_size, 1], fill_value=True, dtype="bool") self.inputs["next_tokens"] = paddle.full(shape=[config.batch_size, 1], fill_value=-1, dtype="int64") self.inputs["is_block_step"] = paddle.full(shape=[config.batch_size], fill_value=False, dtype="bool") @@ -945,7 +936,7 @@ def _preprocess(self, source): self.inputs["seq_lens_decoder"][i : i + 1] = 0 self.inputs["step_idx"][i : i + 1] = 0 self.inputs["stop_flags"][i : i + 1] = False - reset_stop_value(self.inputs["not_need_stop"]) + self.inputs["not_need_stop"][0] = True need_block_nums = ( length + self.config.max_length + self.pre_cache_length + self.block_size - 1 ) // self.block_size @@ -1010,7 +1001,6 @@ def predict(self, input_texts: str | list[str]): for i in range(self.config.batch_size): self.free_list.extend(self.used_list[i]) self.used_list[i] = [] - reset_stop_value(self.inputs["not_need_stop"]) outputs = [] while len(outputs) < self.batch_size: @@ -1147,7 +1137,6 @@ def predict(self, input_texts: str | list[str]): for i in range(self.config.batch_size): self.free_list.extend(self.used_list[i]) self.used_list[i] = [] - reset_stop_value(self.inputs["not_need_stop"]) outputs = [] while len(outputs) < self.batch_size: diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index aff9ac484cf..6dd2a145962 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -671,7 +671,7 @@ def _post_process_( step_idx = paddle.where(model_kwargs["stop_flags"], model_kwargs["step_idx"], model_kwargs["step_idx"] + 1) paddle.assign(step_idx, model_kwargs["step_idx"]) - length_cond = paddle.greater_equal(model_kwargs["step_idx"], model_kwargs["max_dec_len"]) + length_cond = paddle.greater_equal(step_idx, model_kwargs["max_dec_len"]) stop_flags = paddle.logical_or(model_kwargs["stop_flags"], length_cond) set_stop_value_multi_ends_v2( next_tokens, stop_flags, model_kwargs["seq_lens_this_time"], eos_token_id, model_kwargs["next_tokens"] @@ -679,7 +679,7 @@ def _post_process_( paddle.assign(stop_flags, model_kwargs["stop_flags"]) # update inputs update_inputs( - model_kwargs["stop_flags"], + stop_flags, model_kwargs["not_need_stop"], model_kwargs["seq_lens_this_time"], model_kwargs["seq_lens_encoder"],