Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

delete csrc/generation/reset_need_stop_value.cc #8413

Merged
merged 2 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 0 additions & 12 deletions csrc/generation/reset_need_stop_value.cc

This file was deleted.

8 changes: 5 additions & 3 deletions csrc/generation/save_with_output_msg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ void SaveOutMmsg(const paddle::Tensor& x,
if (rank_id > 0) return;
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
int64_t *x_data = x_cpu.data<int64_t>();
auto not_need_stop_cpu = not_need_stop.copy_to(paddle::CPUPlace(), false);
bool* not_need_stop_data = not_need_stop_cpu.data<bool>();

static struct msgdata msg_sed;
static key_t key = ftok("./", 1);
static int msgid = msgget(key, IPC_CREAT | 0666);

msg_sed.mtype = 1;
bool not_need_stop_data = not_need_stop.data<bool>()[0];
msg_sed.mtext[0] = not_need_stop_data ? 1 : -1;
msg_sed.mtext[0] = not_need_stop_data[0] ? 1 : -1;
int bsz = x.shape()[0];
msg_sed.mtext[1] = bsz;
for (int i = 2; i < bsz + 2; i++) {
Expand All @@ -55,4 +57,4 @@ PD_BUILD_OP(save_output)
.Attrs({"rank_id: int64_t"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(SaveOutMmsg));
.SetKernelFn(PD_KERNEL(SaveOutMmsg));
99 changes: 55 additions & 44 deletions csrc/generation/update_inputs.cu
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"

template <int THREADBLOCK_SIZE>
__global__ void update_inputs_kernel(
bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int64_t *input_ids,
const int64_t *stop_nums,
const bool *stop_flags,
const bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride) {
__global__ void update_inputs_kernel(bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int64_t *input_ids,
const int64_t *stop_nums,
const bool *stop_flags,
const bool *is_block_step,
const int64_t *next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride) {
int thread_idx = threadIdx.x;
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
Expand All @@ -37,7 +50,10 @@ __global__ void update_inputs_kernel(
const int seq_len_encoder = seq_lens_encoder[thread_idx];
const int seq_len_decoder = seq_lens_decoder[thread_idx];

seq_lens_decoder[thread_idx] = stop_flag_now ? 0 : (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1);
seq_lens_decoder[thread_idx] =
stop_flag_now
? 0
: (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1);

seq_lens_this_time[thread_idx] = stop_flag_now ? 0 : 1;
seq_lens_encoder[thread_idx] = 0;
Expand All @@ -51,43 +67,38 @@ __global__ void update_inputs_kernel(
}
}

void UpdateInputes(const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop, // cpu
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& input_ids,
const paddle::Tensor& stop_nums,
const paddle::Tensor& next_tokens,
const paddle::Tensor& is_block_step) {
void UpdateInputes(const paddle::Tensor &stop_flags,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &input_ids,
const paddle::Tensor &stop_nums,
const paddle::Tensor &next_tokens,
const paddle::Tensor &is_block_step) {
const int max_bsz = stop_flags.shape()[0];
const int now_bsz = seq_lens_this_time.shape()[0];
const int input_ids_stride = input_ids.shape()[1];
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
update_inputs_kernel<1024><<<1, 1024, 0, input_ids.stream()>>>(
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
stop_nums.data<int64_t>(),
stop_flags.data<bool>(),
is_block_step.data<bool>(),
next_tokens.data<int64_t>(),
now_bsz,
max_bsz,
input_ids_stride
);
auto not_need_stop_cpu = not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool *not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
const_cast<bool *>(not_need_stop.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int64_t *>(input_ids.data<int64_t>()),
stop_nums.data<int64_t>(),
stop_flags.data<bool>(),
is_block_step.data<bool>(),
next_tokens.data<int64_t>(),
now_bsz,
max_bsz,
input_ids_stride);
}

PD_BUILD_OP(update_inputs)
.Inputs({"stop_flags",
"not_need_stop",
"seq_lens_this_time",
"seq_lens_encoder",
.Inputs({"stop_flags",
"not_need_stop",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"input_ids",
"stop_nums",
Expand Down
1 change: 0 additions & 1 deletion csrc/setup_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def get_gencode_flags():
"./generation/stop_generation_multi_ends_v2.cu",
"./generation/update_inputs.cu",
"./generation/get_output.cc",
"./generation/reset_need_stop_value.cc",
"./generation/save_with_output_msg.cc",
"./generation/write_int8_cache_kv.cu",
"./generation/step.cu",
Expand Down
21 changes: 5 additions & 16 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,15 @@
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
ChatGLMv2Tokenizer,
ChatGLMTokenizer,
ChatGLMv2Tokenizer,
LlamaTokenizer,
PretrainedModel,
PretrainedTokenizer,
)
from paddlenlp.utils.import_utils import import_module, is_paddlenlp_ops_available
from paddlenlp.utils.log import logger

try:
from paddlenlp_ops import reset_stop_value
except (ImportError, ModuleNotFoundError):
logger.warning(
"if you run predictor.py with --inference_model argument, please ensure you install "
"the paddlenlp_ops by following the instructions "
"provided at https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md"
)


# Note(@RochardWooSJTU): MAX_BSZ must be the same as definition in get_output / save_output
MAX_BSZ = 512

Expand Down Expand Up @@ -242,7 +232,8 @@ def _preprocess(self, source):
padding=True,
# when use chat_template, it should not add special tokens
# chatglm2 prefix-tokens can not be tokenized into ids
add_special_tokens=self.tokenizer.chat_template is None or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)),
add_special_tokens=self.tokenizer.chat_template is None
or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)),
)
return tokenized_source

Expand Down Expand Up @@ -877,7 +868,7 @@ def init_inputs(self, config: PredictorArgument):
self.inputs["seq_lens_encoder"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int32")
self.inputs["seq_lens_decoder"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int32")
self.inputs["step_idx"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int64")
self.inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=False, dtype="bool").cpu()
self.inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=False, dtype="bool")
self.inputs["stop_flags"] = paddle.full(shape=[config.batch_size, 1], fill_value=True, dtype="bool")
self.inputs["next_tokens"] = paddle.full(shape=[config.batch_size, 1], fill_value=-1, dtype="int64")
self.inputs["is_block_step"] = paddle.full(shape=[config.batch_size], fill_value=False, dtype="bool")
Expand Down Expand Up @@ -945,7 +936,7 @@ def _preprocess(self, source):
self.inputs["seq_lens_decoder"][i : i + 1] = 0
self.inputs["step_idx"][i : i + 1] = 0
self.inputs["stop_flags"][i : i + 1] = False
reset_stop_value(self.inputs["not_need_stop"])
self.inputs["not_need_stop"][0] = True
need_block_nums = (
length + self.config.max_length + self.pre_cache_length + self.block_size - 1
) // self.block_size
Expand Down Expand Up @@ -1010,7 +1001,6 @@ def predict(self, input_texts: str | list[str]):
for i in range(self.config.batch_size):
self.free_list.extend(self.used_list[i])
self.used_list[i] = []
reset_stop_value(self.inputs["not_need_stop"])

outputs = []
while len(outputs) < self.batch_size:
Expand Down Expand Up @@ -1147,7 +1137,6 @@ def predict(self, input_texts: str | list[str]):
for i in range(self.config.batch_size):
self.free_list.extend(self.used_list[i])
self.used_list[i] = []
reset_stop_value(self.inputs["not_need_stop"])

outputs = []
while len(outputs) < self.batch_size:
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,15 +671,15 @@

step_idx = paddle.where(model_kwargs["stop_flags"], model_kwargs["step_idx"], model_kwargs["step_idx"] + 1)
paddle.assign(step_idx, model_kwargs["step_idx"])
length_cond = paddle.greater_equal(model_kwargs["step_idx"], model_kwargs["max_dec_len"])
length_cond = paddle.greater_equal(step_idx, model_kwargs["max_dec_len"])

Check warning on line 674 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L674

Added line #L674 was not covered by tests
stop_flags = paddle.logical_or(model_kwargs["stop_flags"], length_cond)
set_stop_value_multi_ends_v2(
next_tokens, stop_flags, model_kwargs["seq_lens_this_time"], eos_token_id, model_kwargs["next_tokens"]
) # multi ends
paddle.assign(stop_flags, model_kwargs["stop_flags"])
# update inputs
update_inputs(
model_kwargs["stop_flags"],
stop_flags,
model_kwargs["not_need_stop"],
model_kwargs["seq_lens_this_time"],
model_kwargs["seq_lens_encoder"],
Expand Down