diff --git a/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token.cc b/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token.cc index 540ee6006bf..a45ba8e0a37 100644 --- a/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token.cc +++ b/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token.cc @@ -17,144 +17,201 @@ #include #include #include +#include "../speculate_msg.h" #include "paddle/extension.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif -#define MAX_BSZ 256 - // #define SAVE_WITH_OUTPUT_DEBUG -#define MAX_DRAFT_TOKENS 6 -struct msgdata { - long mtype; - int mtext[2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS]; // stop_flag, token_num, tokens -}; void MTPSaveFirstToken(const paddle::Tensor& x, - const paddle::Tensor& not_need_stop, - int64_t rank_id, - int msg_queue_id, - bool save_each_rank) { - if (!save_each_rank && rank_id > 0) { - return; - } - int x_dim = x.shape()[1]; - auto x_cpu = x.copy_to(paddle::CPUPlace(), false); - int64_t* x_data = x_cpu.data(); - static struct msgdata msg_sed; - - if (const char* inference_msg_queue_id_env_p = - std::getenv("INFERENCE_MSG_QUEUE_ID")) { - std::string inference_msg_queue_id_env_str( - inference_msg_queue_id_env_p); - int inference_msg_queue_id_from_env = - std::stoi(inference_msg_queue_id_env_str); + const paddle::Tensor& not_need_stop, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& step_idx, + int64_t rank_id, + int msg_queue_id, + bool save_each_rank, + bool skip_chunk_prefill) { + if (!save_each_rank && rank_id > 0) { + return; + } + int x_dim = x.shape()[1]; + auto x_cpu = x.copy_to(paddle::CPUPlace(), false); + int64_t* x_data = x_cpu.data(); + + auto seq_lens_decoder_cpu = + seq_lens_decoder.copy_to(paddle::CPUPlace(), true); + int* seq_lens_decoder_data = seq_lens_decoder_cpu.data(); + + auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true); + int64_t* prompt_lens_data = prompt_lens_cpu.data(); + + auto step_idx_cpu = step_idx.copy_to(paddle::CPUPlace(), true); + int64_t* step_idx_data = step_idx_cpu.data(); + + static struct speculate_msgdata msg_sed; + + if (const char* inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); #ifdef SAVE_WITH_OUTPUT_DEBUG - std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " - << inference_msg_queue_id_from_env << std::endl; + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " + << inference_msg_queue_id_from_env << std::endl; #endif - msg_queue_id = inference_msg_queue_id_from_env; - } + msg_queue_id = inference_msg_queue_id_from_env; + } - static key_t key = ftok("./", msg_queue_id); - static int msgid = msgget(key, IPC_CREAT | 0666); - - msg_sed.mtype = 1; - bool not_need_stop_data = not_need_stop.data()[0]; - int inference_msg_id_from_env = 1; - if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) { - std::string inference_msg_id_env_str(inference_msg_id_env_p); - inference_msg_id_from_env = std::stoi(inference_msg_id_env_str); - if (inference_msg_id_from_env == 2) { - // 2 and -2 is preserve for no-output indication. - throw std::runtime_error( - " INFERENCE_MSG_ID cannot be 2, please use other number."); - } - if (inference_msg_id_from_env < 0) { - throw std::runtime_error( - " INFERENCE_MSG_ID cannot be negative, please use other " - "number."); - } + static key_t key = ftok("./", msg_queue_id); + static int msgid = msgget(key, IPC_CREAT | 0666); + + msg_sed.mtype = 1; + bool not_need_stop_data = not_need_stop.data()[0]; + int inference_msg_id_from_env = 1; + if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) { + std::string inference_msg_id_env_str(inference_msg_id_env_p); + inference_msg_id_from_env = std::stoi(inference_msg_id_env_str); + if (inference_msg_id_from_env == 2) { + // 2 and -2 is preserve for no-output indication. + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be 2, please use other number."); + } + if (inference_msg_id_from_env < 0) { + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be negative, please use other " + "number."); + } #ifdef SAVE_WITH_OUTPUT_DEBUG - std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env - << std::endl; + std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env + << std::endl; #endif - } else { + } else { #ifdef SAVE_WITH_OUTPUT_DEBUG - std::cout - << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default." - << std::endl; + std::cout << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default." + << std::endl; #endif - } + } +#ifdef SAVE_WITH_OUTPUT_DEBUG + std::cout << "save_output_key: " << key << std::endl; + std::cout << "save msgid: " << msgid << std::endl; +#endif + msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env + : -inference_msg_id_from_env; + int bsz = x.shape()[0]; + msg_sed.mtext[1] = bsz; + for (int i = 0; i < bsz; i++) { #ifdef SAVE_WITH_OUTPUT_DEBUG - std::cout << "save_output_key: " << key << std::endl; - std::cout << "save msgid: " << msgid << std::endl; + printf("bid: %d. 1: %d. 2: %d.\n", + i, + (int)x_data[i * x_dim], + (int)x_data[i * x_dim + 1]); #endif - msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env - : -inference_msg_id_from_env; - int bsz = x.shape()[0]; - msg_sed.mtext[1] = bsz; - for (int i = 0; i < bsz; i++) { + if ((skip_chunk_prefill && + seq_lens_decoder_data[i] < prompt_lens_data[i]) || + step_idx_data[i] == 0) { + msg_sed.mtext[i + 2] = 0; #ifdef SAVE_WITH_OUTPUT_DEBUG - printf("bid: %d. 1: %d. 2: %d.\n", i, (int)x_data[i * x_dim], (int)x_data[i * x_dim + 1]); + printf("bid[%d] skip save mtp output \n", i); #endif - msg_sed.mtext[i + 2] = 2; - msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] = (int)x_data[i * x_dim]; - msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] = (int)x_data[i * x_dim + 1]; + continue; + } else if (step_idx_data[i] == 1) { #ifdef SAVE_WITH_OUTPUT_DEBUG - printf("mtext[%d]:%d. mtext[%d]:%d. \n", i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ, - msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ], - i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ, - msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ]); + printf("bid[%d] save mtp tokens \n", i); #endif + msg_sed.mtext[i + 2] = 2; + msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ] = + (int)x_data[i * x_dim]; + msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ] = + (int)x_data[i * x_dim + 1]; } #ifdef SAVE_WITH_OUTPUT_DEBUG - std::cout << "msg data: "; - for (int i = 0; i < bsz; i++) { - std::cout << " " << (int)x_data[2*i] << " "; - std::cout << " " << (int)x_data[2*i + 1]; + printf("mtext[%d]:%d. mtext[%d]:%d. \n", + i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ, + msg_sed.mtext[i * MAX_DRAFT_TOKENS + 2 + MAX_BSZ], + i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ, + msg_sed.mtext[i * MAX_DRAFT_TOKENS + 1 + 2 + MAX_BSZ]); +#endif + } - } - std::cout << std::endl; +#ifdef SAVE_WITH_OUTPUT_DEBUG + std::cout << "msg data: "; + for (int i = 0; i < bsz; i++) { + std::cout << " " << (int)x_data[2 * i] << " "; + std::cout << " " << (int)x_data[2 * i + 1]; + } + std::cout << std::endl; #endif - if ((msgsnd(msgid, - &msg_sed, - (2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS) * 4, 0)) == -1) { - printf("full msg buffer\n"); - } - return; + if ((msgsnd(msgid, + &msg_sed, + (2 + MAX_BSZ + MAX_BSZ * MAX_DRAFT_TOKENS) * 4, + 0)) == -1) { + printf("full msg buffer\n"); + } + return; } void MTPSaveFirstTokenStatic(const paddle::Tensor& x, - const paddle::Tensor& not_need_stop, - int64_t rank_id, - bool save_each_rank) { - MTPSaveFirstToken(x, not_need_stop, rank_id, 1, save_each_rank); + const paddle::Tensor& not_need_stop, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& step_idx, + int64_t rank_id, + bool save_each_rank, + bool skip_chunk_prefill) { + MTPSaveFirstToken(x, + not_need_stop, + seq_lens_decoder, + prompt_lens, + step_idx, + rank_id, + 1, + save_each_rank, + skip_chunk_prefill); } void MTPSaveFirstTokenDynamic(const paddle::Tensor& x, - const paddle::Tensor& not_need_stop, - int64_t rank_id, - int msg_queue_id, - bool save_each_rank) { - MTPSaveFirstToken(x, not_need_stop, rank_id, msg_queue_id, save_each_rank); + const paddle::Tensor& not_need_stop, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& step_idx, + int64_t rank_id, + int msg_queue_id, + bool save_each_rank, + bool skip_chunk_prefill) { + MTPSaveFirstToken(x, + not_need_stop, + seq_lens_decoder, + prompt_lens, + step_idx, + rank_id, + msg_queue_id, + save_each_rank, + skip_chunk_prefill); } PD_BUILD_STATIC_OP(mtp_save_first_token) - .Inputs({"x", "not_need_stop"}) + .Inputs( + {"x", "not_need_stop", "seq_lens_decoder", "prompt_lens", "step_idx"}) .Attrs({"rank_id: int64_t", - "save_each_rank: bool"}) + "save_each_rank: bool", + "skip_chunk_prefill: bool"}) .Outputs({"x_out"}) .SetInplaceMap({{"x", "x_out"}}) .SetKernelFn(PD_KERNEL(MTPSaveFirstTokenStatic)); PD_BUILD_STATIC_OP(mtp_save_first_token_dynamic) - .Inputs({"x", "not_need_stop"}) - .Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"}) + .Inputs( + {"x", "not_need_stop", "seq_lens_decoder", "prompt_lens", "step_idx"}) + .Attrs({"rank_id: int64_t", + "msg_queue_id: int", + "save_each_rank: bool", + "skip_chunk_prefill: bool"}) .Outputs({"x_out"}) .SetInplaceMap({{"x", "x_out"}}) .SetKernelFn(PD_KERNEL(MTPSaveFirstTokenDynamic)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_schedule_cache.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_schedule_cache.cu index 0f44293eaf5..0464601d654 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_schedule_cache.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_schedule_cache.cu @@ -15,85 +15,94 @@ #include "helper.h" template -__global__ void speculate_schedula_cache( - const int64_t *draft_tokens, - int *block_tables, - bool *stop_flags, - const int64_t* prompt_lens, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int64_t *step_draft_tokens, - int *step_seq_lens_this_time, - int *accept_num, - int64_t *accept_tokens, - bool *is_block_step, - bool *not_need_stop, - const int64_t *stop_nums, - const int real_bsz, - const int max_bsz, - const int max_next_step_tokens, - const int draft_tokens_len, - const int accept_tokens_len, - const int block_size, - const int block_num_per_seq) { - const int bid = threadIdx.x; - int stop_flag_now_int = 0; - if (bid < real_bsz) { - if (!stop_flags[bid]) { - const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len; - int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len; - int *block_table_now = block_tables + bid * block_num_per_seq; - int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len; +__global__ void speculate_schedula_cache(const int64_t *draft_tokens, + int *block_tables, + bool *stop_flags, + const int64_t *prompt_lens, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int64_t *step_draft_tokens, + int *step_seq_lens_this_time, + int *accept_num, + int64_t *accept_tokens, + bool *is_block_step, + bool *not_need_stop, + const int64_t *stop_nums, + const int real_bsz, + const int max_bsz, + const int max_next_step_tokens, + const int draft_tokens_len, + const int accept_tokens_len, + const int block_size, + const int block_num_per_seq, + const bool prefill_one_step_stop) { + const int bid = threadIdx.x; + int stop_flag_now_int = 0; + if (bid < real_bsz) { + if (!stop_flags[bid]) { + const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len; + int64_t *step_draft_tokens_now = + step_draft_tokens + bid * draft_tokens_len; + int *block_table_now = block_tables + bid * block_num_per_seq; + int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len; - if (seq_lens_decoder[bid] >= prompt_lens[bid]) { - // decoder - const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size; - if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) { - is_block_step[bid] = true; - step_seq_lens_this_time[bid] = seq_lens_this_time[bid]; - seq_lens_this_time[bid] = 0; - stop_flags[bid] = true; - stop_flag_now_int = 1; - step_seq_lens_decoder[bid] = seq_lens_decoder[bid]; - seq_lens_decoder[bid] = 0; - accept_num[bid] = 0; - for (int i = 0; i < accept_tokens_len; i++) { - accept_tokens_now[i] = -1; - } - for (int i = 0; i < draft_tokens_len; i++) { - step_draft_tokens_now[i] = draft_tokens_now[i]; - } - } - } else { - // prefill - stop_flags[bid] = true; - seq_lens_this_time[bid] = 0; - seq_lens_decoder[bid] = 0; - seq_lens_encoder[bid] = 0; - accept_num[bid] = 0; - stop_flag_now_int = 1; - } + if (seq_lens_decoder[bid] >= prompt_lens[bid]) { + const int max_possible_block_idx = + (seq_lens_decoder[bid] + max_next_step_tokens) / block_size; - - } else { - stop_flag_now_int = 1; + if (prefill_one_step_stop) { + stop_flags[bid] = true; + seq_lens_this_time[bid] = 0; + seq_lens_decoder[bid] = 0; + seq_lens_encoder[bid] = 0; + accept_num[bid] = 0; + stop_flag_now_int = 1; + } else if (max_possible_block_idx < block_num_per_seq && + block_table_now[max_possible_block_idx] == -1) { + is_block_step[bid] = true; + step_seq_lens_this_time[bid] = seq_lens_this_time[bid]; + seq_lens_this_time[bid] = 0; + stop_flags[bid] = true; + stop_flag_now_int = 1; + step_seq_lens_decoder[bid] = seq_lens_decoder[bid]; + seq_lens_decoder[bid] = 0; + accept_num[bid] = 0; + for (int i = 0; i < accept_tokens_len; i++) { + accept_tokens_now[i] = -1; + } + for (int i = 0; i < draft_tokens_len; i++) { + step_draft_tokens_now[i] = draft_tokens_now[i]; + } } - } else if (bid >= real_bsz && bid < max_bsz) { + } else { + // prefill + stop_flags[bid] = true; + seq_lens_this_time[bid] = 0; + seq_lens_decoder[bid] = 0; + seq_lens_encoder[bid] = 0; + accept_num[bid] = 0; stop_flag_now_int = 1; + } + + } else { + stop_flag_now_int = 1; } - __syncthreads(); - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; + } else if (bid >= real_bsz && bid < max_bsz) { + stop_flag_now_int = 1; + } + __syncthreads(); + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; - // printf("stop_flag_now_int %d \n", stop_flag_now_int); - int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int); + // printf("stop_flag_now_int %d \n", stop_flag_now_int); + int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int); - if (threadIdx.x == 0) { - // printf("stop_sum %d \n", stop_sum); - not_need_stop[0] = stop_sum < stop_nums[0]; - } + if (threadIdx.x == 0) { + // printf("stop_sum %d \n", stop_sum); + not_need_stop[0] = stop_sum < stop_nums[0]; + } } void SpeculateScheduleCache(const paddle::Tensor &draft_tokens, @@ -113,45 +122,51 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens, const paddle::Tensor &stop_nums, const int block_size, const int max_draft_tokens) { - const int real_bsz = seq_lens_this_time.shape()[0]; - const int max_bsz = stop_flags.shape()[0]; - const int accept_tokens_len = accept_tokens.shape()[1]; - const int draft_token_len = draft_tokens.shape()[1]; - const int block_num_per_seq = block_tables.shape()[1]; + const int real_bsz = seq_lens_this_time.shape()[0]; + const int max_bsz = stop_flags.shape()[0]; + const int accept_tokens_len = accept_tokens.shape()[1]; + const int draft_token_len = draft_tokens.shape()[1]; + const int block_num_per_seq = block_tables.shape()[1]; - constexpr int BlockSize = 512; - const int max_next_step_tokens = 2 * max_draft_tokens + 2; - - auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); - speculate_schedula_cache<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>( - draft_tokens.data(), - const_cast(block_tables.data()), - const_cast(stop_flags.data()), - prompt_lens.data(), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(step_seq_lens_decoder.data()), - const_cast(step_draft_tokens.data()), - const_cast(step_seq_lens_this_time.data()), - const_cast(accept_num.data()), - const_cast(accept_tokens.data()), - const_cast(is_block_step.data()), - const_cast(not_need_stop_gpu.data()), - stop_nums.data(), - real_bsz, - max_bsz, - max_next_step_tokens, - draft_token_len, - accept_tokens_len, - block_size, - block_num_per_seq - ); + constexpr int BlockSize = 512; + const int max_next_step_tokens = 2 * max_draft_tokens + 2; + bool prefill_one_step_stop = false; + if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) { + if (env_p[0] == '1') { + prefill_one_step_stop = true; + } + } + auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); + speculate_schedula_cache + <<<1, BlockSize, 0, seq_lens_this_time.stream()>>>( + draft_tokens.data(), + const_cast(block_tables.data()), + const_cast(stop_flags.data()), + prompt_lens.data(), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_seq_lens_decoder.data()), + const_cast(step_draft_tokens.data()), + const_cast(step_seq_lens_this_time.data()), + const_cast(accept_num.data()), + const_cast(accept_tokens.data()), + const_cast(is_block_step.data()), + const_cast(not_need_stop_gpu.data()), + stop_nums.data(), + real_bsz, + max_bsz, + max_next_step_tokens, + draft_token_len, + accept_tokens_len, + block_size, + block_num_per_seq, + prefill_one_step_stop); - auto not_need_stop_cpu = - not_need_stop_gpu.copy_to(not_need_stop.place(), true); - bool *not_need_stop_data = const_cast(not_need_stop.data()); - not_need_stop_data[0] = not_need_stop_cpu.data()[0]; + auto not_need_stop_cpu = + not_need_stop_gpu.copy_to(not_need_stop.place(), true); + bool *not_need_stop_data = const_cast(not_need_stop.data()); + not_need_stop_data[0] = not_need_stop_cpu.data()[0]; } PD_BUILD_STATIC_OP(speculate_schedule_cache) @@ -184,17 +199,19 @@ PD_BUILD_STATIC_OP(speculate_schedule_cache) "accept_tokens_out", "is_block_step_out", "not_need_stop_out"}) - .SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, - {"block_tables", "block_tables_out"}, - {"stop_flags", "stop_flags_out"}, - {"seq_lens_this_time", "seq_lens_this_time_out"}, - {"seq_lens_encoder", "seq_lens_encoder_out"}, - {"seq_lens_decoder", "seq_lens_decoder_out"}, - {"step_seq_lens_decoder", "step_seq_lens_decoder_out"}, - {"step_draft_tokens", "step_draft_tokens_out"}, - {"step_seq_lens_this_time", "step_seq_lens_this_time_out"}, - {"accept_num", "accept_num_out"}, - {"accept_tokens", "accept_tokens_out"}, - {"is_block_step", "is_block_step_out"}, - {"not_need_stop", "not_need_stop_out"},}) + .SetInplaceMap({ + {"draft_tokens", "draft_tokens_out"}, + {"block_tables", "block_tables_out"}, + {"stop_flags", "stop_flags_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"step_seq_lens_decoder", "step_seq_lens_decoder_out"}, + {"step_draft_tokens", "step_draft_tokens_out"}, + {"step_seq_lens_this_time", "step_seq_lens_this_time_out"}, + {"accept_num", "accept_num_out"}, + {"accept_tokens", "accept_tokens_out"}, + {"is_block_step", "is_block_step_out"}, + {"not_need_stop", "not_need_stop_out"}, + }) .SetKernelFn(PD_KERNEL(SpeculateScheduleCache)); diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index b5238cae51d..50d57ed767e 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -77,7 +77,7 @@ def __init__( self.enable_logprob = self.model_config.enable_logprob # [mixed, prefill, decoder] - self.role = "mixed" + self.role = self.scheduler_config.splitwise_role self.sampler = MTPSampler(fd_config) self._init_model_inputs() @@ -365,6 +365,7 @@ def _init_model_inputs(self): ) # self.model_inputs["caches"] = self.cache_kvs # Inherit generation hyperparameters from the main model for consistency + self.model_inputs["prompt_lens"] = self.target_model_inputs["prompt_lens"] self.model_inputs["top_p"] = self.target_model_inputs["top_p"] self.model_inputs["top_k"] = self.target_model_inputs["top_k"] self.model_inputs["temperature"] = self.target_model_inputs["temperature"] @@ -501,9 +502,10 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int): self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 self.model_inputs["is_block_step"][idx : idx + 1] = False continue - # if has_prefill_task or has_decode_task: - # self.model_inputs["not_need_stop"][0] = True - self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests] + + # TODO(liuzichang): Solve splitewise-p bug to restore + # self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests] + self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int): """ @@ -704,11 +706,25 @@ def _post_process(self, sampled_token_ids): self.model_inputs["substep"], ) if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0: + skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) mtp_save_first_token( self.model_inputs["base_model_draft_tokens"], self.model_inputs["not_need_stop"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["prompt_lens"], + self.model_inputs["step_idx"], self.local_rank, self.parallel_config.use_ep, + skip_save, + ) + # Ensure only save first token once. + paddle.assign( + paddle.where( + self.model_inputs["stop_flags"], + paddle.zeros_like(self.model_inputs["step_idx"]), + self.model_inputs["step_idx"], + ), + self.model_inputs["step_idx"], ) def _propose(self, step_use_cudagraph: bool = False):