From 96787a0f4eadf3fd9c2d551708bd78bb3792d67a Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Sat, 1 Nov 2025 17:27:58 +0000 Subject: [PATCH 01/17] [XPU] support kernel for mtp(base) --- .../speculate_decoding/speculate_verify.cu | 2 +- custom_ops/xpu_ops/src/ops/block_attn.cc | 1 - .../src/ops/mtp/draft_model_preprocess_v2.cc | 148 ++++++ .../mtp/speculate_get_padding_offset_v2.cc | 133 ++++++ .../src/ops/mtp/speculate_save_output.cc | 2 +- .../src/ops/mtp/speculate_step_paddle.cc | 166 +++++++ .../xpu_ops/src/ops/mtp/speculate_verify.cc | 23 +- custom_ops/xpu_ops/src/ops/pybind/pybind.cc | 125 +++++- .../xpu_ops/src/plugin/include/xpu/plugin.h | 92 +++- .../kunlun3cpp/mtp_kernel/compute_order.xpu | 33 +- .../mtp_kernel/draft_model_preprocess_v2.xpu | 250 +++++++++++ .../mtp_kernel/draft_model_update.xpu | 4 +- .../speculate_free_and_dispatch_block.xpu | 332 ++++++++++++++ .../speculate_get_padding_offset.xpu | 37 ++ .../mtp_kernel/speculate_recover_block.xpu | 154 +++++++ .../mtp_kernel/speculate_verify.xpu | 61 +-- .../mtp_wrapper/draft_model_preprocess_v2.cpp | 425 ++++++++++++++++++ .../speculate_free_and_dispatch_block.cpp | 224 +++++++++ .../speculate_get_padding_offset.cpp | 113 +++++ .../mtp_wrapper/speculate_recover_block.cpp | 257 +++++++++++ .../wrapper/mtp_wrapper/speculate_verify.cpp | 90 ++-- .../xpu_ops/test/test_speculate_step.py | 189 ++++++++ 22 files changed, 2757 insertions(+), 104 deletions(-) create mode 100644 custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess_v2.cc create mode 100644 custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset_v2.cc create mode 100644 custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess_v2.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_dispatch_block.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess_v2.cpp create mode 100644 custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_dispatch_block.cpp create mode 100644 custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp create mode 100644 custom_ops/xpu_ops/test/test_speculate_step.py diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu index ae4555f5fe4..1eebc175b66 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu @@ -111,7 +111,7 @@ __global__ void speculate_verify(const int64_t *sampled_token_ids, auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens; auto *actual_candidate_len_now = actual_candidate_len + start_token_id; auto *sampled_token_id_now = sampled_token_ids + start_token_id; - +draft_model_update int i = 0; // printf("seq_lens_this_time[%d]-1: %d \n",bid, // seq_lens_this_time[bid]-1); diff --git a/custom_ops/xpu_ops/src/ops/block_attn.cc b/custom_ops/xpu_ops/src/ops/block_attn.cc index 6153a77dd0a..c9e3313f2e4 100644 --- a/custom_ops/xpu_ops/src/ops/block_attn.cc +++ b/custom_ops/xpu_ops/src/ops/block_attn.cc @@ -722,7 +722,6 @@ std::vector BlockAttnKernel( : quant_v_scale_inv, nullptr, // o_maxptr param.head_dim); // vo_head_dim - PD_CHECK(0, "speculative_attention unimplemented"); PD_CHECK(ret == api::SUCCESS, "xfa::speculative_attention_decoder failed."); if (!Eq_len) { diff --git a/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess_v2.cc b/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess_v2.cc new file mode 100644 index 00000000000..d97e28f68f7 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess_v2.cc @@ -0,0 +1,148 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "paddle/phi/core/enforce.h" +#include "xpu/plugin.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +namespace api = baidu::xpu::api; +void DraftModelPreprocessV2(const paddle::Tensor& draft_tokens, + const paddle::Tensor& input_ids, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_idx, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& is_block_step, + const paddle::Tensor& batch_drop, + const paddle::Tensor& pre_ids, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& base_model_seq_lens_this_time, + const paddle::Tensor& base_model_seq_lens_encoder, + const paddle::Tensor& base_model_seq_lens_decoder, + const paddle::Tensor& base_model_step_idx, + const paddle::Tensor& base_model_stop_flags, + const paddle::Tensor& base_model_is_block_step, + const paddle::Tensor& base_model_draft_tokens, + const int num_model_step, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { + + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + api::Context* ctx = static_cast(dev_ctx)->x_context(); + if (draft_tokens.is_cpu()) { + ctx = new api::Context(api::kCPU); + } + int real_bsz = seq_lens_this_time.shape()[0]; + int accept_tokens_len = accept_tokens.shape()[1]; + int input_ids_len = input_ids.shape()[1]; + int draft_tokens_len = draft_tokens.shape()[1]; + int pre_ids_len = pre_ids.shape()[1]; + constexpr int BlockSize = 512; + int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1]; + auto not_need_stop_gpu = + not_need_stop.copy_to(seq_lens_this_time.place(), false); + + int r = baidu::xpu::api::plugin::draft_model_preprocess_v2( + ctx, + const_cast(draft_tokens.data()), + const_cast(input_ids.data()), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_idx.data()), + const_cast(not_need_stop_gpu.data()), + const_cast(is_block_step.data()), + const_cast(batch_drop.data()), + const_cast(pre_ids.data()), + accept_tokens.data(), + accept_num.data(), + base_model_seq_lens_this_time.data(), + base_model_seq_lens_encoder.data(), + base_model_seq_lens_decoder.data(), + base_model_step_idx.data(), + base_model_stop_flags.data(), + base_model_is_block_step.data(), + const_cast(base_model_draft_tokens.data()), + real_bsz, + num_model_step, + accept_tokens_len, + draft_tokens_len, + input_ids_len, + base_model_draft_tokens_len, + pre_ids_len, + truncate_first_token, + splitwise_prefill, + kvcache_scheduler_v1); + + PD_CHECK(r == 0, "xpu::plugin::draft_model_preprocess failed."); + auto not_need_stop_cpu = + not_need_stop_gpu.copy_to(not_need_stop.place(), false); + bool* not_need_stop_data = const_cast(not_need_stop.data()); + not_need_stop_data[0] = not_need_stop_cpu.data()[0]; +} + +PD_BUILD_STATIC_OP(draft_model_preprocess_v2) + .Inputs({"draft_tokens", + "input_ids", + "stop_flags", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder", + "step_idx", + "not_need_stop", + "is_block_step", + "batch_drop", + "pre_ids", + "accept_tokens", + "accept_num", + "base_model_seq_lens_this_time", + "base_model_seq_lens_encoder", + "base_model_seq_lens_decoder", + "base_model_step_idx", + "base_model_stop_flags", + "base_model_is_block_step", + "base_model_draft_tokens"}) + .Outputs({"draft_tokens_out", + "input_ids_out", + "stop_flags_out", + "seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "step_idx_out", + "not_need_stop_out", + "batch_drop_out", + "pre_ids_out"}) + .Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool", "kvcache_scheduler_v1: bool"}) + .SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, + {"input_ids", "input_ids_out"}, + {"stop_flags", "stop_flags_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"step_idx", "step_idx_out"}, + {"not_need_stop", "not_need_stop_out"}, + {"batch_drop", "batch_drop_out"}, + {"pre_ids", "pre_ids_out"}}) + .SetKernelFn(PD_KERNEL(DraftModelPreprocessV2)); diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset_v2.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset_v2.cc new file mode 100644 index 00000000000..18b945bcc05 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset_v2.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +std::vector SpeculateGetPaddingOffsetV2( + const paddle::Tensor& input_ids, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& token_num, + const paddle::Tensor& seq_len, + const paddle::Tensor& seq_lens_encoder) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + + std::vector input_ids_shape = input_ids.shape(); + const int bsz = seq_len.shape()[0]; + const int seq_length = input_ids_shape[1]; + const int max_draft_tokens = draft_tokens.shape()[1]; + auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false); + auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); + + const int token_num_data = cpu_token_num.data()[0]; + auto x_remove_padding = paddle::empty( + {token_num_data}, paddle::DataType::INT64, input_ids.place()); + auto padding_offset = paddle::empty( + {token_num_data}, paddle::DataType::INT32, input_ids.place()); + auto batch_id_per_token = paddle::empty( + {token_num_data}, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_q = + paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_k = + paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place()); + + PD_CHECK(input_ids.is_contiguous(), "Input ids tensor must be contiguous"); + PD_CHECK(draft_tokens.is_contiguous(), + "Draft tokens tensor must be contiguous"); + PD_CHECK(cum_offsets.is_contiguous(), + "Cum offsets tensor must be contiguous"); + PD_CHECK(seq_len.is_contiguous(), "Seq lens tensor must be contiguous"); + + int r = baidu::xpu::api::plugin::speculate_get_padding_offset_v2( + xpu_ctx->x_context(), + batch_id_per_token.data(), + cum_offsets_out.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + cum_offsets.data(), + seq_len.data(), + seq_length, + bsz); + PD_CHECK(r == 0, "XPU speculate_get_padding_offset_v2 failed"); + + r = baidu::xpu::api::plugin::speculate_remove_padding( + xpu_ctx->x_context(), + x_remove_padding.data(), + input_ids.data(), + draft_tokens.data(), + seq_len.data(), + seq_lens_encoder.data(), + cum_offsets_out.data(), + seq_length, + max_draft_tokens, + bsz, + token_num_data); + PD_CHECK(r == 0, "XPU speculate_remove_padding failed"); + + return {x_remove_padding, + cum_offsets_out, + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k}; // , enc_token_num, dec_token_num}; +} + +std::vector> SpeculateGetPaddingOffsetV2InferShape( + const std::vector& input_ids_shape, + const std::vector& draft_tokens_shape, + const std::vector& cum_offsets_shape, + const std::vector& token_num_shape, + const std::vector& seq_len_shape, + const std::vector& seq_lens_encoder_shape) { + int64_t bsz = seq_len_shape[0]; + int64_t seq_len = input_ids_shape[1]; + return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}}; +} + +std::vector SpeculateGetPaddingOffsetV2InferDtype( + const paddle::DataType& input_ids_dtype, + const paddle::DataType& draft_tokens_dtype, + const paddle::DataType& cum_offsets_dtype, + const paddle::DataType& token_num_dtype, + const paddle::DataType& seq_len_dtype, + const paddle::DataType& seq_lens_encoder_dtype) { + return {input_ids_dtype, + seq_len_dtype, + seq_len_dtype, + seq_len_dtype, + seq_len_dtype}; +} + +PD_BUILD_STATIC_OP(speculate_get_padding_offset_v2) + .Inputs({"input_ids", + "draft_tokens", + "cum_offsets", + "token_num", + "seq_len", + "seq_lens_encoder"}) + .Outputs({"x_remove_padding", + "cum_offsets_out", + "batch_id_per_token", + "cu_seqlens_q", + "cu_seqlens_k"}) + .SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffsetV2)) + .SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetPaddingOffsetV2InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetPaddingOffsetV2InferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_save_output.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_save_output.cc index 60764b26a5e..a8e61c708dc 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_save_output.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_save_output.cc @@ -35,7 +35,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens, const paddle::Tensor& not_need_stop, int64_t rank_id, int msg_queue_id, - int save_each_rank) { + bool save_each_rank) { // printf("enter save output"); if (!save_each_rank && rank_id > 0) { return; diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc new file mode 100644 index 00000000000..eaec2c6958b --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc @@ -0,0 +1,166 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" +#include "paddle/phi/core/enforce.h" +#include "xpu/plugin.h" +#include + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +void SpeculateStepPaddle( + const paddle::Tensor &stop_flags, const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &ori_seq_lens_encoder, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor &encoder_block_lens, + const paddle::Tensor &is_block_step, const paddle::Tensor &step_block_list, + const paddle::Tensor &step_lens, const paddle::Tensor &recover_block_list, + const paddle::Tensor &recover_lens, const paddle::Tensor &need_block_list, + const paddle::Tensor &need_block_len, const paddle::Tensor &used_list_len, + const paddle::Tensor &free_list, const paddle::Tensor &free_list_len, + const paddle::Tensor &input_ids, const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, const paddle::Tensor &next_tokens, + const paddle::Tensor &first_token_ids, const paddle::Tensor &accept_num, + const int block_size, const int encoder_decoder_block_num, + const int max_draft_tokens) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = + paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + + const int bsz = seq_lens_this_time.shape()[0]; + PADDLE_ENFORCE_LE( + bsz, 640, + phi::errors::InvalidArgument( + "Only support bsz <= 640, but received bsz is %d", bsz)); + const int block_num_per_seq = block_tables.shape()[1]; + const int length = input_ids.shape()[1]; + const int pre_id_length = pre_ids.shape()[1]; + const int max_decoder_block_num = pre_id_length / block_size; + int r = baidu::xpu::api::plugin::speculate_free_and_dispatch_block( + xpu_ctx->x_context(), const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(encoder_block_lens.data()), + const_cast(is_block_step.data()), + const_cast(step_block_list.data()), + const_cast(step_lens.data()), + const_cast(recover_block_list.data()), + const_cast(recover_lens.data()), + const_cast(need_block_list.data()), + const_cast(need_block_len.data()), + const_cast(used_list_len.data()), + const_cast(free_list.data()), + const_cast(free_list_len.data()), + const_cast(first_token_ids.data()), + const_cast(accept_num.data()), + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num, + max_draft_tokens); + PD_CHECK(r == 0, "speculate_free_and_dispatch_block failed."); + auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false); + int recover_lens_cpu_data = recover_lens_cpu.data()[0]; + if (recover_lens_cpu_data > 0) { + r = baidu::xpu::api::plugin::speculate_recover_block( + xpu_ctx->x_context(), + const_cast(recover_block_list.data()), + const_cast(recover_lens.data()), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + ori_seq_lens_encoder.data(), + const_cast(seq_lens_encoder.data()), + seq_lens_decoder.data(), + const_cast(block_tables.data()), + const_cast(free_list.data()), + const_cast(free_list_len.data()), + const_cast(input_ids.data()), + pre_ids.data(), step_idx.data(), + encoder_block_lens.data(), used_list_len.data(), + next_tokens.data(), first_token_ids.data(), bsz, + block_num_per_seq, length, pre_id_length); + PD_CHECK(r == 0, "speculate_recover_block failed."); + } +} + +PD_BUILD_STATIC_OP(speculate_step_paddle) + .Inputs({"stop_flags", + "seq_lens_this_time", + "ori_seq_lens_encoder", + "seq_lens_encoder", + "seq_lens_decoder", + "block_tables", + "encoder_block_lens", + "is_block_step", + "step_block_list", + "step_lens", + "recover_block_list", + "recover_lens", + "need_block_list", + "need_block_len", + "used_list_len", + "free_list", + "free_list_len", + "input_ids", + "pre_ids", + "step_idx", + "next_tokens", + "first_token_ids", + "accept_num"}) + .Attrs({"block_size: int", + "encoder_decoder_block_num: int", + "max_draft_tokens: int"}) + .Outputs({"stop_flags_out", + "seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "block_tables_out", + "encoder_block_lens_out", + "is_block_step_out", + "step_block_list_out", + "step_lens_out", + "recover_block_list_out", + "recover_lens_out", + "need_block_list_out", + "need_block_len_out", + "used_list_len_out", + "free_list_out", + "free_list_len_out", + "input_ids_out", + "first_token_ids_out"}) + .SetInplaceMap({{"stop_flags", "stop_flags_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"block_tables", "block_tables_out"}, + {"encoder_block_lens", "encoder_block_lens_out"}, + {"is_block_step", "is_block_step_out"}, + {"step_block_list", "step_block_list_out"}, + {"step_lens", "step_lens_out"}, + {"recover_block_list", "recover_block_list_out"}, + {"recover_lens", "recover_lens_out"}, + {"need_block_list", "need_block_list_out"}, + {"need_block_len", "need_block_len_out"}, + {"used_list_len", "used_list_len_out"}, + {"free_list", "free_list_out"}, + {"free_list_len", "free_list_len_out"}, + {"input_ids", "input_ids_out"}, + {"first_token_ids", "first_token_ids_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateStepPaddle)); diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc index 53b5b90dc33..59df0f0f2de 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_verify.cc @@ -45,7 +45,10 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, const paddle::Tensor &topp, int max_seq_len, int verify_window, - bool enable_topp) { + bool enable_topp, + bool benchmark_mode, + bool accept_all_drafts) { + // TODO(chenhuan09):support accept_all_drafts auto bsz = accept_tokens.shape()[0]; int real_bsz = seq_lens_this_time.shape()[0]; auto max_draft_tokens = draft_tokens.shape()[1]; @@ -133,7 +136,8 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop); + prefill_one_step_stop, + benchmark_mode); } else { baidu::xpu::api::plugin::speculate_verify( ctx, @@ -161,7 +165,8 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop); + prefill_one_step_stop, + benchmark_mode); } } else { if (enable_topp) { @@ -191,7 +196,8 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop); + prefill_one_step_stop, + benchmark_mode); } else { baidu::xpu::api::plugin::speculate_verify( ctx, @@ -219,7 +225,8 @@ void SpeculateVerify(const paddle::Tensor &accept_tokens, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop); + prefill_one_step_stop, + benchmark_mode); } } } @@ -246,7 +253,11 @@ PD_BUILD_STATIC_OP(speculate_verify) "accept_num_out", "step_idx_out", "stop_flags_out"}) - .Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool"}) + .Attrs({"max_seq_len: int", + "verify_window: int", + "enable_topp: bool", + "benchmark_mode: bool", + "accept_all_drafts: bool"}) .SetInplaceMap({{"accept_tokens", "accept_tokens_out"}, {"accept_num", "accept_num_out"}, {"step_idx", "step_idx_out"}, diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 79f89df37d7..0128f7ca461 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -264,7 +264,9 @@ void SpeculateVerify(const paddle::Tensor& accept_tokens, const paddle::Tensor& topp, int max_seq_len, int verify_window, - bool enable_topp); + bool enable_topp, + bool benchmark_mode, + bool accept_all_drafts); void SpeculateClearAcceptNums(const paddle::Tensor& accept_num, const paddle::Tensor& seq_lens_decoder); @@ -301,6 +303,31 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const bool truncate_first_token, const bool splitwise_prefill); +void DraftModelPreprocessV2(const paddle::Tensor& draft_tokens, + const paddle::Tensor& input_ids, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_idx, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& is_block_step, + const paddle::Tensor& batch_drop, + const paddle::Tensor& pre_ids, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& base_model_seq_lens_this_time, + const paddle::Tensor& base_model_seq_lens_encoder, + const paddle::Tensor& base_model_seq_lens_decoder, + const paddle::Tensor& base_model_step_idx, + const paddle::Tensor& base_model_stop_flags, + const paddle::Tensor& base_model_is_block_step, + const paddle::Tensor& base_model_draft_tokens, + const int num_model_step, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1); + void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens, const paddle::Tensor& base_model_seq_lens_this_time, const paddle::Tensor& base_model_seq_lens_encoder, @@ -396,6 +423,14 @@ std::vector SpeculateGetPaddingOffset( const paddle::Tensor& seq_len, const paddle::Tensor& seq_lens_encoder); +std::vector SpeculateGetPaddingOffsetV2( + const paddle::Tensor& input_ids, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& token_num, + const paddle::Tensor& seq_len, + const paddle::Tensor& seq_lens_encoder); + void StepPaddle(const paddle::Tensor& stop_flags, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& ori_seq_lens_encoder, @@ -436,6 +471,24 @@ void MTPStepPaddle( const int block_size, const int max_draft_tokens); +void SpeculateStepPaddle( + const paddle::Tensor &stop_flags, const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &ori_seq_lens_encoder, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor &encoder_block_lens, + const paddle::Tensor &is_block_step, const paddle::Tensor &step_block_list, + const paddle::Tensor &step_lens, const paddle::Tensor &recover_block_list, + const paddle::Tensor &recover_lens, const paddle::Tensor &need_block_list, + const paddle::Tensor &need_block_len, const paddle::Tensor &used_list_len, + const paddle::Tensor &free_list, const paddle::Tensor &free_list_len, + const paddle::Tensor &input_ids, const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, const paddle::Tensor &next_tokens, + const paddle::Tensor &first_token_ids, const paddle::Tensor &accept_num, + const int block_size, const int encoder_decoder_block_num, + const int max_draft_tokens); + void SaveOutMmsgStatic(const paddle::Tensor& x, const paddle::Tensor& not_need_stop, int64_t rank_id, @@ -637,6 +690,34 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("splitwise_prefill"), "Preprocess data for draft model in speculative decoding"); + m.def("draft_model_preprocess_v2", + &DraftModelPreprocessV2, + py::arg("draft_tokens"), + py::arg("input_ids"), + py::arg("stop_flags"), + py::arg("seq_lens_this_time"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("step_idx"), + py::arg("not_need_stop"), + py::arg("is_block_step"), + py::arg("batch_drop"), + py::arg("pre_ids"), + py::arg("accept_tokens"), + py::arg("accept_num"), + py::arg("base_model_seq_lens_this_time"), + py::arg("base_model_seq_lens_encoder"), + py::arg("base_model_seq_lens_decoder"), + py::arg("base_model_step_idx"), + py::arg("base_model_stop_flags"), + py::arg("base_model_is_block_step"), + py::arg("base_model_draft_tokens"), + py::arg("num_model_step"), + py::arg("truncate_first_token"), + py::arg("splitwise_prefill"), + py::arg("kvcache_scheduler_v1"), + "Preprocess data for draft model in speculative decoding"); + m.def("draft_model_postprocess", &DraftModelPostprocess, py::arg("base_model_draft_tokens"), @@ -983,6 +1064,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("max_seq_len"), py::arg("verify_window"), py::arg("enable_topp"), + py::arg("benchmark_mode"), + py::arg("accept_all_drafts"), "Perform speculative verification for decoding"); m.def("speculate_clear_accept_nums", @@ -1021,6 +1104,16 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("seq_lens_encoder"), "Get padding offset"); + m.def("speculate_get_padding_offset_v2", + &SpeculateGetPaddingOffsetV2, + py::arg("input_ids"), + py::arg("draft_tokens"), + py::arg("cum_offsets"), + py::arg("token_num"), + py::arg("seq_len"), + py::arg("seq_lens_encoder"), + "Get padding offset v2"); + m.def("speculate_step_reschedule", &SpeculateStepSchedule, py::arg("stop_flags"), @@ -1104,6 +1197,36 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("encoder_decoder_block_num"), "Step paddle function"); + m.def("speculate_step_paddle", + &SpeculateStepPaddle, + py::arg("stop_flags"), + py::arg("seq_lens_this_time"), + py::arg("ori_seq_lens_encoder"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_decoder"), + py::arg("block_tables"), + py::arg("encoder_block_lens"), + py::arg("is_block_step"), + py::arg("step_block_list"), + py::arg("step_lens"), + py::arg("recover_block_list"), + py::arg("recover_lens"), + py::arg("need_block_list"), + py::arg("need_block_len"), + py::arg("used_list_len"), + py::arg("free_list"), + py::arg("free_list_len"), + py::arg("input_ids"), + py::arg("pre_ids"), + py::arg("step_idx"), + py::arg("next_tokens"), + py::arg("first_token_ids"), + py::arg("accept_num"), + py::arg("block_size"), + py::arg("encoder_decoder_block_num"), + py::arg("max_draft_tokens"), + "Step paddle function"); + m.def("text_image_gather_scatter", &TextImageGatherScatter, py::arg("input"), diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index 4e393b86851..754297ce6ce 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -74,6 +74,48 @@ DLL_EXPORT int get_padding_offset(Context* ctx, const int* seq_lens, const int max_seq_len, const int bs); + +DLL_EXPORT int speculate_get_padding_offset_v2(Context* ctx, + int* batch_id_per_token, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + int bsz); + +DLL_EXPORT int draft_model_preprocess_v2(api::Context* ctx, + int64_t* draft_tokens, + int64_t* input_ids, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + bool* not_need_stop, + bool* is_block_step, + bool* batch_drop, + int64_t* pre_ids, + const int64_t* accept_tokens, + const int* accept_num, + const int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const int* base_model_seq_lens_decoder, + const int64_t* base_model_step_idx, + const bool* base_model_stop_flags, + const bool* base_model_is_block_step, + int64_t* base_model_draft_tokens, + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1); DLL_EXPORT int update_inputs(Context* ctx, bool* not_need_stop, @@ -111,6 +153,30 @@ DLL_EXPORT int free_and_dispatch_block(Context* ctx, const int block_num_per_seq, const int max_decoder_block_num); +DLL_EXPORT int speculate_free_and_dispatch_block(Context* ctx, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_decoder, + int* block_tables, + int* encoder_block_lens, + bool* is_block_step, + int* step_block_list, // [bsz] + int* step_len, + int* recover_block_list, + int* recover_len, + int* need_block_list, + int* need_block_len, + int* used_list_len, + int* free_list, + int* free_list_len, + int64_t* first_token_ids, + int* accept_num, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens); + DLL_EXPORT int recover_block(Context* ctx, int* recover_block_list, // [bsz] int* recover_len, @@ -134,6 +200,29 @@ DLL_EXPORT int recover_block(Context* ctx, const int length, const int pre_id_length); +DLL_EXPORT int speculate_recover_block(Context* ctx, + int* recover_block_list, // [bsz] + int* recover_len, + bool* stop_flags, + int* seq_lens_this_time, + const int* ori_seq_lens_encoder, + int* seq_lens_encoder, + const int* seq_lens_decoder, + int* block_tables, + int* free_list, + int* free_list_len, + int64_t* input_ids, + const int64_t* pre_ids, + const int64_t* step_idx, + const int* encoder_block_lens, + const int* used_list_len, + const int64_t* next_tokens, + const int64_t* first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length); + DLL_EXPORT int recover_decode_task(Context* ctx, bool* stop_flags, int* seq_lens_this_time, @@ -305,7 +394,8 @@ DLL_EXPORT int speculate_verify(Context* ctx, const int max_seq_len, const int max_candidate_len, const int verify_window, - const bool prefill_one_step_stop); + const bool prefill_one_step_stop, + const bool benchmark_mode); DLL_EXPORT int speculate_clear_accept_nums(Context* ctx, int* accept_num, diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu index 7cd399d09c1..b8d70544a47 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu @@ -20,6 +20,7 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time, return; } + // 256 * int char lm[6 * 1024]; int buf_size = 6 * 1024 / (6 * sizeof(int)); int* lm_base_model_seq_lens_this_time = (int*)lm; @@ -68,10 +69,7 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time, in_offset += write_size; } mfence_lm(); - // 2. base model encoder. Base step=0 - } else if (cur_base_model_seq_lens_encoder != 0) { - // nothing happens - // 3. New end + // 2. Base model stop at last verify-step. } else if (cur_base_model_seq_lens_this_time != 0 && cur_seq_lens_this_time == 0) { in_offset += cur_base_model_seq_lens_this_time; @@ -80,27 +78,16 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time, cur_seq_lens_this_time == 0) { // nothing happens } else { - if (accept_num <= actual_draft_token_num) { - int position_map_val = out_offset; - LM2GM(&position_map_val, - position_map + in_offset + accept_num - 1, - sizeof(int)); - out_offset++; - in_offset += cur_base_model_seq_lens_this_time; - } else { - int position_map_val_1 = out_offset; - LM2GM(&position_map_val_1, - position_map + in_offset + accept_num - 2, - sizeof(int)); - out_offset++; - int position_map_val_2 = out_offset; - LM2GM(&position_map_val_2, - position_map + in_offset + accept_num - 1, - sizeof(int)); - out_offset++; - in_offset += cur_base_model_seq_lens_this_time; + // accept_num << buf_size, so do not need split + for (int i = 0; i < accept_num; i++) { + lm_position_map[i] = out_offset++; } mfence_lm(); + LM2GM(lm_position_map, + position_map + in_offset, + accept_num * sizeof(int)); + in_offset += cur_base_model_seq_lens_this_time; + mfence_lm(); } } } diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess_v2.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess_v2.xpu new file mode 100644 index 00000000000..052de1b7fc3 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess_v2.xpu @@ -0,0 +1,250 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_debug.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" +#include "xpu/kernel/cluster_simd.h" + +namespace xpu3 { +namespace plugin { +__global__ void draft_model_preprocess_v2( + int64_t* draft_tokens, + int64_t* input_ids, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + bool* not_need_stop, + bool* is_block_step, + bool* batch_drop, + int64_t* pre_ids, + const int64_t* accept_tokens, + const int* accept_num, + const int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const int* base_model_seq_lens_decoder, + const int64_t* base_model_step_idx, + const bool* base_model_stop_flags, + const bool* base_model_is_block_step, + int64_t* base_model_draft_tokens, + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + int nclusters = cluster_num(); + int tid = clusterid * ncores + cid; + __shared__ int not_stop_flag_sm[64]; + not_stop_flag_sm[cid] = 0; + int64_t accept_tokens_now[128]; + + int value_zero = 0; + int64_t value_fu = -1; + + if (splitwise_prefill) { + for (; tid < bsz; tid += ncores * nclusters) { + int64_t base_model_step_idx_now = 0; + int seq_lens_encoder_now = 0; + int seq_lens_this_time_now = 0; + bool stop_flags_now = false; + int64_t base_model_first_token; + int seq_lens_encoder_record_now = 0; + int64_t input_ids_now = 0; + + GM2LM_ASYNC(base_model_step_idx + tid, + &base_model_step_idx_now, + sizeof(int64_t)); + GM2LM_ASYNC( + seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int)); + GM2LM(accept_tokens + tid * accept_tokens_len, + &base_model_first_token, + sizeof(int64_t)); + if (seq_lens_encoder_now > 0) { + not_stop_flag_sm[cid] += 1; + stop_flags_now = false; + int position = seq_lens_encoder_now; + if (truncate_first_token) { + position = position - 1; + input_ids_now = base_model_first_token; + seq_lens_this_time_now = seq_lens_encoder_now; + } else { + input_ids_now = base_model_first_token; + seq_lens_this_time_now = seq_lens_encoder_now + 1; + } + LM2GM_ASYNC(&input_ids_now, + input_ids + tid * input_ids_len + position, + sizeof(int64_t)); + } else { + stop_flags_now = true; + seq_lens_this_time_now = 0; + seq_lens_encoder_now = 0; + not_stop_flag_sm[cid] += 0; + LM2GM_ASYNC(&value_zero, seq_lens_decoder + tid, sizeof(int)); + } + LM2GM_ASYNC( + &seq_lens_encoder_now, seq_lens_encoder + tid, sizeof(int)); + LM2GM_ASYNC(&stop_flags_now, stop_flags + tid, sizeof(bool)); + LM2GM( + &seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int)); + } + } else { + for (; tid < bsz; tid += ncores * nclusters) { + bool base_model_stop_flags_now = false; + bool base_model_is_block_step_now = false; + bool batch_drop_now = false; + bool stop_flags_now = false; + bool is_block_step_now = false; + int seq_lens_this_time_now = 0; + int seq_lens_encoder_now = 0; + int seq_lens_decoder_new = 0; + int accept_num_now = 0; + int base_model_seq_lens_decoder_now = 0; + int base_model_seq_lens_this_time_now = 0; + int64_t step_id_now = 0; + int64_t base_model_step_idx_now; + int64_t pre_ids_now; + mfence(); + GM2LM_ASYNC(is_block_step + tid, &is_block_step_now, sizeof(bool)); + GM2LM_ASYNC(base_model_stop_flags + tid, + &base_model_stop_flags_now, + sizeof(bool)); + GM2LM_ASYNC(base_model_is_block_step + tid, + &base_model_is_block_step_now, + sizeof(bool)); + GM2LM_ASYNC(batch_drop + tid, &batch_drop_now, sizeof(bool)); + GM2LM_ASYNC(stop_flags + tid, &stop_flags_now, sizeof(bool)); + GM2LM_ASYNC( + seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int)); + GM2LM_ASYNC( + seq_lens_decoder + tid, &seq_lens_decoder_new, sizeof(int)); + + GM2LM_ASYNC(accept_tokens + tid * accept_tokens_len, + accept_tokens_now, + accept_tokens_len * sizeof(int64_t)); + GM2LM_ASYNC(accept_num + tid, &accept_num_now, sizeof(int)); + + GM2LM_ASYNC(base_model_seq_lens_this_time + tid, + &base_model_seq_lens_this_time_now, + sizeof(int)); + GM2LM_ASYNC(base_model_seq_lens_decoder + tid, + &base_model_seq_lens_decoder_now, + sizeof(int)); + GM2LM_ASYNC(step_idx + tid, &step_id_now, sizeof(int64_t)); + GM2LM(base_model_step_idx + tid, + &base_model_step_idx_now, + sizeof(int64_t)); + + for (int i = 1; i < base_model_draft_tokens_len; i++) { + LM2GM_ASYNC(&value_fu, + base_model_draft_tokens + + tid * base_model_draft_tokens_len + i, + sizeof(int)); + } + if (kvcache_scheduler_v1) { + if (base_model_stop_flags_now && base_model_is_block_step_now) { + stop_flags_now = true; + is_block_step_now = true; + } + } else { + if (base_model_stop_flags_now && base_model_is_block_step_now) { + batch_drop_now = true; + stop_flags_now = true; + } + } + + if (!(base_model_stop_flags_now || batch_drop_now)) { + not_stop_flag_sm[cid] += 1; + if (seq_lens_encoder_now > 0) { + int seq_len_encoder = seq_lens_encoder_now; + stop_flags_now = false; + int64_t base_model_first_token = accept_tokens_now[0]; + LM2GM(&base_model_first_token, + pre_ids + tid * pre_ids_len, + sizeof(int64_t)); + int position = seq_len_encoder; + if (truncate_first_token) { + LM2GM(&base_model_first_token, + input_ids + tid * input_ids_len + position - 1, + sizeof(int64_t)); + seq_lens_this_time_now = seq_len_encoder; + } else { + LM2GM(&base_model_first_token, + input_ids + tid * input_ids_len + position, + sizeof(int64_t)); + seq_lens_this_time_now = seq_len_encoder + 1; + } + } else { + if (kvcache_scheduler_v1) { + if (!base_model_is_block_step_now && + is_block_step_now) { + is_block_step_now = false; + } + } + if (stop_flags_now) { + stop_flags_now = false; + seq_lens_decoder_new = base_model_seq_lens_decoder_now - + base_model_seq_lens_this_time_now; + step_id_now = base_model_step_idx_now - + base_model_seq_lens_this_time_now; + + } else { + seq_lens_decoder_new -= num_model_step - 1; + step_id_now -= num_model_step - 1; + } + for (int i = 0; i < accept_num_now; i++) { + const int pre_id_pos = + base_model_step_idx_now - (accept_num_now - i); + LM2GM(accept_tokens_now + i, + draft_tokens + tid * draft_tokens_len + i, + sizeof(int64_t)); + LM2GM(accept_tokens_now + i, + pre_ids + tid * pre_ids_len + pre_id_pos, + sizeof(int64_t)); + } + seq_lens_this_time_now = accept_num_now; + } + + } else { + stop_flags_now = true; + seq_lens_this_time_now = 0; + seq_lens_encoder_now = 0; + seq_lens_decoder_new = 0; + } + LM2GM_ASYNC(&stop_flags_now, stop_flags + tid, sizeof(bool)); + LM2GM_ASYNC(&batch_drop_now, batch_drop + tid, sizeof(bool)); + LM2GM_ASYNC(&is_block_step_now, is_block_step + tid, sizeof(bool)); + LM2GM_ASYNC( + &seq_lens_decoder_new, seq_lens_decoder + tid, sizeof(int)); + LM2GM_ASYNC( + &seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int)); + LM2GM_ASYNC( + &seq_lens_encoder_now, seq_lens_encoder + tid, sizeof(int)); + LM2GM_ASYNC(&step_id_now, step_idx + tid, sizeof(int64_t)); + } + } + mfence(); + sync_cluster(); + bool value_true = true; + bool value_false = false; + if (cid == 0) { + for (int i = 0; i < ncores; i++) { + not_stop_flag_sm[0] += not_stop_flag_sm[i]; + } + if (not_stop_flag_sm[0] > 0) { + LM2GM(&value_true, not_need_stop, sizeof(bool)); + } else { + LM2GM(&value_false, not_need_stop, sizeof(bool)); + } + } +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_update.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_update.xpu index 0334995f9db..50ba31d61b8 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_update.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_update.xpu @@ -60,10 +60,8 @@ __global__ void draft_model_update(const int64_t* inter_next_tokens, token_this_time = next_tokens_start[seq_len_this_time - 1]; draft_token_now[0] = next_tokens_start[seq_len_this_time - 1]; base_model_draft_tokens_now[substep + 1] = token_this_time; - for (int i = 0; i < seq_len_this_time; ++i) { - pre_ids_now[step_idx[tid] + 1 + i] = next_tokens_start[i]; - } step_idx[tid] += seq_len_this_time; + pre_ids_now[step_idx[tid]] = token_this_time; } else { token_this_time = next_tokens_start[0]; seq_lens_decoder[tid] = seq_len_encoder + seq_len_decoder; diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_dispatch_block.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_dispatch_block.xpu new file mode 100644 index 00000000000..f5652c55aa6 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_dispatch_block.xpu @@ -0,0 +1,332 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu3 { +namespace plugin { + +static __device__ inline int loada_float(_shared_ptr_ const int *ptr) { + int ret; + __asm__ __volatile__("loada.w %0,%1" : "=&r"(ret) : "r"(ptr)); + return ret; +} + +static __device__ inline bool storea_float(_shared_ptr_ int *ptr, int value) { + bool ret; + __asm__ __volatile__("storea.w %0,%1,%2" : "=&r"(ret) : "r"(value), "r"(ptr)); + return ret; +} + +static __device__ int atomic_add(_shared_ptr_ int *ptr, int value) { + bool fail = true; + int old_value; + while (fail) { + old_value = loada_float(ptr); + int new_value = old_value + value; + fail = storea_float(ptr, new_value); + } + return old_value; +} + +static __device__ bool in_need_block_list(const int qid, + _shared_ptr_ int *need_block_list, + const int need_block_len) { + bool res = false; + for (int i = 0; i < need_block_len; i++) { + if (qid == need_block_list[i]) { + need_block_list[i] = -1; + res = true; + break; + } + } + return res; +} + +__global__ void speculate_free_and_dispatch_block( + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + int *accept_num, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + if (clusterid != 0 || cid >= bsz) return; + + // assert bsz <= 640 + const int max_bs = 640; + int value_zero = 0; + bool flag_true = true; + + // 128 = seq_len(8192) / block_size(64) + // 每次最多处理block_table数量为128 + const int block_table_now_len = 128; + int block_table_now[block_table_now_len]; + for (int i = 0; i < block_table_now_len; i++) { + block_table_now[i] = -1; + } + bool stop_flag_lm; + int seq_lens_decoder_lm; + + __shared__ int free_list_len_sm; + // 每次最多处理free_list数量为block_table_now_len + int free_list_now[block_table_now_len]; + __shared__ int need_block_len_sm; + __shared__ int need_block_list_sm[max_bs]; + __shared__ int used_list_len_sm[max_bs]; + __shared__ bool step_max_block_flag; + __shared__ int in_need_block_list_len; + if (cid == 0) { + step_max_block_flag = false; + in_need_block_list_len = 0; + GM2SM_ASYNC(free_list_len, &free_list_len_sm, sizeof(int)); + GM2SM_ASYNC(need_block_len, &need_block_len_sm, sizeof(int)); + mfence(); + if (need_block_len_sm > 0) { + GM2SM_ASYNC( + need_block_list, need_block_list_sm, sizeof(int) * need_block_len_sm); + } + GM2SM_ASYNC(used_list_len, used_list_len_sm, sizeof(int) * bsz); + mfence(); + } + sync_cluster(); + + for (int tid = cid; tid < bsz; tid += ncores) { + bool is_block_step_lm; + int seq_lens_this_time_lm; + mfence(); + GM2LM_ASYNC(stop_flags + tid, &stop_flag_lm, sizeof(bool)); + GM2LM_ASYNC(is_block_step + tid, &is_block_step_lm, sizeof(bool)); + GM2LM_ASYNC(seq_lens_decoder + tid, &seq_lens_decoder_lm, sizeof(int)); + GM2LM_ASYNC(seq_lens_this_time + tid, &seq_lens_this_time_lm, sizeof(int)); + mfence(); + int max_possible_block_idx = + (seq_lens_decoder_lm + max_draft_tokens + 1) / block_size; + if (stop_flag_lm && !is_block_step_lm) { + // 回收block块 + int64_t first_token_id_lm = -1; + mfence_lm(); + LM2GM(&first_token_id_lm, first_token_ids + tid, sizeof(int64_t)); + int encoder_block_len_lm; + int decoder_used_len_lm = used_list_len_sm[tid]; + GM2LM(encoder_block_lens + tid, &encoder_block_len_lm, sizeof(int)); + if (decoder_used_len_lm > 0) { + const int ori_free_list_len = + atomic_add(&free_list_len_sm, decoder_used_len_lm); + for (int i = 0; i < decoder_used_len_lm; i += block_table_now_len) { + int process_len = min(block_table_now_len, decoder_used_len_lm - i); + GM2LM( + block_tables + tid * block_num_per_seq + encoder_block_len_lm + i, + free_list_now, + process_len * sizeof(int)); + LM2GM(free_list_now, + free_list + ori_free_list_len + i, + process_len * sizeof(int)); + LM2GM( + block_table_now, + block_tables + tid * block_num_per_seq + encoder_block_len_lm + i, + process_len * sizeof(int)); + } + used_list_len_sm[tid] = 0; + mfence(); + LM2GM(&value_zero, encoder_block_lens + tid, sizeof(int)); + } + } else if (seq_lens_this_time_lm != 0 && + max_possible_block_idx < block_num_per_seq) { + int next_block_id; + GM2LM(block_tables + tid * block_num_per_seq + + (seq_lens_decoder_lm + max_draft_tokens + 1) / block_size, + &next_block_id, + sizeof(int)); + if (next_block_id == -1) { + // 统计需要分配block的位置和总数 + const int ori_need_block_len = atomic_add(&need_block_len_sm, 1); + need_block_list_sm[ori_need_block_len] = tid; + } + } + } + sync_cluster(); + + bool is_block_step_lm[max_bs]; + int step_len_lm; + int step_block_list_lm[max_bs]; + int recover_len_lm; + int recover_block_list_lm[max_bs]; + if (cid == 0) { + GM2LM_ASYNC(is_block_step, is_block_step_lm, sizeof(bool) * bsz); + GM2LM_ASYNC(step_len, &step_len_lm, sizeof(int)); + GM2LM_ASYNC(step_block_list, step_block_list_lm, sizeof(int) * bsz); + GM2LM_ASYNC(recover_len, &recover_len_lm, sizeof(int)); + GM2LM_ASYNC(recover_block_list, recover_block_list_lm, sizeof(int) * bsz); + mfence(); + } + + if (cid == 0) { + while (need_block_len_sm > free_list_len_sm) { + // 调度block,根据used_list_len从大到小回收block,直到满足need_block_len,已解码到最后一个block的query不参与调度(马上就结束) + int max_used_list_len_id = 0; + int max_used_list_len = 0; + for (int i = 0; i < bsz; i++) { + if (!is_block_step_lm[i] && + (step_max_block_flag || + used_list_len_sm[i] != max_decoder_block_num) && + (used_list_len_sm[i] > max_used_list_len)) { + max_used_list_len_id = i; + max_used_list_len = used_list_len_sm[i]; + } + } + + if (max_used_list_len == 0) { + step_max_block_flag = true; + } else { + int encoder_block_len; + GM2LM(encoder_block_lens + max_used_list_len_id, + &encoder_block_len, + sizeof(int)); + for (int i = 0; i < max_used_list_len; i += block_table_now_len) { + int process_len = min(block_table_now_len, max_used_list_len - i); + GM2LM(block_tables + max_used_list_len_id * block_num_per_seq + + encoder_block_len + i, + free_list_now, + process_len * sizeof(int)); + LM2GM(free_list_now, + free_list + free_list_len_sm + i, + process_len * sizeof(int)); + LM2GM(block_table_now, + block_tables + max_used_list_len_id * block_num_per_seq + + encoder_block_len + i, + process_len * sizeof(int)); + } + step_block_list_lm[step_len_lm] = max_used_list_len_id; + int need_block_len_all = need_block_len_sm + in_need_block_list_len; + if (in_need_block_list( + max_used_list_len_id, need_block_list_sm, need_block_len_all)) { + need_block_len_sm--; + in_need_block_list_len++; + } + step_len_lm++; + free_list_len_sm += max_used_list_len; + LM2GM_ASYNC( + &flag_true, stop_flags + max_used_list_len_id, sizeof(bool)); + is_block_step_lm[max_used_list_len_id] = true; + LM2GM_ASYNC(&value_zero, + seq_lens_this_time + max_used_list_len_id, + sizeof(int)); + LM2GM_ASYNC( + &value_zero, seq_lens_decoder + max_used_list_len_id, sizeof(int)); + mfence(); + } + } + } + sync_cluster(); + + int need_block_len_all = need_block_len_sm + in_need_block_list_len; + for (int tid = cid; tid < need_block_len_all; tid += ncores) { + // 为需要block的位置分配block,每个位置分配一个block + const int need_block_id = need_block_list_sm[tid]; + if (need_block_id != -1) { + GM2LM(stop_flags + need_block_id, &stop_flag_lm, sizeof(bool)); + if (!stop_flag_lm) { + // 如果需要的位置正好是上一步中被释放的位置,不做处理 + used_list_len_sm[need_block_id]++; + const int ori_free_list_len = atomic_add(&free_list_len_sm, -1); + int tmp_seq_lens_decoder; + GM2LM(seq_lens_decoder + need_block_id, + &tmp_seq_lens_decoder, + sizeof(int)); + int free_block_id; + GM2LM(free_list + ori_free_list_len - 1, &free_block_id, sizeof(int)); + LM2GM(&free_block_id, + block_tables + need_block_id * block_num_per_seq + + (tmp_seq_lens_decoder + max_draft_tokens + 1) / block_size, + sizeof(int)); + } + need_block_list_sm[tid] = -1; + } + } + sync_cluster(); + + // 计算可以复原的query id + // 每次最多只恢复max_recover_num个query + int max_recover_num = 1; + if (cid == 0 && step_len_lm > 0) { + int ori_free_list_len = free_list_len_sm; + int ori_step_block_id = step_block_list_lm[step_len_lm - 1]; + int tmp_used_len = used_list_len_sm[ori_step_block_id]; + int encoder_block_len_lm; + GM2LM(encoder_block_lens + ori_step_block_id, + &encoder_block_len_lm, + sizeof(int)); + const int max_decoder_block_num_this_seq = + max_decoder_block_num - encoder_block_len_lm; + // 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中) + int used_len = tmp_used_len + 1 < max_decoder_block_num_this_seq + ? tmp_used_len + 1 + : max_decoder_block_num_this_seq; + while (step_len_lm > 0 && ori_free_list_len >= used_len && + max_recover_num-- > 0) { + recover_block_list_lm[recover_len_lm] = ori_step_block_id; + is_block_step_lm[ori_step_block_id] = false; + used_list_len_sm[ori_step_block_id] = used_len; + ori_free_list_len -= used_len; + step_block_list_lm[step_len_lm - 1] = -1; + step_len_lm--; + recover_len_lm++; + if (step_len_lm > 0) { + ori_step_block_id = step_block_list_lm[step_len_lm - 1]; + tmp_used_len = used_list_len_sm[ori_step_block_id]; + used_len = tmp_used_len + 1 < max_decoder_block_num_this_seq + ? tmp_used_len + 1 + : max_decoder_block_num_this_seq; + } + } + } + + // TODO(zhupengyang): + // Before the operator: need_block_len is 0, need_block_list is -1 + // After the operator: need_block_len is 0, need_block_list is -1 + // May need_block_len and need_block_list not need update? + int ori_need_block_len; + if (cid == 0) { + ori_need_block_len = need_block_len_sm; + need_block_len_sm = 0; + } + + if (cid == 0) { + mfence(); + LM2GM_ASYNC(step_block_list_lm, step_block_list, sizeof(int) * bsz); + LM2GM_ASYNC(is_block_step_lm, is_block_step, sizeof(bool) * bsz); + LM2GM_ASYNC(&step_len_lm, step_len, sizeof(int)); + LM2GM_ASYNC(&recover_len_lm, recover_len, sizeof(int)); + LM2GM_ASYNC(recover_block_list_lm, recover_block_list, sizeof(int) * bsz); + SM2GM_ASYNC(&free_list_len_sm, free_list_len, sizeof(int)); + SM2GM_ASYNC(&need_block_len_sm, need_block_len, sizeof(int)); + if (ori_need_block_len > 0) { + SM2GM_ASYNC(need_block_list_sm, + need_block_list, + sizeof(int) * ori_need_block_len); + } + SM2GM_ASYNC(used_list_len_sm, used_list_len, sizeof(int) * bsz); + mfence(); + } +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu index c08d756d7c0..4af74b8f620 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu @@ -101,6 +101,43 @@ __global__ void speculate_get_padding_offset(int* padding_offset, } } + +__global__ void speculate_get_padding_offset_v2(int* batch_id_per_token, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + int bsz) { + int bid = cluster_id(); + int tid = core_id(); + int ncores = core_num(); + int nclusters = cluster_num(); + int seq_lens_now = 0; + int cum_offsets_now = 0; + int cum_offsets_now_ind = 0; + for (int bi = bid; bi < bsz; bi += nclusters) { + GM2LM(seq_lens + bi, &seq_lens_now, sizeof(int)); + if (bi == 0) { + cum_offsets_now = 0; + } else { + GM2LM(cum_offsets + bi - 1, &cum_offsets_now, sizeof(int)); + } + GM2LM(cum_offsets + bi, &cum_offsets_now_ind, sizeof(int)); + + for (int i = tid; i < seq_lens_now; i += ncores) { + LM2GM(&bi, + batch_id_per_token + bi * max_seq_len - cum_offsets_now + i, + sizeof(int)); + } + LM2GM(&cum_offsets_now, cum_offsets_out + bi, sizeof(int)); + int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets_now_ind; + LM2GM(&cum_seq_len, cu_seqlens_q + bi + 1, sizeof(int)); + LM2GM(&cum_seq_len, cu_seqlens_k + bi + 1, sizeof(int)); + } +} + #define _XPU_DEF_SPECULATE_KERNELS_(T) \ template __global__ void speculate_remove_padding(T*, \ const T*, \ diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu new file mode 100644 index 00000000000..46d24821dda --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu @@ -0,0 +1,154 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu3 { +namespace plugin { + +static __device__ inline int loada_float(_shared_ptr_ const int* ptr) { + int ret; + __asm__ __volatile__("loada.w %0,%1" : "=&r"(ret) : "r"(ptr)); + return ret; +} + +static __device__ inline bool storea_float(_shared_ptr_ int* ptr, int value) { + bool ret; + __asm__ __volatile__("storea.w %0,%1,%2" : "=&r"(ret) : "r"(value), "r"(ptr)); + return ret; +} + +static __device__ int atomic_add(_shared_ptr_ int* ptr, int value) { + bool fail = true; + int old_value; + while (fail) { + old_value = loada_float(ptr); + int new_value = old_value + value; + fail = storea_float(ptr, new_value); + } + return old_value; +} + +__global__ void speculate_recover_block(int* recover_block_list, // [bsz] + int* recover_len, + bool* stop_flags, + int* seq_lens_this_time, + const int* ori_seq_lens_encoder, + int* seq_lens_encoder, + const int* seq_lens_decoder, + int* block_tables, + int* free_list, + int* free_list_len, + int64_t* input_ids, + const int64_t* pre_ids, + const int64_t* step_idx, + const int* encoder_block_lens, + const int* used_list_len, + const int64_t* next_tokens, + const int64_t* first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length) { + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + if (clusterid != 0) return; + + // 128 = seq_len(8192) / block_size(64) + // 每次最多处理block_table数量为128 + const int block_table_now_len = 128; + int block_table_now[block_table_now_len]; + // max_seq_len == length + // max_seq_len == pre_id_length + + // 32k local memory per 4 core on kl2. + // No enough memory for 16382 input_ids. + const int buf_len = 256; + int64_t input_ids_now[buf_len]; + + bool flag_false = false; + + __shared__ int free_list_len_sm; + // 每次最多处理free_list数量为block_table_now_len + int free_list_now[block_table_now_len]; + if (cid == 0) { + GM2SM(free_list_len, &free_list_len_sm, sizeof(int)); + } + sync_cluster(); + + int recover_len_lm; + GM2LM(recover_len, &recover_len_lm, sizeof(int)); + + for (int bid = cid; bid < recover_len_lm; bid += ncores) { + int recover_id; + int ori_seq_len_encoder; + int step_idx_now; + int encoder_block_len; + int decoder_used_len; + int64_t next_token; + GM2LM(recover_block_list + bid, &recover_id, sizeof(int)); + GM2LM_ASYNC( + ori_seq_lens_encoder + recover_id, &ori_seq_len_encoder, sizeof(int)); + GM2LM_ASYNC(step_idx + recover_id, &step_idx_now, sizeof(int)); + GM2LM_ASYNC( + encoder_block_lens + recover_id, &encoder_block_len, sizeof(int)); + GM2LM_ASYNC(used_list_len + recover_id, &decoder_used_len, sizeof(int)); + GM2LM_ASYNC(next_tokens + recover_id, &next_token, sizeof(int64_t)); + mfence(); + + int seq_len = ori_seq_len_encoder + step_idx_now; + mfence(); + LM2GM_ASYNC(&seq_len, seq_lens_this_time + recover_id, sizeof(int)); + LM2GM_ASYNC(&seq_len, seq_lens_encoder + recover_id, sizeof(int)); + LM2GM_ASYNC(&flag_false, stop_flags + recover_id, sizeof(bool)); + mfence(); + // // next tokens + // LM2GM_ASYNC(&next_token, + // input_ids + recover_id * length + seq_len - 1, + // sizeof(int64_t)); + // set first prompt token + int64_t first_token_id; + GM2LM(first_token_ids + recover_id, &first_token_id, sizeof(int64_t)); + LM2GM_ASYNC( + &first_token_id, input_ids + recover_id * length, sizeof(int64_t)); + + int ori_free_list_len = atomic_add(&free_list_len_sm, -decoder_used_len); + // 恢复block table + for (int i = 0; i < decoder_used_len; i += block_table_now_len) { + int process_len = min(block_table_now_len, decoder_used_len - i); + GM2LM(free_list + ori_free_list_len - i - process_len, + free_list_now, + process_len * sizeof(int)); + for (int j = 0; j < process_len; j++) { + block_table_now[j] = free_list_now[process_len - 1 - j]; + } + mfence(); + LM2GM( + block_table_now, + block_tables + recover_id * block_num_per_seq + encoder_block_len + i, + process_len * sizeof(int)); + } + // 恢复input_ids + for (int i = 0; i < step_idx_now; i += buf_len) { + int real_len = min(buf_len, step_idx_now - i); + GM2LM(pre_ids + recover_id * pre_id_length + i + 1, + input_ids_now, + sizeof(int64_t) * real_len); + LM2GM(input_ids_now, + input_ids + recover_id * length + ori_seq_len_encoder + i, + sizeof(int64_t) * real_len); + } + mfence(); + } + + if (cid == 0) { + recover_len_lm = 0; + mfence(); + LM2GM_ASYNC(&recover_len_lm, recover_len, sizeof(int)); + SM2GM_ASYNC(&free_list_len_sm, free_list_len, sizeof(int)); + mfence(); + } +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu index 68eb2bd6068..4287c3e7d88 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu @@ -138,7 +138,8 @@ __global__ void speculate_verify( const int max_candidate_len, // scalar, 每个 verify token // 的最大候选数(用于验证或采样) const int verify_window, // scalar, TopK 验证窗口(允许连续 top1 匹配次数) - const bool prefill_one_step_stop) { + const bool prefill_one_step_stop, + const bool benchmark_mode) { const int cid = core_id(); const int64_t tid = cluster_id() * core_num() + core_id(); const int64_t nthreads = cluster_num() * core_num(); @@ -161,6 +162,9 @@ __global__ void speculate_verify( // printf("seq_lens_this_time[%d]-1: %d \n",bid, // seq_lens_this_time[bid]-1); for (; i < seq_lens_this_time[bid] - 1; i++) { + if(benchmark_mode){ + break; + } if (seq_lens_encoder[bid] != 0) { break; } @@ -300,33 +304,34 @@ __global__ void speculate_verify( } } } -#define SPECULATE_VERIFY_INSTANTIATE(ENABLE_TOPP, USE_TOPK) \ - template __global__ void speculate_verify( \ - int64_t * accept_tokens, \ - int *accept_num, \ - int64_t *step_idx, \ - bool *stop_flags, \ - const int *seq_lens_encoder, \ - const int *seq_lens_decoder, \ - const int64_t *draft_tokens, \ - const int *actual_draft_token_nums, \ - const float *dev_curand_states, \ - const float *topp, \ - const int *seq_lens_this_time, \ - const int64_t *verify_tokens, \ - const float *verify_scores, \ - const int64_t *max_dec_len, \ - const int64_t *end_tokens, \ - const bool *is_block_step, \ - const int *output_cum_offsets, \ - const int *actual_candidate_len, \ - int real_bsz, \ - int max_draft_tokens, \ - int end_length, \ - int max_seq_len, \ - int max_candidate_len, \ - int verify_window, \ - bool prefill_one_step_stop); +#define SPECULATE_VERIFY_INSTANTIATE(ENABLE_TOPP, USE_TOPK) \ + template __global__ void speculate_verify( \ + int64_t * accept_tokens, \ + int *accept_num, \ + int64_t *step_idx, \ + bool *stop_flags, \ + const int *seq_lens_encoder, \ + const int *seq_lens_decoder, \ + const int64_t *draft_tokens, \ + const int *actual_draft_token_nums, \ + const float *dev_curand_states, \ + const float *topp, \ + const int *seq_lens_this_time, \ + const int64_t *verify_tokens, \ + const float *verify_scores, \ + const int64_t *max_dec_len, \ + const int64_t *end_tokens, \ + const bool *is_block_step, \ + const int *output_cum_offsets, \ + const int *actual_candidate_len, \ + int real_bsz, \ + int max_draft_tokens, \ + int end_length, \ + int max_seq_len, \ + int max_candidate_len, \ + int verify_window, \ + bool prefill_one_step_stop, \ + bool benchmark_mode); SPECULATE_VERIFY_INSTANTIATE(true, true) SPECULATE_VERIFY_INSTANTIATE(true, false) SPECULATE_VERIFY_INSTANTIATE(false, true) diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess_v2.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess_v2.cpp new file mode 100644 index 00000000000..3eedc4e67f9 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess_v2.cpp @@ -0,0 +1,425 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl/launch_strategy.h" +#include "xpu/refactor/impl_public/wrapper_check.h" +#include "xpu/xdnn.h" + +namespace xpu3 { +namespace plugin { +__attribute__((global)) void draft_model_preprocess_v2( + int64_t* draft_tokens, + int64_t* input_ids, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + bool* not_need_stop, + bool* is_block_step, + bool* batch_drop, + int64_t* pre_ids, + const int64_t* accept_tokens, + const int* accept_num, + const int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const int* base_model_seq_lens_decoder, + const int64_t* base_model_step_idx, + const bool* base_model_stop_flags, + const bool* base_model_is_block_step, + int64_t* base_model_draft_tokens, + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1); +} // namespace plugin +} // namespace xpu3 + +namespace xpu2 { +namespace plugin {} // namespace plugin +} // namespace xpu2 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(api::Context* ctx, + int64_t* draft_tokens, + int64_t* input_ids, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + bool* not_need_stop, + bool* is_block_step, + bool* batch_drop, + int64_t* pre_ids, + const int64_t* accept_tokens, + const int* accept_num, + const int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const int* base_model_seq_lens_decoder, + const int64_t* base_model_step_idx, + const bool* base_model_stop_flags, + const bool* base_model_is_block_step, + int64_t* base_model_draft_tokens, + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { + int64_t not_stop_flag_sum = 0; + int64_t not_stop_flag = 0; + for (int tid = 0; tid < bsz; tid++) { + if (splitwise_prefill) { + auto* input_ids_now = input_ids + tid * input_ids_len; + auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len; + if (seq_lens_encoder[tid] > 0) { + not_stop_flag = 1; + int seq_len_encoder = seq_lens_encoder[tid]; + stop_flags[tid] = false; + int64_t base_model_first_token = accept_tokens_now[0]; + int position = seq_len_encoder; + if (truncate_first_token) { + input_ids_now[position - 1] = base_model_first_token; + seq_lens_this_time[tid] = seq_len_encoder; + } else { + input_ids_now[position] = base_model_first_token; + seq_lens_this_time[tid] = seq_len_encoder + 1; + } + } else { + stop_flags[tid] = true; + seq_lens_this_time[tid] = 0; + seq_lens_decoder[tid] = 0; + seq_lens_encoder[tid] = 0; + not_stop_flag = 0; + } + not_stop_flag_sum += not_stop_flag; + } else { + auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len; + auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len; + auto accept_num_now = accept_num[tid]; + auto* input_ids_now = input_ids + tid * input_ids_len; + auto* base_model_draft_tokens_now = + base_model_draft_tokens + tid * base_model_draft_tokens_len; + auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid]; + const int32_t base_model_seq_len_this_time = + base_model_seq_lens_this_time[tid]; + auto* pre_ids_now = pre_ids + tid * pre_ids_len; + for (int i = 1; i < base_model_draft_tokens_len; i++) { + base_model_draft_tokens_now[i] = -1; + } + if(kvcache_scheduler_v1) { + if (base_model_stop_flags[tid] && + base_model_is_block_step[tid]) { + stop_flags[tid] = true; + is_block_step[tid] = true; + // Need to continue infer + } + } else { + if (base_model_stop_flags[tid] && + base_model_is_block_step[tid]) { + batch_drop[tid] = true; + stop_flags[tid] = true; + } + } + + if (!(base_model_stop_flags[tid] || batch_drop[tid])) { + not_stop_flag = 1; + // prefill generation + if (seq_lens_encoder[tid] > 0) { + // Can be extended to first few tokens + int seq_len_encoder = seq_lens_encoder[tid]; + stop_flags[tid] = false; + int64_t base_model_first_token = accept_tokens_now[0]; + pre_ids_now[0] = base_model_first_token; + int position = seq_len_encoder; + if (truncate_first_token) { + input_ids_now[position - 1] = base_model_first_token; + seq_lens_this_time[tid] = seq_len_encoder; + } else { + input_ids_now[position] = base_model_first_token; + seq_lens_this_time[tid] = seq_len_encoder + 1; + } + } else { // decode generation + if(kvcache_scheduler_v1) { + // 3. try to recover mtp infer in V1 mode + if (!base_model_is_block_step[tid] && + is_block_step[tid]) { + is_block_step[tid] = false; + } + } + if (stop_flags[tid]) { + stop_flags[tid] = false; + // TODO: check + seq_lens_decoder[tid] = base_model_seq_len_decoder - + base_model_seq_len_this_time; + step_idx[tid] = base_model_step_idx[tid] - + base_model_seq_len_this_time; + } else { + // 2: Last base model generated token and first MTP + // token + seq_lens_decoder[tid] -= num_model_step - 1; + step_idx[tid] -= num_model_step - 1; + } + for (int i = 0; i < accept_num_now; i++) { + draft_tokens_now[i] = accept_tokens_now[i]; + const int pre_id_pos = + base_model_step_idx[tid] - (accept_num_now - i); + const int64_t accept_token = accept_tokens_now[i]; + pre_ids_now[pre_id_pos] = accept_token; + } + seq_lens_this_time[tid] = accept_num_now; + } + } else { + stop_flags[tid] = true; + seq_lens_this_time[tid] = 0; + seq_lens_decoder[tid] = 0; + seq_lens_encoder[tid] = 0; + } + not_stop_flag_sum += not_stop_flag; + } + } + not_need_stop[0] = not_stop_flag_sum > 0; + return api::SUCCESS; +} + +static int xpu3_wrapper(api::Context* ctx, + int64_t* draft_tokens, + int64_t* input_ids, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + bool* not_need_stop, + bool* is_block_step, + bool* batch_drop, + int64_t* pre_ids, + const int64_t* accept_tokens, + const int* accept_num, + const int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const int* base_model_seq_lens_decoder, + const int64_t* base_model_step_idx, + const bool* base_model_stop_flags, + const bool* base_model_is_block_step, + int64_t* base_model_draft_tokens, + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { + using XPU_INT64 = typename XPUIndexType::type; + + // NOTE: Don't change 16 to 64, because kernel use gsm + xpu3::plugin::draft_model_preprocess_v2<<<1, 64, ctx->xpu_stream>>>( + reinterpret_cast(draft_tokens), + reinterpret_cast(input_ids), + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + reinterpret_cast(step_idx), + not_need_stop, + is_block_step, + batch_drop, + reinterpret_cast(pre_ids), + reinterpret_cast(accept_tokens), + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + reinterpret_cast(base_model_step_idx), + base_model_stop_flags, + base_model_is_block_step, + reinterpret_cast(base_model_draft_tokens), + bsz, + num_model_step, + accept_tokens_len, + draft_tokens_len, + input_ids_len, + base_model_draft_tokens_len, + pre_ids_len, + truncate_first_token, + splitwise_prefill, + kvcache_scheduler_v1); + return api::SUCCESS; +} + +int draft_model_preprocess_v2(api::Context* ctx, + int64_t* draft_tokens, + int64_t* input_ids, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + bool* not_need_stop, + bool* is_block_step, + bool* batch_drop, + int64_t* pre_ids, + const int64_t* accept_tokens, + const int* accept_num, + const int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const int* base_model_seq_lens_decoder, + const int64_t* base_model_step_idx, + const bool* base_model_stop_flags, + const bool* base_model_is_block_step, + int64_t* base_model_draft_tokens, + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "draft_model_preprocess_v2", int64_t); + WRAPPER_DUMP_PARAM6(ctx, + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder); + WRAPPER_DUMP_PARAM5( + ctx, step_idx, not_need_stop, is_block_step, batch_drop, pre_ids); + WRAPPER_DUMP_PARAM3( + ctx, accept_tokens, accept_num, base_model_seq_lens_encoder); + WRAPPER_DUMP_PARAM4(ctx, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags); + WRAPPER_DUMP_PARAM3( + ctx, base_model_is_block_step, base_model_draft_tokens, bsz); + WRAPPER_DUMP_PARAM3( + ctx, num_model_step, accept_tokens_len, draft_tokens_len); + WRAPPER_DUMP_PARAM4(ctx, + input_ids_len, + base_model_draft_tokens_len, + pre_ids_len, + truncate_first_token); + WRAPPER_DUMP_PARAM2(ctx, splitwise_prefill, kvcache_scheduler_v1); + WRAPPER_DUMP(ctx); + + WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz * accept_tokens_len, accept_tokens); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz * input_ids_len, input_ids); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz * draft_tokens_len, draft_tokens); + WRAPPER_CHECK_PTR(ctx, + int64_t, + bsz * base_model_draft_tokens_len, + base_model_draft_tokens); + + WRAPPER_ASSERT_GT(ctx, bsz, 0); + WRAPPER_ASSERT_LT(ctx, accept_tokens_len, 128); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + accept_tokens_len, + draft_tokens_len, + input_ids_len, + base_model_draft_tokens_len, + pre_ids_len, + truncate_first_token, + splitwise_prefill, + kvcache_scheduler_v1); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + accept_tokens_len, + draft_tokens_len, + input_ids_len, + base_model_draft_tokens_len, + pre_ids_len, + truncate_first_token, + splitwise_prefill, + kvcache_scheduler_v1); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_dispatch_block.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_dispatch_block.cpp new file mode 100644 index 00000000000..68e6a3b3835 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_dispatch_block.cpp @@ -0,0 +1,224 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" +#include +#include + +namespace xpu3 { +namespace plugin { + +__attribute__((global)) void speculate_free_and_dispatch_block( + bool *stop_flags, int *seq_lens_this_time, int *seq_lens_decoder, + int *block_tables, int *encoder_block_lens, bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, int *recover_block_list, int *recover_len, + int *need_block_list, int *need_block_len, int *used_list_len, + int *free_list, int *free_list_len, int64_t *first_token_ids, + int *accept_num, const int bsz, + const int block_size, const int block_num_per_seq, + const int max_decoder_block_num, const int max_draft_tokens); + +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(Context *ctx, bool *stop_flags, int *seq_lens_this_time, + int *seq_lens_decoder, int *block_tables, + int *encoder_block_lens, bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, int *recover_block_list, int *recover_len, + int *need_block_list, int *need_block_len, + int *used_list_len, int *free_list, int *free_list_len, + int64_t *first_token_ids, int *accept_num, const int bsz, + const int block_size, const int block_num_per_seq, + const int max_decoder_block_num, const int max_draft_tokens) { + for (int i = 0; i < bsz; i++) { + int *block_table_now = block_tables + i * block_num_per_seq; + if (stop_flags[i] && !is_block_step[i]) { + // 回收block块 + const int encoder_block_len = encoder_block_lens[i]; + const int decoder_used_len = used_list_len[i]; + if (decoder_used_len > 0) { + const int ori_free_list_len = free_list_len[0]; + free_list_len[0] += decoder_used_len; + for (int j = 0; j < decoder_used_len; j++) { + free_list[ori_free_list_len + j] = + block_table_now[encoder_block_len + j]; + block_table_now[encoder_block_len + j] = -1; + } + encoder_block_lens[i] = 0; + used_list_len[i] = 0; + } + } else if (block_table_now[seq_lens_decoder[i] / block_size] == -1) { + // 统计需要分配block的位置和总数 + const int ori_need_block_len = need_block_len[0]; + need_block_len[0] += 1; + need_block_list[ori_need_block_len] = i; + } + } + + while (need_block_len[0] > free_list_len[0]) { + // 调度block,根据used_list_len从大到小回收block,直到满足need_block_len + int max_used_list_len_id = 0; + int max_used_list_len = 0; + for (int i = 0; i < bsz; i++) { + const int used_block_num = !is_block_step[i] ? used_list_len[i] : 0; + if (used_block_num > max_used_list_len) { + max_used_list_len_id = i; + max_used_list_len = used_block_num; + } + } + + const int encoder_block_len = encoder_block_lens[max_used_list_len_id]; + int *block_table_now = + block_tables + max_used_list_len_id * block_num_per_seq; + for (int i = 0; i < max_used_list_len; i++) { + free_list[free_list_len[0] + i] = + block_table_now[encoder_block_len + i]; + block_table_now[encoder_block_len + i] = -1; + } + step_block_list[step_len[0]] = max_used_list_len_id; + step_len[0] += 1; + free_list_len[0] += max_used_list_len; + stop_flags[max_used_list_len_id] = true; + is_block_step[max_used_list_len_id] = true; + seq_lens_this_time[max_used_list_len_id] = 0; + seq_lens_decoder[max_used_list_len_id] = 0; + } + + // 为需要block的位置分配block,每个位置分配一个block + for (int i = 0; i < bsz; i++) { + if (i < need_block_len[0]) { + const int need_block_id = need_block_list[i]; + if (!stop_flags[need_block_id]) { + // 如果需要的位置正好是上一步中被释放的位置,不做处理 + used_list_len[need_block_id] += 1; + const int ori_free_list_len = free_list_len[0]; + free_list_len[0]--; + int *block_table_now = + block_tables + need_block_id * block_num_per_seq; + block_table_now[seq_lens_decoder[need_block_id] / block_size] = + free_list[ori_free_list_len - 1]; + } + need_block_list[i] = -1; + } + } + + // 计算可以复原的query id + int ori_step_len = step_len[0]; + if (ori_step_len > 0) { + int ori_free_list_len = free_list_len[0]; + int ori_step_block_id = step_block_list[ori_step_len - 1]; + int tmp_used_len = used_list_len[ori_step_block_id]; + // 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中) + int used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 + : tmp_used_len; + while (ori_step_len > 0 && ori_free_list_len >= used_len) { + recover_block_list[recover_len[0]] = ori_step_block_id; + is_block_step[ori_step_block_id] = false; + used_list_len[ori_step_block_id] = used_len; + ori_free_list_len -= used_len; + step_block_list[ori_step_len - 1] = -1; + step_len[0] -= 1; + recover_len[0] += 1; + ori_step_len = step_len[0]; + if (ori_step_len > 0) { + ori_step_block_id = step_block_list[ori_step_len - 1]; + tmp_used_len = used_list_len[ori_step_block_id]; + used_len = tmp_used_len < max_decoder_block_num + ? tmp_used_len + 1 + : tmp_used_len; + } + } + need_block_len[0] = 0; + } + return api::SUCCESS; +} + +static int xpu3_wrapper(Context *ctx, bool *stop_flags, int *seq_lens_this_time, + int *seq_lens_decoder, int *block_tables, + int *encoder_block_lens, bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, int *recover_block_list, + int *recover_len, int *need_block_list, + int *need_block_len, int *used_list_len, int *free_list, + int *free_list_len, int64_t *first_token_ids, int *accept_num, + const int bsz, const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, const int max_draft_tokens) { + using XPU_INT64 = typename XPUIndexType::type; + auto speculate_free_and_dispatch_block_kernel = xpu3::plugin::speculate_free_and_dispatch_block; + speculate_free_and_dispatch_block_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + stop_flags, seq_lens_this_time, seq_lens_decoder, block_tables, + encoder_block_lens, is_block_step, step_block_list, step_len, + recover_block_list, recover_len, need_block_list, need_block_len, + used_list_len, free_list, free_list_len, + reinterpret_cast(first_token_ids), accept_num, bsz, block_size, + block_num_per_seq, max_decoder_block_num, max_draft_tokens); + return api::SUCCESS; +} + +int speculate_free_and_dispatch_block(Context *ctx, bool *stop_flags, + int *seq_lens_this_time, int *seq_lens_decoder, + int *block_tables, int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, int *recover_block_list, + int *recover_len, int *need_block_list, + int *need_block_len, int *used_list_len, + int *free_list, int *free_list_len, + int64_t *first_token_ids, int *accept_num, + const int bsz, const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, const int max_draft_tokens) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_free_and_dispatch_block", float); + WRAPPER_DUMP_PARAM6(ctx, stop_flags, seq_lens_this_time, seq_lens_decoder, + block_tables, encoder_block_lens, is_block_step); + WRAPPER_DUMP_PARAM6(ctx, step_block_list, step_len, recover_block_list, + recover_len, need_block_list, need_block_len); + WRAPPER_DUMP_PARAM4(ctx, used_list_len, free_list, free_list_len, + first_token_ids); + WRAPPER_DUMP_PARAM4(ctx, bsz, block_size, block_num_per_seq, + max_decoder_block_num); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper( + ctx, stop_flags, seq_lens_this_time, seq_lens_decoder, block_tables, + encoder_block_lens, is_block_step, step_block_list, step_len, + recover_block_list, recover_len, need_block_list, need_block_len, + used_list_len, free_list, free_list_len, first_token_ids, accept_num, bsz, + block_size, block_num_per_seq, max_decoder_block_num, max_draft_tokens); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper( + ctx, stop_flags, seq_lens_this_time, seq_lens_decoder, block_tables, + encoder_block_lens, is_block_step, step_block_list, step_len, + recover_block_list, recover_len, need_block_list, need_block_len, + used_list_len, free_list, free_list_len, first_token_ids, accept_num, bsz, + block_size, block_num_per_seq, max_decoder_block_num, max_draft_tokens); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp index a0066e45579..0886a0196a8 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp @@ -42,6 +42,16 @@ __attribute__((global)) void speculate_get_padding_offset( const int max_seq_len, int bsz); +__attribute__((global)) void speculate_get_padding_offset_v2( + int* batch_id_per_token, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + int bsz); + } // namespace plugin } // namespace xpu3 @@ -99,6 +109,29 @@ static int cpu_wrapper_get_padding_offset(Context* ctx, return api::SUCCESS; } + +static int cpu_wrapper_get_padding_offset_v2(Context* ctx, + int* batch_id_per_token, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + int bsz) { + for (int bi = 0; bi < bsz; ++bi) { + int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; + for (int i = 0; i < seq_lens[bi]; i++) { + batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi; + } + cum_offsets_out[bi] = cum_offset; + int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi]; + cu_seqlens_q[bi + 1] = cum_seq_len; + cu_seqlens_k[bi + 1] = cum_seq_len; + } + return api::SUCCESS; +} + template static int xpu3_wrapper_remove_padding(Context* ctx, T* output_data, @@ -150,6 +183,29 @@ static int xpu3_wrapper_get_padding_offset(Context* ctx, return api::SUCCESS; } +static int xpu3_wrapper_get_padding_offset_v2(Context* ctx, + int* batch_id_per_token, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + int bsz) { + xpu3::plugin:: + speculate_get_padding_offset_v2<<ncluster(), 64, ctx->xpu_stream>>>( + batch_id_per_token, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + cum_offsets, + seq_lens, + max_seq_len, + bsz); + return api::SUCCESS; +} + + template int speculate_remove_padding(Context* ctx, T* x_remove_padding, @@ -271,6 +327,63 @@ int speculate_get_padding_offset(Context* ctx, WRAPPER_UNIMPLEMENTED(ctx); } +int speculate_get_padding_offset_v2(Context* ctx, + int* batch_id_per_token, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + int bsz) { + WRAPPER_CHECK_CTX(ctx); + + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_get_padding_offset", float); + WRAPPER_DUMP_PARAM6(ctx, + batch_id_per_token, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + cum_offsets, + seq_lens); + WRAPPER_DUMP_PARAM2(ctx, max_seq_len, bsz); + WRAPPER_DUMP(ctx); + + WRAPPER_CHECK_PTR(ctx, int, bsz, cum_offsets); + WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens); + WRAPPER_CHECK_PTR(ctx, int, bsz, cum_offsets_out); + WRAPPER_CHECK_PTR(ctx, int, bsz + 1, cu_seqlens_q); + WRAPPER_CHECK_PTR(ctx, int, bsz + 1, cu_seqlens_k); + + WRAPPER_ASSERT_GT(ctx, bsz, 0); + WRAPPER_ASSERT_GT(ctx, max_seq_len, 0); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper_get_padding_offset_v2(ctx, + batch_id_per_token, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + cum_offsets, + seq_lens, + max_seq_len, + bsz); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper_get_padding_offset_v2(ctx, + batch_id_per_token, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + cum_offsets, + seq_lens, + max_seq_len, + bsz); + } + + WRAPPER_UNIMPLEMENTED(ctx); +} + #define INSTANTIATION_SPECULATE_REMOVE_PADDING(T) \ template int speculate_remove_padding(Context * ctx, \ T * x_remove_padding, \ diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp new file mode 100644 index 00000000000..bd6e8c2212d --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp @@ -0,0 +1,257 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { + +__attribute__((global)) void speculate_recover_block(int *recover_block_list, // [bsz] + int *recover_len, + bool *stop_flags, + int *seq_lens_this_time, + const int *ori_seq_lens_encoder, + int *seq_lens_encoder, + const int *seq_lens_decoder, + int *block_tables, + int *free_list, + int *free_list_len, + int64_t *input_ids, + const int64_t *pre_ids, + const int64_t *step_idx, + const int *encoder_block_lens, + const int *used_list_len, + const int64_t *next_tokens, + const int64_t *first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length); + +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(Context *ctx, + int *recover_block_list, // [bsz] + int *recover_len, + bool *stop_flags, + int *seq_lens_this_time, + const int *ori_seq_lens_encoder, + int *seq_lens_encoder, + const int *seq_lens_decoder, + int *block_tables, + int *free_list, + int *free_list_len, + int64_t *input_ids, + const int64_t *pre_ids, + const int64_t *step_idx, + const int *encoder_block_lens, + const int *used_list_len, + const int64_t *next_tokens, + const int64_t *first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length) { + for (int bid = 0; bid < recover_len[0]; bid++) { + const int recover_id = recover_block_list[bid]; + const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id]; + const int step_idx_now = step_idx[recover_id]; + const int seq_len = ori_seq_len_encoder + step_idx_now; + const int encoder_block_len = encoder_block_lens[recover_id]; + const int decoder_used_len = used_list_len[recover_id]; + int *block_table_now = block_tables + recover_id * block_num_per_seq; + int64_t *input_ids_now = input_ids + recover_id * length; + const int64_t *pre_ids_now = pre_ids + recover_id * pre_id_length; + + seq_lens_this_time[recover_id] = seq_len; + seq_lens_encoder[recover_id] = seq_len; + stop_flags[recover_id] = false; + // input_ids_now[seq_len - 1] = next_tokens[recover_id]; // next tokens + input_ids_now[0] = first_token_ids[recover_id]; // set first prompt token + int ori_free_list_len = free_list_len[0]; + free_list_len[0] -= decoder_used_len; + + // 恢复block table + for (int i = 0; i < decoder_used_len; i++) { + block_table_now[encoder_block_len + i] = + free_list[ori_free_list_len - i - 1]; + } + // 恢复input_ids + for (int i = 0; i < step_idx_now; i++) { + input_ids_now[ori_seq_len_encoder + i] = pre_ids_now[i + 1]; + } + } + recover_len[0] = 0; + return api::SUCCESS; +} + +static int xpu3_wrapper(Context *ctx, + int *recover_block_list, // [bsz] + int *recover_len, + bool *stop_flags, + int *seq_lens_this_time, + const int *ori_seq_lens_encoder, + int *seq_lens_encoder, + const int *seq_lens_decoder, + int *block_tables, + int *free_list, + int *free_list_len, + int64_t *input_ids, + const int64_t *pre_ids, + const int64_t *step_idx, + const int *encoder_block_lens, + const int *used_list_len, + const int64_t *next_tokens, + const int64_t *first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length) { + using XPU_INT64 = typename XPUIndexType::type; + auto recover_block_kernel = xpu3::plugin::speculate_recover_block; + recover_block_kernel<<ncluster(), 64, ctx->xpu_stream>>>( + recover_block_list, // [bsz] + recover_len, + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + seq_lens_encoder, + seq_lens_decoder, + block_tables, + free_list, + free_list_len, + reinterpret_cast(input_ids), + reinterpret_cast(pre_ids), + reinterpret_cast(step_idx), + encoder_block_lens, + used_list_len, + reinterpret_cast(next_tokens), + reinterpret_cast(first_token_ids), + bsz, + block_num_per_seq, + length, + pre_id_length); + return api::SUCCESS; +} + +int speculate_recover_block(Context *ctx, + int *recover_block_list, // [bsz] + int *recover_len, + bool *stop_flags, + int *seq_lens_this_time, + const int *ori_seq_lens_encoder, + int *seq_lens_encoder, + const int *seq_lens_decoder, + int *block_tables, + int *free_list, + int *free_list_len, + int64_t *input_ids, + const int64_t *pre_ids, + const int64_t *step_idx, + const int *encoder_block_lens, + const int *used_list_len, + const int64_t *next_tokens, + const int64_t *first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_recover_block", float); + WRAPPER_DUMP_PARAM6(ctx, + recover_block_list, + recover_len, + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + seq_lens_encoder); + WRAPPER_DUMP_PARAM6(ctx, + seq_lens_decoder, + block_tables, + free_list, + free_list_len, + input_ids, + pre_ids); + WRAPPER_DUMP_PARAM5(ctx, + step_idx, + encoder_block_lens, + used_list_len, + next_tokens, + first_token_ids); + WRAPPER_DUMP_PARAM4(ctx, bsz, block_num_per_seq, length, pre_id_length); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + recover_block_list, // [bsz] + recover_len, + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + seq_lens_encoder, + seq_lens_decoder, + block_tables, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + encoder_block_lens, + used_list_len, + next_tokens, + first_token_ids, + bsz, + block_num_per_seq, + length, + pre_id_length); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + recover_block_list, // [bsz] + recover_len, + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + seq_lens_encoder, + seq_lens_decoder, + block_tables, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + encoder_block_lens, + used_list_len, + next_tokens, + first_token_ids, + bsz, + block_num_per_seq, + length, + pre_id_length); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp index c5e3e425b7b..c9571bd513c 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp @@ -48,7 +48,8 @@ __attribute__((global)) void speculate_verify( const int max_seq_len, const int max_candidate_len, const int verify_window, - const bool prefill_one_step_stop); + const bool prefill_one_step_stop, + const bool benchmark_mode); } // namespace plugin } // namespace xpu3 @@ -136,14 +137,15 @@ static int cpu_wrapper(Context *ctx, const int max_seq_len, const int max_candidate_len, const int verify_window, - const bool prefill_one_step_stop) { + const bool prefill_one_step_stop, + const bool benchmark_mode) { for (int bid = 0; bid < real_bsz; ++bid) { - const int start_token_id = bid * max_seq_len - output_cum_offsets[bid]; // verify and set stop flags int accept_num_now = 1; int stop_flag_now_int = 0; if (!(is_block_step[bid] || bid >= real_bsz)) { + const int start_token_id = bid * max_seq_len - output_cum_offsets[bid]; // printf("debug cpu bid:%d,start_token_id:%d\n",bid, start_token_id); // printf("bid %d\n", bid); @@ -160,6 +162,9 @@ static int cpu_wrapper(Context *ctx, // printf("seq_lens_this_time[%d]-1: %d \n",bid, // seq_lens_this_time[bid]-1); for (; i < seq_lens_this_time[bid] - 1; i++) { + if(benchmark_mode){ + break; + } if (seq_lens_encoder[bid] != 0) { break; } @@ -326,7 +331,8 @@ static int xpu3_wrapper(Context *ctx, const int max_seq_len, const int max_candidate_len, const int verify_window, - const bool prefill_one_step_stop) { + const bool prefill_one_step_stop, + const bool benchmark_mode) { using XPU_INT64 = typename XPUIndexType::type; xpu3::plugin::speculate_verify <<<1, 64, ctx->xpu_stream>>>( @@ -354,7 +360,8 @@ static int xpu3_wrapper(Context *ctx, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop); + prefill_one_step_stop, + benchmark_mode); return api::SUCCESS; } template @@ -383,7 +390,8 @@ int speculate_verify(Context *ctx, const int max_seq_len, const int max_candidate_len, const int verify_window, - const bool prefill_one_step_stop) { + const bool prefill_one_step_stop, + const bool benchmark_mode) { WRAPPER_CHECK_CTX(ctx); WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_verify", int64_t); WRAPPER_DUMP_PARAM3(ctx, accept_tokens, accept_num, step_idx); @@ -406,12 +414,13 @@ int speculate_verify(Context *ctx, actual_candidate_len, real_bsz, max_draft_tokens); - WRAPPER_DUMP_PARAM5(ctx, + WRAPPER_DUMP_PARAM6(ctx, end_length, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop); + prefill_one_step_stop, + benchmark_mode); WRAPPER_DUMP(ctx); WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_draft_tokens, accept_tokens); WRAPPER_CHECK_PTR(ctx, int, real_bsz, accept_num); @@ -469,7 +478,8 @@ int speculate_verify(Context *ctx, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop); + prefill_one_step_stop, + benchmark_mode); } if (ctx->dev().type() == api::kXPU3) { return xpu3_wrapper(ctx, @@ -497,40 +507,42 @@ int speculate_verify(Context *ctx, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop); + prefill_one_step_stop, + benchmark_mode); } WRAPPER_UNIMPLEMENTED(ctx); } -#define INSTANTIATE_SPECULATE_VERIFY(ENABLE_TOPP, USE_TOPK) \ - template int \ - baidu::xpu::api::plugin::speculate_verify( \ - baidu::xpu::api::Context *, /* xpu_ctx */ \ - int64_t *, /* accept_tokens */ \ - int *, /* accept_num */ \ - int64_t *, /* step_idx */ \ - bool *, /* stop_flags */ \ - const int *, /* seq_lens_encoder */ \ - const int *, /* seq_lens_decoder */ \ - const int64_t *, /* draft_tokens */ \ - const int *, /* actual_draft_token_nums */ \ - const float *, /* dev_curand_states or topp */ \ - const float *, /* topp or nullptr */ \ - const int *, /* seq_lens_this_time */ \ - const int64_t *, /* verify_tokens */ \ - const float *, /* verify_scores */ \ - const int64_t *, /* max_dec_len */ \ - const int64_t *, /* end_tokens */ \ - const bool *, /* is_block_step */ \ - const int *, /* output_cum_offsets */ \ - const int *, /* actual_candidate_len */ \ - int, /* real_bsz */ \ - int, /* max_draft_tokens */ \ - int, /* end_length */ \ - int, /* max_seq_len */ \ - int, /* max_candidate_len */ \ - int, /* verify_window */ \ - bool); /* prefill_one_step_stop */ +#define INSTANTIATE_SPECULATE_VERIFY(ENABLE_TOPP, USE_TOPK) \ + template int \ + baidu::xpu::api::plugin::speculate_verify( \ + baidu::xpu::api::Context *, /* xpu_ctx */ \ + int64_t *, /* accept_tokens */ \ + int *, /* accept_num */ \ + int64_t *, /* step_idx */ \ + bool *, /* stop_flags */ \ + const int *, /* seq_lens_encoder */ \ + const int *, /* seq_lens_decoder */ \ + const int64_t *, /* draft_tokens */ \ + const int *, /* actual_draft_token_nums */ \ + const float *, /* dev_curand_states or topp */ \ + const float *, /* topp or nullptr */ \ + const int *, /* seq_lens_this_time */ \ + const int64_t *, /* verify_tokens */ \ + const float *, /* verify_scores */ \ + const int64_t *, /* max_dec_len */ \ + const int64_t *, /* end_tokens */ \ + const bool *, /* is_block_step */ \ + const int *, /* output_cum_offsets */ \ + const int *, /* actual_candidate_len */ \ + int, /* real_bsz */ \ + int, /* max_draft_tokens */ \ + int, /* end_length */ \ + int, /* max_seq_len */ \ + int, /* max_candidate_len */ \ + int, /* verify_window */ \ + bool, \ + bool); /* prefill_one_step_stop */ INSTANTIATE_SPECULATE_VERIFY(false, false) INSTANTIATE_SPECULATE_VERIFY(false, true) diff --git a/custom_ops/xpu_ops/test/test_speculate_step.py b/custom_ops/xpu_ops/test/test_speculate_step.py new file mode 100644 index 00000000000..d3bdb212c4a --- /dev/null +++ b/custom_ops/xpu_ops/test/test_speculate_step.py @@ -0,0 +1,189 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import step_paddle, speculate_step_paddle + +np.random.seed(2023) + +max_bs = 128 +bs = max_bs +max_seq_len = 8192 +block_size = 64 +block_bs = 8 +block_ratio = 0.75 +max_draft_tokens = 1 + +stop_flags = np.random.randint(0, 2, [max_bs]).astype("bool") +seq_lens_this_time = np.zeros([bs], "int32") +seq_lens_encoder = np.zeros([max_bs], "int32") +seq_lens_decoder = np.zeros([max_bs], "int32") +step_idx = np.zeros([max_bs], "int64") +accept_num = np.random.randint(1, 3, [max_bs]).astype("int32") +for i in range(bs): + seq_lens_decoder[i] = 2 + i * 2 + seq_lens_this_time[i] = 1 +ori_seq_lens_encoder = np.zeros([max_bs], "int32") +ori_seq_lens_encoder[:] = seq_lens_decoder[:] // 2 +step_idx = (seq_lens_decoder - ori_seq_lens_encoder).astype("int64") + +max_block_num = block_bs * max_seq_len // block_size +free_list_len = int(max_block_num * (1 - block_ratio)) +free_list_len = np.full([1], free_list_len, "int32") +free_list = np.arange(max_block_num - 1, max_block_num - free_list_len - 1, -1, dtype="int32") + +encoder_block_lens = np.zeros([max_bs], "int32") +used_list_len = np.zeros([max_bs], "int32") +block_tables = np.full([max_bs, 128], -1, "int32") +encoder_block_id = 0 +for i in range(bs): + enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size + encoder_block_lens[i] = enc_block_num + dec_block_num = (seq_lens_decoder[i] + block_size - 1) // block_size - enc_block_num + used_list_len[i] = dec_block_num + block_tables[i, :enc_block_num] = np.arange(encoder_block_id, encoder_block_id + enc_block_num, 1, "int32") + encoder_block_id += enc_block_num + if dec_block_num > 0: + block_tables[i, enc_block_num : enc_block_num + dec_block_num] = free_list[ + free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1 + ] + free_list[free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1] = -1 + free_list_len[0] -= dec_block_num +assert free_list_len[0] >= 0 + +is_block_step = np.zeros([max_bs], "bool") +is_block_step[:bs] = np.random.randint(0, 2, [bs]).astype("bool") +step_block_list = np.full([max_bs], -1, "int32") +step_lens = np.full([1], 0, "int32") +for i in range(bs): + if is_block_step[i]: + step_block_list[step_lens[0]] = i + step_lens[0] += 1 + +recover_lens = np.full([1], 0, "int32") +recover_block_list = np.full([max_bs], -1, "int32") + +need_block_len = np.full([1], 0, "int32") +need_block_list = np.full([max_bs], -1, "int32") + +input_ids = np.random.randint(0, 1000, [max_bs, max_seq_len], "int64") +pre_ids = np.random.randint(0, 1000, [max_bs, max_seq_len], "int64") + +next_tokens = np.random.randint(0, 1000, [max_bs], "int64") +encoder_decoder_block_num = 1 +first_token_ids = np.random.randint(0, 1000, [max_bs], "int64") + +stop_flags = paddle.to_tensor(stop_flags) +seq_lens_this_time = paddle.to_tensor(seq_lens_this_time) +seq_lens_encoder = paddle.to_tensor(seq_lens_encoder) +seq_lens_decoder = paddle.to_tensor(seq_lens_decoder) +ori_seq_lens_encoder = paddle.to_tensor(ori_seq_lens_encoder) +block_tables = paddle.to_tensor(block_tables) +encoder_block_lens = paddle.to_tensor(encoder_block_lens) +is_block_step = paddle.to_tensor(is_block_step) +step_block_list = paddle.to_tensor(step_block_list) +step_lens = paddle.to_tensor(step_lens) +recover_lens = paddle.to_tensor(recover_lens) +recover_block_list = paddle.to_tensor(recover_block_list) +need_block_list = paddle.to_tensor(need_block_list) +need_block_len = paddle.to_tensor(need_block_len) +used_list_len = paddle.to_tensor(used_list_len) +free_list_len = paddle.to_tensor(free_list_len) +free_list = paddle.to_tensor(free_list) +input_ids = paddle.to_tensor(input_ids) +pre_ids = paddle.to_tensor(pre_ids) +step_idx = paddle.to_tensor(step_idx) +next_tokens = paddle.to_tensor(next_tokens) +first_token_ids = paddle.to_tensor(first_token_ids) +accept_num = paddle.to_tensor(accept_num) + +print("-" * 50 + "before step op" + "-" * 50) +print("stop_flags: ", stop_flags) +print("seq_lens_this_time: ", seq_lens_this_time) +print("seq_lens_encoder: ", seq_lens_encoder) +print("seq_lens_decoder: ", seq_lens_decoder) +print("ori_seq_lens_encoder: ", ori_seq_lens_encoder) +print("block_tables: ", block_tables) +print("encoder_block_lens: ", encoder_block_lens) +print("is_block_step: ", is_block_step) +print("step_block_list: ", step_block_list) +print("step_lens: ", step_lens) +print("recover_lens: ", recover_lens) +print("recover_block_list: ", recover_block_list) +print("need_block_list: ", need_block_list) +print("need_block_len: ", need_block_len) +print("used_list_len: ", used_list_len) +print("free_list_len: ", free_list_len) +print("free_list: ", free_list) +print("input_ids: ", input_ids) +print("pre_ids: ", pre_ids) +print("step_idx: ", step_idx) +print("next_tokens: ", next_tokens) +print("accept_num: ", accept_num) + +speculate_step_paddle( + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + seq_lens_encoder, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_lens, + recover_block_list, + recover_lens, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + next_tokens, + first_token_ids, + accept_num, + block_size, + encoder_decoder_block_num, + max_draft_tokens, +) + +print("-" * 50 + "after step op" + "-" * 50) +print("stop_flags: ", stop_flags) +print("seq_lens_this_time: ", seq_lens_this_time) +print("seq_lens_encoder: ", seq_lens_encoder) +print("seq_lens_decoder: ", seq_lens_decoder) +print("ori_seq_lens_encoder: ", ori_seq_lens_encoder) +print("block_tables: ", block_tables) +print("encoder_block_lens: ", encoder_block_lens) +print("is_block_step: ", is_block_step) +print("step_block_list: ", step_block_list) +print("step_lens: ", step_lens) +print("recover_lens: ", recover_lens) +print("recover_block_list: ", recover_block_list) +print("need_block_list: ", need_block_list) +print("need_block_len: ", need_block_len) +print("used_list_len: ", used_list_len) +print("free_list_len: ", free_list_len) +print("free_list: ", free_list) +print("input_ids: ", input_ids) +print("pre_ids: ", pre_ids) +print("step_idx: ", step_idx) +print("next_tokens: ", next_tokens) +print("first_token_ids: ", first_token_ids) +print("accept_num: ", accept_num) From 2efd8dd66f4dfcab021216079290fe3d36cd21f9 Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Sat, 1 Nov 2025 17:30:47 +0000 Subject: [PATCH 02/17] [XPU] support kernel for mtp(base) --- custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu index 1eebc175b66..ae4555f5fe4 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu @@ -111,7 +111,7 @@ __global__ void speculate_verify(const int64_t *sampled_token_ids, auto *draft_tokens_now = draft_tokens + bid * max_draft_tokens; auto *actual_candidate_len_now = actual_candidate_len + start_token_id; auto *sampled_token_id_now = sampled_token_ids + start_token_id; -draft_model_update + int i = 0; // printf("seq_lens_this_time[%d]-1: %d \n",bid, // seq_lens_this_time[bid]-1); From 1fb377b431a4ed7faa0389191de420f695da8070 Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Sat, 1 Nov 2025 17:34:17 +0000 Subject: [PATCH 03/17] format --- .../src/ops/mtp/draft_model_preprocess_v2.cc | 52 +- .../src/ops/mtp/speculate_step_paddle.cc | 151 +++--- .../mtp_kernel/draft_model_preprocess_v2.xpu | 380 +++++++------- .../speculate_free_and_dispatch_block.xpu | 48 +- .../mtp_wrapper/draft_model_preprocess_v2.cpp | 476 +++++++++--------- .../speculate_free_and_dispatch_block.cpp | 468 ++++++++++------- .../mtp_wrapper/speculate_recover_block.cpp | 43 +- .../xpu_ops/test/test_speculate_step.py | 2 +- 8 files changed, 866 insertions(+), 754 deletions(-) diff --git a/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess_v2.cc b/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess_v2.cc index d97e28f68f7..c2eb3313b27 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess_v2.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess_v2.cc @@ -23,30 +23,29 @@ namespace api = baidu::xpu::api; void DraftModelPreprocessV2(const paddle::Tensor& draft_tokens, - const paddle::Tensor& input_ids, - const paddle::Tensor& stop_flags, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& step_idx, - const paddle::Tensor& not_need_stop, - const paddle::Tensor& is_block_step, - const paddle::Tensor& batch_drop, - const paddle::Tensor& pre_ids, - const paddle::Tensor& accept_tokens, - const paddle::Tensor& accept_num, - const paddle::Tensor& base_model_seq_lens_this_time, - const paddle::Tensor& base_model_seq_lens_encoder, - const paddle::Tensor& base_model_seq_lens_decoder, - const paddle::Tensor& base_model_step_idx, - const paddle::Tensor& base_model_stop_flags, - const paddle::Tensor& base_model_is_block_step, - const paddle::Tensor& base_model_draft_tokens, - const int num_model_step, - const bool truncate_first_token, - const bool splitwise_prefill, - const bool kvcache_scheduler_v1) { - + const paddle::Tensor& input_ids, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_idx, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& is_block_step, + const paddle::Tensor& batch_drop, + const paddle::Tensor& pre_ids, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& base_model_seq_lens_this_time, + const paddle::Tensor& base_model_seq_lens_encoder, + const paddle::Tensor& base_model_seq_lens_decoder, + const paddle::Tensor& base_model_step_idx, + const paddle::Tensor& base_model_stop_flags, + const paddle::Tensor& base_model_is_block_step, + const paddle::Tensor& base_model_draft_tokens, + const int num_model_step, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); api::Context* ctx = static_cast(dev_ctx)->x_context(); @@ -134,7 +133,10 @@ PD_BUILD_STATIC_OP(draft_model_preprocess_v2) "not_need_stop_out", "batch_drop_out", "pre_ids_out"}) - .Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool", "kvcache_scheduler_v1: bool"}) + .Attrs({"num_model_step: int", + "truncate_first_token: bool", + "splitwise_prefill: bool", + "kvcache_scheduler_v1: bool"}) .SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, {"input_ids", "input_ids_out"}, {"stop_flags", "stop_flags_out"}, diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc index eaec2c6958b..bdffa3b79a4 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc @@ -12,92 +12,109 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include "paddle/extension.h" #include "paddle/phi/core/enforce.h" #include "xpu/plugin.h" -#include #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif void SpeculateStepPaddle( - const paddle::Tensor &stop_flags, const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &ori_seq_lens_encoder, const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] const paddle::Tensor &encoder_block_lens, - const paddle::Tensor &is_block_step, const paddle::Tensor &step_block_list, - const paddle::Tensor &step_lens, const paddle::Tensor &recover_block_list, - const paddle::Tensor &recover_lens, const paddle::Tensor &need_block_list, - const paddle::Tensor &need_block_len, const paddle::Tensor &used_list_len, - const paddle::Tensor &free_list, const paddle::Tensor &free_list_len, - const paddle::Tensor &input_ids, const paddle::Tensor &pre_ids, - const paddle::Tensor &step_idx, const paddle::Tensor &next_tokens, - const paddle::Tensor &first_token_ids, const paddle::Tensor &accept_num, - const int block_size, const int encoder_decoder_block_num, + const paddle::Tensor &is_block_step, + const paddle::Tensor &step_block_list, + const paddle::Tensor &step_lens, + const paddle::Tensor &recover_block_list, + const paddle::Tensor &recover_lens, + const paddle::Tensor &need_block_list, + const paddle::Tensor &need_block_len, + const paddle::Tensor &used_list_len, + const paddle::Tensor &free_list, + const paddle::Tensor &free_list_len, + const paddle::Tensor &input_ids, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &next_tokens, + const paddle::Tensor &first_token_ids, + const paddle::Tensor &accept_num, + const int block_size, + const int encoder_decoder_block_num, const int max_draft_tokens) { - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = - paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); - const int bsz = seq_lens_this_time.shape()[0]; - PADDLE_ENFORCE_LE( - bsz, 640, - phi::errors::InvalidArgument( - "Only support bsz <= 640, but received bsz is %d", bsz)); - const int block_num_per_seq = block_tables.shape()[1]; - const int length = input_ids.shape()[1]; - const int pre_id_length = pre_ids.shape()[1]; - const int max_decoder_block_num = pre_id_length / block_size; - int r = baidu::xpu::api::plugin::speculate_free_and_dispatch_block( - xpu_ctx->x_context(), const_cast(stop_flags.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_decoder.data()), - const_cast(block_tables.data()), - const_cast(encoder_block_lens.data()), - const_cast(is_block_step.data()), - const_cast(step_block_list.data()), - const_cast(step_lens.data()), + const int bsz = seq_lens_this_time.shape()[0]; + PADDLE_ENFORCE_LE( + bsz, + 640, + phi::errors::InvalidArgument( + "Only support bsz <= 640, but received bsz is %d", bsz)); + const int block_num_per_seq = block_tables.shape()[1]; + const int length = input_ids.shape()[1]; + const int pre_id_length = pre_ids.shape()[1]; + const int max_decoder_block_num = pre_id_length / block_size; + int r = baidu::xpu::api::plugin::speculate_free_and_dispatch_block( + xpu_ctx->x_context(), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(encoder_block_lens.data()), + const_cast(is_block_step.data()), + const_cast(step_block_list.data()), + const_cast(step_lens.data()), + const_cast(recover_block_list.data()), + const_cast(recover_lens.data()), + const_cast(need_block_list.data()), + const_cast(need_block_len.data()), + const_cast(used_list_len.data()), + const_cast(free_list.data()), + const_cast(free_list_len.data()), + const_cast(first_token_ids.data()), + const_cast(accept_num.data()), + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num, + max_draft_tokens); + PD_CHECK(r == 0, "speculate_free_and_dispatch_block failed."); + auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false); + int recover_lens_cpu_data = recover_lens_cpu.data()[0]; + if (recover_lens_cpu_data > 0) { + r = baidu::xpu::api::plugin::speculate_recover_block( + xpu_ctx->x_context(), const_cast(recover_block_list.data()), const_cast(recover_lens.data()), - const_cast(need_block_list.data()), - const_cast(need_block_len.data()), - const_cast(used_list_len.data()), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + ori_seq_lens_encoder.data(), + const_cast(seq_lens_encoder.data()), + seq_lens_decoder.data(), + const_cast(block_tables.data()), const_cast(free_list.data()), const_cast(free_list_len.data()), - const_cast(first_token_ids.data()), - const_cast(accept_num.data()), - bsz, - block_size, - block_num_per_seq, - max_decoder_block_num, - max_draft_tokens); - PD_CHECK(r == 0, "speculate_free_and_dispatch_block failed."); - auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false); - int recover_lens_cpu_data = recover_lens_cpu.data()[0]; - if (recover_lens_cpu_data > 0) { - r = baidu::xpu::api::plugin::speculate_recover_block( - xpu_ctx->x_context(), - const_cast(recover_block_list.data()), - const_cast(recover_lens.data()), - const_cast(stop_flags.data()), - const_cast(seq_lens_this_time.data()), - ori_seq_lens_encoder.data(), - const_cast(seq_lens_encoder.data()), - seq_lens_decoder.data(), - const_cast(block_tables.data()), - const_cast(free_list.data()), - const_cast(free_list_len.data()), - const_cast(input_ids.data()), - pre_ids.data(), step_idx.data(), - encoder_block_lens.data(), used_list_len.data(), - next_tokens.data(), first_token_ids.data(), bsz, - block_num_per_seq, length, pre_id_length); - PD_CHECK(r == 0, "speculate_recover_block failed."); - } + const_cast(input_ids.data()), + pre_ids.data(), + step_idx.data(), + encoder_block_lens.data(), + used_list_len.data(), + next_tokens.data(), + first_token_ids.data(), + bsz, + block_num_per_seq, + length, + pre_id_length); + PD_CHECK(r == 0, "speculate_recover_block failed."); + } } PD_BUILD_STATIC_OP(speculate_step_paddle) diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess_v2.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess_v2.xpu index 052de1b7fc3..9d26919c33a 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess_v2.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess_v2.xpu @@ -37,213 +37,203 @@ __global__ void draft_model_preprocess_v2( const bool truncate_first_token, const bool splitwise_prefill, const bool kvcache_scheduler_v1) { - int cid = core_id(); - int ncores = core_num(); - int clusterid = cluster_id(); - int nclusters = cluster_num(); - int tid = clusterid * ncores + cid; - __shared__ int not_stop_flag_sm[64]; - not_stop_flag_sm[cid] = 0; - int64_t accept_tokens_now[128]; + int cid = core_id(); + int ncores = core_num(); + int clusterid = cluster_id(); + int nclusters = cluster_num(); + int tid = clusterid * ncores + cid; + __shared__ int not_stop_flag_sm[64]; + not_stop_flag_sm[cid] = 0; + int64_t accept_tokens_now[128]; - int value_zero = 0; - int64_t value_fu = -1; + int value_zero = 0; + int64_t value_fu = -1; - if (splitwise_prefill) { - for (; tid < bsz; tid += ncores * nclusters) { - int64_t base_model_step_idx_now = 0; - int seq_lens_encoder_now = 0; - int seq_lens_this_time_now = 0; - bool stop_flags_now = false; - int64_t base_model_first_token; - int seq_lens_encoder_record_now = 0; - int64_t input_ids_now = 0; + if (splitwise_prefill) { + for (; tid < bsz; tid += ncores * nclusters) { + int64_t base_model_step_idx_now = 0; + int seq_lens_encoder_now = 0; + int seq_lens_this_time_now = 0; + bool stop_flags_now = false; + int64_t base_model_first_token; + int seq_lens_encoder_record_now = 0; + int64_t input_ids_now = 0; - GM2LM_ASYNC(base_model_step_idx + tid, - &base_model_step_idx_now, - sizeof(int64_t)); - GM2LM_ASYNC( - seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int)); - GM2LM(accept_tokens + tid * accept_tokens_len, - &base_model_first_token, - sizeof(int64_t)); - if (seq_lens_encoder_now > 0) { - not_stop_flag_sm[cid] += 1; - stop_flags_now = false; - int position = seq_lens_encoder_now; - if (truncate_first_token) { - position = position - 1; - input_ids_now = base_model_first_token; - seq_lens_this_time_now = seq_lens_encoder_now; - } else { - input_ids_now = base_model_first_token; - seq_lens_this_time_now = seq_lens_encoder_now + 1; - } - LM2GM_ASYNC(&input_ids_now, - input_ids + tid * input_ids_len + position, - sizeof(int64_t)); - } else { - stop_flags_now = true; - seq_lens_this_time_now = 0; - seq_lens_encoder_now = 0; - not_stop_flag_sm[cid] += 0; - LM2GM_ASYNC(&value_zero, seq_lens_decoder + tid, sizeof(int)); - } - LM2GM_ASYNC( - &seq_lens_encoder_now, seq_lens_encoder + tid, sizeof(int)); - LM2GM_ASYNC(&stop_flags_now, stop_flags + tid, sizeof(bool)); - LM2GM( - &seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int)); + GM2LM_ASYNC( + base_model_step_idx + tid, &base_model_step_idx_now, sizeof(int64_t)); + GM2LM_ASYNC(seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int)); + GM2LM(accept_tokens + tid * accept_tokens_len, + &base_model_first_token, + sizeof(int64_t)); + if (seq_lens_encoder_now > 0) { + not_stop_flag_sm[cid] += 1; + stop_flags_now = false; + int position = seq_lens_encoder_now; + if (truncate_first_token) { + position = position - 1; + input_ids_now = base_model_first_token; + seq_lens_this_time_now = seq_lens_encoder_now; + } else { + input_ids_now = base_model_first_token; + seq_lens_this_time_now = seq_lens_encoder_now + 1; } - } else { - for (; tid < bsz; tid += ncores * nclusters) { - bool base_model_stop_flags_now = false; - bool base_model_is_block_step_now = false; - bool batch_drop_now = false; - bool stop_flags_now = false; - bool is_block_step_now = false; - int seq_lens_this_time_now = 0; - int seq_lens_encoder_now = 0; - int seq_lens_decoder_new = 0; - int accept_num_now = 0; - int base_model_seq_lens_decoder_now = 0; - int base_model_seq_lens_this_time_now = 0; - int64_t step_id_now = 0; - int64_t base_model_step_idx_now; - int64_t pre_ids_now; - mfence(); - GM2LM_ASYNC(is_block_step + tid, &is_block_step_now, sizeof(bool)); - GM2LM_ASYNC(base_model_stop_flags + tid, - &base_model_stop_flags_now, - sizeof(bool)); - GM2LM_ASYNC(base_model_is_block_step + tid, - &base_model_is_block_step_now, - sizeof(bool)); - GM2LM_ASYNC(batch_drop + tid, &batch_drop_now, sizeof(bool)); - GM2LM_ASYNC(stop_flags + tid, &stop_flags_now, sizeof(bool)); - GM2LM_ASYNC( - seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int)); - GM2LM_ASYNC( - seq_lens_decoder + tid, &seq_lens_decoder_new, sizeof(int)); - - GM2LM_ASYNC(accept_tokens + tid * accept_tokens_len, - accept_tokens_now, - accept_tokens_len * sizeof(int64_t)); - GM2LM_ASYNC(accept_num + tid, &accept_num_now, sizeof(int)); - - GM2LM_ASYNC(base_model_seq_lens_this_time + tid, - &base_model_seq_lens_this_time_now, - sizeof(int)); - GM2LM_ASYNC(base_model_seq_lens_decoder + tid, - &base_model_seq_lens_decoder_now, - sizeof(int)); - GM2LM_ASYNC(step_idx + tid, &step_id_now, sizeof(int64_t)); - GM2LM(base_model_step_idx + tid, - &base_model_step_idx_now, - sizeof(int64_t)); - - for (int i = 1; i < base_model_draft_tokens_len; i++) { - LM2GM_ASYNC(&value_fu, - base_model_draft_tokens + - tid * base_model_draft_tokens_len + i, - sizeof(int)); - } - if (kvcache_scheduler_v1) { - if (base_model_stop_flags_now && base_model_is_block_step_now) { - stop_flags_now = true; - is_block_step_now = true; - } - } else { - if (base_model_stop_flags_now && base_model_is_block_step_now) { - batch_drop_now = true; - stop_flags_now = true; - } - } + LM2GM_ASYNC(&input_ids_now, + input_ids + tid * input_ids_len + position, + sizeof(int64_t)); + } else { + stop_flags_now = true; + seq_lens_this_time_now = 0; + seq_lens_encoder_now = 0; + not_stop_flag_sm[cid] += 0; + LM2GM_ASYNC(&value_zero, seq_lens_decoder + tid, sizeof(int)); + } + LM2GM_ASYNC(&seq_lens_encoder_now, seq_lens_encoder + tid, sizeof(int)); + LM2GM_ASYNC(&stop_flags_now, stop_flags + tid, sizeof(bool)); + LM2GM(&seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int)); + } + } else { + for (; tid < bsz; tid += ncores * nclusters) { + bool base_model_stop_flags_now = false; + bool base_model_is_block_step_now = false; + bool batch_drop_now = false; + bool stop_flags_now = false; + bool is_block_step_now = false; + int seq_lens_this_time_now = 0; + int seq_lens_encoder_now = 0; + int seq_lens_decoder_new = 0; + int accept_num_now = 0; + int base_model_seq_lens_decoder_now = 0; + int base_model_seq_lens_this_time_now = 0; + int64_t step_id_now = 0; + int64_t base_model_step_idx_now; + int64_t pre_ids_now; + mfence(); + GM2LM_ASYNC(is_block_step + tid, &is_block_step_now, sizeof(bool)); + GM2LM_ASYNC(base_model_stop_flags + tid, + &base_model_stop_flags_now, + sizeof(bool)); + GM2LM_ASYNC(base_model_is_block_step + tid, + &base_model_is_block_step_now, + sizeof(bool)); + GM2LM_ASYNC(batch_drop + tid, &batch_drop_now, sizeof(bool)); + GM2LM_ASYNC(stop_flags + tid, &stop_flags_now, sizeof(bool)); + GM2LM_ASYNC(seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int)); + GM2LM_ASYNC(seq_lens_decoder + tid, &seq_lens_decoder_new, sizeof(int)); - if (!(base_model_stop_flags_now || batch_drop_now)) { - not_stop_flag_sm[cid] += 1; - if (seq_lens_encoder_now > 0) { - int seq_len_encoder = seq_lens_encoder_now; - stop_flags_now = false; - int64_t base_model_first_token = accept_tokens_now[0]; - LM2GM(&base_model_first_token, - pre_ids + tid * pre_ids_len, - sizeof(int64_t)); - int position = seq_len_encoder; - if (truncate_first_token) { - LM2GM(&base_model_first_token, - input_ids + tid * input_ids_len + position - 1, - sizeof(int64_t)); - seq_lens_this_time_now = seq_len_encoder; - } else { - LM2GM(&base_model_first_token, - input_ids + tid * input_ids_len + position, - sizeof(int64_t)); - seq_lens_this_time_now = seq_len_encoder + 1; - } - } else { - if (kvcache_scheduler_v1) { - if (!base_model_is_block_step_now && - is_block_step_now) { - is_block_step_now = false; - } - } - if (stop_flags_now) { - stop_flags_now = false; - seq_lens_decoder_new = base_model_seq_lens_decoder_now - - base_model_seq_lens_this_time_now; - step_id_now = base_model_step_idx_now - - base_model_seq_lens_this_time_now; + GM2LM_ASYNC(accept_tokens + tid * accept_tokens_len, + accept_tokens_now, + accept_tokens_len * sizeof(int64_t)); + GM2LM_ASYNC(accept_num + tid, &accept_num_now, sizeof(int)); - } else { - seq_lens_decoder_new -= num_model_step - 1; - step_id_now -= num_model_step - 1; - } - for (int i = 0; i < accept_num_now; i++) { - const int pre_id_pos = - base_model_step_idx_now - (accept_num_now - i); - LM2GM(accept_tokens_now + i, - draft_tokens + tid * draft_tokens_len + i, - sizeof(int64_t)); - LM2GM(accept_tokens_now + i, - pre_ids + tid * pre_ids_len + pre_id_pos, - sizeof(int64_t)); - } - seq_lens_this_time_now = accept_num_now; - } + GM2LM_ASYNC(base_model_seq_lens_this_time + tid, + &base_model_seq_lens_this_time_now, + sizeof(int)); + GM2LM_ASYNC(base_model_seq_lens_decoder + tid, + &base_model_seq_lens_decoder_now, + sizeof(int)); + GM2LM_ASYNC(step_idx + tid, &step_id_now, sizeof(int64_t)); + GM2LM( + base_model_step_idx + tid, &base_model_step_idx_now, sizeof(int64_t)); - } else { - stop_flags_now = true; - seq_lens_this_time_now = 0; - seq_lens_encoder_now = 0; - seq_lens_decoder_new = 0; - } - LM2GM_ASYNC(&stop_flags_now, stop_flags + tid, sizeof(bool)); - LM2GM_ASYNC(&batch_drop_now, batch_drop + tid, sizeof(bool)); - LM2GM_ASYNC(&is_block_step_now, is_block_step + tid, sizeof(bool)); - LM2GM_ASYNC( - &seq_lens_decoder_new, seq_lens_decoder + tid, sizeof(int)); - LM2GM_ASYNC( - &seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int)); - LM2GM_ASYNC( - &seq_lens_encoder_now, seq_lens_encoder + tid, sizeof(int)); - LM2GM_ASYNC(&step_id_now, step_idx + tid, sizeof(int64_t)); + for (int i = 1; i < base_model_draft_tokens_len; i++) { + LM2GM_ASYNC( + &value_fu, + base_model_draft_tokens + tid * base_model_draft_tokens_len + i, + sizeof(int)); + } + if (kvcache_scheduler_v1) { + if (base_model_stop_flags_now && base_model_is_block_step_now) { + stop_flags_now = true; + is_block_step_now = true; } - } - mfence(); - sync_cluster(); - bool value_true = true; - bool value_false = false; - if (cid == 0) { - for (int i = 0; i < ncores; i++) { - not_stop_flag_sm[0] += not_stop_flag_sm[i]; + } else { + if (base_model_stop_flags_now && base_model_is_block_step_now) { + batch_drop_now = true; + stop_flags_now = true; } - if (not_stop_flag_sm[0] > 0) { - LM2GM(&value_true, not_need_stop, sizeof(bool)); + } + + if (!(base_model_stop_flags_now || batch_drop_now)) { + not_stop_flag_sm[cid] += 1; + if (seq_lens_encoder_now > 0) { + int seq_len_encoder = seq_lens_encoder_now; + stop_flags_now = false; + int64_t base_model_first_token = accept_tokens_now[0]; + LM2GM(&base_model_first_token, + pre_ids + tid * pre_ids_len, + sizeof(int64_t)); + int position = seq_len_encoder; + if (truncate_first_token) { + LM2GM(&base_model_first_token, + input_ids + tid * input_ids_len + position - 1, + sizeof(int64_t)); + seq_lens_this_time_now = seq_len_encoder; + } else { + LM2GM(&base_model_first_token, + input_ids + tid * input_ids_len + position, + sizeof(int64_t)); + seq_lens_this_time_now = seq_len_encoder + 1; + } } else { - LM2GM(&value_false, not_need_stop, sizeof(bool)); + if (kvcache_scheduler_v1) { + if (!base_model_is_block_step_now && is_block_step_now) { + is_block_step_now = false; + } + } + if (stop_flags_now) { + stop_flags_now = false; + seq_lens_decoder_new = base_model_seq_lens_decoder_now - + base_model_seq_lens_this_time_now; + step_id_now = + base_model_step_idx_now - base_model_seq_lens_this_time_now; + + } else { + seq_lens_decoder_new -= num_model_step - 1; + step_id_now -= num_model_step - 1; + } + for (int i = 0; i < accept_num_now; i++) { + const int pre_id_pos = + base_model_step_idx_now - (accept_num_now - i); + LM2GM(accept_tokens_now + i, + draft_tokens + tid * draft_tokens_len + i, + sizeof(int64_t)); + LM2GM(accept_tokens_now + i, + pre_ids + tid * pre_ids_len + pre_id_pos, + sizeof(int64_t)); + } + seq_lens_this_time_now = accept_num_now; } + + } else { + stop_flags_now = true; + seq_lens_this_time_now = 0; + seq_lens_encoder_now = 0; + seq_lens_decoder_new = 0; + } + LM2GM_ASYNC(&stop_flags_now, stop_flags + tid, sizeof(bool)); + LM2GM_ASYNC(&batch_drop_now, batch_drop + tid, sizeof(bool)); + LM2GM_ASYNC(&is_block_step_now, is_block_step + tid, sizeof(bool)); + LM2GM_ASYNC(&seq_lens_decoder_new, seq_lens_decoder + tid, sizeof(int)); + LM2GM_ASYNC( + &seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int)); + LM2GM_ASYNC(&seq_lens_encoder_now, seq_lens_encoder + tid, sizeof(int)); + LM2GM_ASYNC(&step_id_now, step_idx + tid, sizeof(int64_t)); + } + } + mfence(); + sync_cluster(); + bool value_true = true; + bool value_false = false; + if (cid == 0) { + for (int i = 0; i < ncores; i++) { + not_stop_flag_sm[0] += not_stop_flag_sm[i]; + } + if (not_stop_flag_sm[0] > 0) { + LM2GM(&value_true, not_need_stop, sizeof(bool)); + } else { + LM2GM(&value_false, not_need_stop, sizeof(bool)); } + } } } // namespace plugin diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_dispatch_block.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_dispatch_block.xpu index f5652c55aa6..69bcb3f990e 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_dispatch_block.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_dispatch_block.xpu @@ -43,28 +43,28 @@ static __device__ bool in_need_block_list(const int qid, } __global__ void speculate_free_and_dispatch_block( - bool *stop_flags, - int *seq_lens_this_time, - int *seq_lens_decoder, - int *block_tables, - int *encoder_block_lens, - bool *is_block_step, - int *step_block_list, // [bsz] - int *step_len, - int *recover_block_list, - int *recover_len, - int *need_block_list, - int *need_block_len, - int *used_list_len, - int *free_list, - int *free_list_len, - int64_t *first_token_ids, - int *accept_num, - const int bsz, - const int block_size, - const int block_num_per_seq, - const int max_decoder_block_num, - const int max_draft_tokens) { + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + int *accept_num, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens) { int cid = core_id(); int ncores = core_num(); int clusterid = cluster_id(); @@ -148,8 +148,8 @@ __global__ void speculate_free_and_dispatch_block( mfence(); LM2GM(&value_zero, encoder_block_lens + tid, sizeof(int)); } - } else if (seq_lens_this_time_lm != 0 && - max_possible_block_idx < block_num_per_seq) { + } else if (seq_lens_this_time_lm != 0 && + max_possible_block_idx < block_num_per_seq) { int next_block_id; GM2LM(block_tables + tid * block_num_per_seq + (seq_lens_decoder_lm + max_draft_tokens + 1) / block_size, diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess_v2.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess_v2.cpp index 3eedc4e67f9..13b3b892b49 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess_v2.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess_v2.cpp @@ -93,120 +93,117 @@ static int cpu_wrapper(api::Context* ctx, const bool truncate_first_token, const bool splitwise_prefill, const bool kvcache_scheduler_v1) { - int64_t not_stop_flag_sum = 0; - int64_t not_stop_flag = 0; - for (int tid = 0; tid < bsz; tid++) { - if (splitwise_prefill) { - auto* input_ids_now = input_ids + tid * input_ids_len; - auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len; - if (seq_lens_encoder[tid] > 0) { - not_stop_flag = 1; - int seq_len_encoder = seq_lens_encoder[tid]; - stop_flags[tid] = false; - int64_t base_model_first_token = accept_tokens_now[0]; - int position = seq_len_encoder; - if (truncate_first_token) { - input_ids_now[position - 1] = base_model_first_token; - seq_lens_this_time[tid] = seq_len_encoder; - } else { - input_ids_now[position] = base_model_first_token; - seq_lens_this_time[tid] = seq_len_encoder + 1; - } - } else { - stop_flags[tid] = true; - seq_lens_this_time[tid] = 0; - seq_lens_decoder[tid] = 0; - seq_lens_encoder[tid] = 0; - not_stop_flag = 0; - } - not_stop_flag_sum += not_stop_flag; + int64_t not_stop_flag_sum = 0; + int64_t not_stop_flag = 0; + for (int tid = 0; tid < bsz; tid++) { + if (splitwise_prefill) { + auto* input_ids_now = input_ids + tid * input_ids_len; + auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len; + if (seq_lens_encoder[tid] > 0) { + not_stop_flag = 1; + int seq_len_encoder = seq_lens_encoder[tid]; + stop_flags[tid] = false; + int64_t base_model_first_token = accept_tokens_now[0]; + int position = seq_len_encoder; + if (truncate_first_token) { + input_ids_now[position - 1] = base_model_first_token; + seq_lens_this_time[tid] = seq_len_encoder; } else { - auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len; - auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len; - auto accept_num_now = accept_num[tid]; - auto* input_ids_now = input_ids + tid * input_ids_len; - auto* base_model_draft_tokens_now = - base_model_draft_tokens + tid * base_model_draft_tokens_len; - auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid]; - const int32_t base_model_seq_len_this_time = - base_model_seq_lens_this_time[tid]; - auto* pre_ids_now = pre_ids + tid * pre_ids_len; - for (int i = 1; i < base_model_draft_tokens_len; i++) { - base_model_draft_tokens_now[i] = -1; - } - if(kvcache_scheduler_v1) { - if (base_model_stop_flags[tid] && - base_model_is_block_step[tid]) { - stop_flags[tid] = true; - is_block_step[tid] = true; - // Need to continue infer - } - } else { - if (base_model_stop_flags[tid] && - base_model_is_block_step[tid]) { - batch_drop[tid] = true; - stop_flags[tid] = true; - } - } + input_ids_now[position] = base_model_first_token; + seq_lens_this_time[tid] = seq_len_encoder + 1; + } + } else { + stop_flags[tid] = true; + seq_lens_this_time[tid] = 0; + seq_lens_decoder[tid] = 0; + seq_lens_encoder[tid] = 0; + not_stop_flag = 0; + } + not_stop_flag_sum += not_stop_flag; + } else { + auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len; + auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len; + auto accept_num_now = accept_num[tid]; + auto* input_ids_now = input_ids + tid * input_ids_len; + auto* base_model_draft_tokens_now = + base_model_draft_tokens + tid * base_model_draft_tokens_len; + auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid]; + const int32_t base_model_seq_len_this_time = + base_model_seq_lens_this_time[tid]; + auto* pre_ids_now = pre_ids + tid * pre_ids_len; + for (int i = 1; i < base_model_draft_tokens_len; i++) { + base_model_draft_tokens_now[i] = -1; + } + if (kvcache_scheduler_v1) { + if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { + stop_flags[tid] = true; + is_block_step[tid] = true; + // Need to continue infer + } + } else { + if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { + batch_drop[tid] = true; + stop_flags[tid] = true; + } + } - if (!(base_model_stop_flags[tid] || batch_drop[tid])) { - not_stop_flag = 1; - // prefill generation - if (seq_lens_encoder[tid] > 0) { - // Can be extended to first few tokens - int seq_len_encoder = seq_lens_encoder[tid]; - stop_flags[tid] = false; - int64_t base_model_first_token = accept_tokens_now[0]; - pre_ids_now[0] = base_model_first_token; - int position = seq_len_encoder; - if (truncate_first_token) { - input_ids_now[position - 1] = base_model_first_token; - seq_lens_this_time[tid] = seq_len_encoder; - } else { - input_ids_now[position] = base_model_first_token; - seq_lens_this_time[tid] = seq_len_encoder + 1; - } - } else { // decode generation - if(kvcache_scheduler_v1) { - // 3. try to recover mtp infer in V1 mode - if (!base_model_is_block_step[tid] && - is_block_step[tid]) { - is_block_step[tid] = false; - } - } - if (stop_flags[tid]) { - stop_flags[tid] = false; - // TODO: check - seq_lens_decoder[tid] = base_model_seq_len_decoder - - base_model_seq_len_this_time; - step_idx[tid] = base_model_step_idx[tid] - - base_model_seq_len_this_time; - } else { - // 2: Last base model generated token and first MTP - // token - seq_lens_decoder[tid] -= num_model_step - 1; - step_idx[tid] -= num_model_step - 1; - } - for (int i = 0; i < accept_num_now; i++) { - draft_tokens_now[i] = accept_tokens_now[i]; - const int pre_id_pos = - base_model_step_idx[tid] - (accept_num_now - i); - const int64_t accept_token = accept_tokens_now[i]; - pre_ids_now[pre_id_pos] = accept_token; - } - seq_lens_this_time[tid] = accept_num_now; - } - } else { - stop_flags[tid] = true; - seq_lens_this_time[tid] = 0; - seq_lens_decoder[tid] = 0; - seq_lens_encoder[tid] = 0; + if (!(base_model_stop_flags[tid] || batch_drop[tid])) { + not_stop_flag = 1; + // prefill generation + if (seq_lens_encoder[tid] > 0) { + // Can be extended to first few tokens + int seq_len_encoder = seq_lens_encoder[tid]; + stop_flags[tid] = false; + int64_t base_model_first_token = accept_tokens_now[0]; + pre_ids_now[0] = base_model_first_token; + int position = seq_len_encoder; + if (truncate_first_token) { + input_ids_now[position - 1] = base_model_first_token; + seq_lens_this_time[tid] = seq_len_encoder; + } else { + input_ids_now[position] = base_model_first_token; + seq_lens_this_time[tid] = seq_len_encoder + 1; + } + } else { // decode generation + if (kvcache_scheduler_v1) { + // 3. try to recover mtp infer in V1 mode + if (!base_model_is_block_step[tid] && is_block_step[tid]) { + is_block_step[tid] = false; } - not_stop_flag_sum += not_stop_flag; + } + if (stop_flags[tid]) { + stop_flags[tid] = false; + // TODO: check + seq_lens_decoder[tid] = + base_model_seq_len_decoder - base_model_seq_len_this_time; + step_idx[tid] = + base_model_step_idx[tid] - base_model_seq_len_this_time; + } else { + // 2: Last base model generated token and first MTP + // token + seq_lens_decoder[tid] -= num_model_step - 1; + step_idx[tid] -= num_model_step - 1; + } + for (int i = 0; i < accept_num_now; i++) { + draft_tokens_now[i] = accept_tokens_now[i]; + const int pre_id_pos = + base_model_step_idx[tid] - (accept_num_now - i); + const int64_t accept_token = accept_tokens_now[i]; + pre_ids_now[pre_id_pos] = accept_token; + } + seq_lens_this_time[tid] = accept_num_now; } + } else { + stop_flags[tid] = true; + seq_lens_this_time[tid] = 0; + seq_lens_decoder[tid] = 0; + seq_lens_encoder[tid] = 0; + } + not_stop_flag_sum += not_stop_flag; } - not_need_stop[0] = not_stop_flag_sum > 0; - return api::SUCCESS; + } + not_need_stop[0] = not_stop_flag_sum > 0; + return api::SUCCESS; } static int xpu3_wrapper(api::Context* ctx, @@ -240,41 +237,41 @@ static int xpu3_wrapper(api::Context* ctx, const bool truncate_first_token, const bool splitwise_prefill, const bool kvcache_scheduler_v1) { - using XPU_INT64 = typename XPUIndexType::type; + using XPU_INT64 = typename XPUIndexType::type; - // NOTE: Don't change 16 to 64, because kernel use gsm - xpu3::plugin::draft_model_preprocess_v2<<<1, 64, ctx->xpu_stream>>>( - reinterpret_cast(draft_tokens), - reinterpret_cast(input_ids), - stop_flags, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - reinterpret_cast(step_idx), - not_need_stop, - is_block_step, - batch_drop, - reinterpret_cast(pre_ids), - reinterpret_cast(accept_tokens), - accept_num, - base_model_seq_lens_this_time, - base_model_seq_lens_encoder, - base_model_seq_lens_decoder, - reinterpret_cast(base_model_step_idx), - base_model_stop_flags, - base_model_is_block_step, - reinterpret_cast(base_model_draft_tokens), - bsz, - num_model_step, - accept_tokens_len, - draft_tokens_len, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len, - truncate_first_token, - splitwise_prefill, - kvcache_scheduler_v1); - return api::SUCCESS; + // NOTE: Don't change 16 to 64, because kernel use gsm + xpu3::plugin::draft_model_preprocess_v2<<<1, 64, ctx->xpu_stream>>>( + reinterpret_cast(draft_tokens), + reinterpret_cast(input_ids), + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + reinterpret_cast(step_idx), + not_need_stop, + is_block_step, + batch_drop, + reinterpret_cast(pre_ids), + reinterpret_cast(accept_tokens), + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + reinterpret_cast(base_model_step_idx), + base_model_stop_flags, + base_model_is_block_step, + reinterpret_cast(base_model_draft_tokens), + bsz, + num_model_step, + accept_tokens_len, + draft_tokens_len, + input_ids_len, + base_model_draft_tokens_len, + pre_ids_len, + truncate_first_token, + splitwise_prefill, + kvcache_scheduler_v1); + return api::SUCCESS; } int draft_model_preprocess_v2(api::Context* ctx, @@ -308,115 +305,112 @@ int draft_model_preprocess_v2(api::Context* ctx, const bool truncate_first_token, const bool splitwise_prefill, const bool kvcache_scheduler_v1) { - WRAPPER_CHECK_CTX(ctx); - WRAPPER_DUMP_FUNCTION_T1(ctx, "draft_model_preprocess_v2", int64_t); - WRAPPER_DUMP_PARAM6(ctx, + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "draft_model_preprocess_v2", int64_t); + WRAPPER_DUMP_PARAM6(ctx, + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder); + WRAPPER_DUMP_PARAM5( + ctx, step_idx, not_need_stop, is_block_step, batch_drop, pre_ids); + WRAPPER_DUMP_PARAM3( + ctx, accept_tokens, accept_num, base_model_seq_lens_encoder); + WRAPPER_DUMP_PARAM4(ctx, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags); + WRAPPER_DUMP_PARAM3( + ctx, base_model_is_block_step, base_model_draft_tokens, bsz); + WRAPPER_DUMP_PARAM3(ctx, num_model_step, accept_tokens_len, draft_tokens_len); + WRAPPER_DUMP_PARAM4(ctx, + input_ids_len, + base_model_draft_tokens_len, + pre_ids_len, + truncate_first_token); + WRAPPER_DUMP_PARAM2(ctx, splitwise_prefill, kvcache_scheduler_v1); + WRAPPER_DUMP(ctx); + + WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz * accept_tokens_len, accept_tokens); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz * input_ids_len, input_ids); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz * draft_tokens_len, draft_tokens); + WRAPPER_CHECK_PTR( + ctx, int64_t, bsz * base_model_draft_tokens_len, base_model_draft_tokens); + + WRAPPER_ASSERT_GT(ctx, bsz, 0); + WRAPPER_ASSERT_LT(ctx, accept_tokens_len, 128); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + accept_tokens_len, + draft_tokens_len, + input_ids_len, + base_model_draft_tokens_len, + pre_ids_len, + truncate_first_token, + splitwise_prefill, + kvcache_scheduler_v1); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, draft_tokens, input_ids, stop_flags, seq_lens_this_time, seq_lens_encoder, - seq_lens_decoder); - WRAPPER_DUMP_PARAM5( - ctx, step_idx, not_need_stop, is_block_step, batch_drop, pre_ids); - WRAPPER_DUMP_PARAM3( - ctx, accept_tokens, accept_num, base_model_seq_lens_encoder); - WRAPPER_DUMP_PARAM4(ctx, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, - base_model_stop_flags); - WRAPPER_DUMP_PARAM3( - ctx, base_model_is_block_step, base_model_draft_tokens, bsz); - WRAPPER_DUMP_PARAM3( - ctx, num_model_step, accept_tokens_len, draft_tokens_len); - WRAPPER_DUMP_PARAM4(ctx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + accept_tokens_len, + draft_tokens_len, input_ids_len, base_model_draft_tokens_len, pre_ids_len, - truncate_first_token); - WRAPPER_DUMP_PARAM2(ctx, splitwise_prefill, kvcache_scheduler_v1); - WRAPPER_DUMP(ctx); - - WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time); - WRAPPER_CHECK_PTR(ctx, int64_t, bsz * accept_tokens_len, accept_tokens); - WRAPPER_CHECK_PTR(ctx, int64_t, bsz * input_ids_len, input_ids); - WRAPPER_CHECK_PTR(ctx, int64_t, bsz * draft_tokens_len, draft_tokens); - WRAPPER_CHECK_PTR(ctx, - int64_t, - bsz * base_model_draft_tokens_len, - base_model_draft_tokens); - - WRAPPER_ASSERT_GT(ctx, bsz, 0); - WRAPPER_ASSERT_LT(ctx, accept_tokens_len, 128); - - if (ctx->dev().type() == api::kCPU) { - return cpu_wrapper(ctx, - draft_tokens, - input_ids, - stop_flags, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - step_idx, - not_need_stop, - is_block_step, - batch_drop, - pre_ids, - accept_tokens, - accept_num, - base_model_seq_lens_this_time, - base_model_seq_lens_encoder, - base_model_seq_lens_decoder, - base_model_step_idx, - base_model_stop_flags, - base_model_is_block_step, - base_model_draft_tokens, - bsz, - num_model_step, - accept_tokens_len, - draft_tokens_len, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len, - truncate_first_token, - splitwise_prefill, - kvcache_scheduler_v1); - } - if (ctx->dev().type() == api::kXPU3) { - return xpu3_wrapper(ctx, - draft_tokens, - input_ids, - stop_flags, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - step_idx, - not_need_stop, - is_block_step, - batch_drop, - pre_ids, - accept_tokens, - accept_num, - base_model_seq_lens_this_time, - base_model_seq_lens_encoder, - base_model_seq_lens_decoder, - base_model_step_idx, - base_model_stop_flags, - base_model_is_block_step, - base_model_draft_tokens, - bsz, - num_model_step, - accept_tokens_len, - draft_tokens_len, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len, - truncate_first_token, - splitwise_prefill, - kvcache_scheduler_v1); - } - WRAPPER_UNIMPLEMENTED(ctx); + truncate_first_token, + splitwise_prefill, + kvcache_scheduler_v1); + } + WRAPPER_UNIMPLEMENTED(ctx); } } // namespace plugin diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_dispatch_block.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_dispatch_block.cpp index 68e6a3b3835..a25537dd7cd 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_dispatch_block.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_dispatch_block.cpp @@ -12,213 +12,321 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "xpu/plugin.h" -#include "xpu/refactor/impl_public/wrapper_check.h" #include #include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" namespace xpu3 { namespace plugin { __attribute__((global)) void speculate_free_and_dispatch_block( - bool *stop_flags, int *seq_lens_this_time, int *seq_lens_decoder, - int *block_tables, int *encoder_block_lens, bool *is_block_step, - int *step_block_list, // [bsz] - int *step_len, int *recover_block_list, int *recover_len, - int *need_block_list, int *need_block_len, int *used_list_len, - int *free_list, int *free_list_len, int64_t *first_token_ids, - int *accept_num, const int bsz, - const int block_size, const int block_num_per_seq, - const int max_decoder_block_num, const int max_draft_tokens); + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + int *accept_num, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens); -} // namespace plugin -} // namespace xpu3 +} // namespace plugin +} // namespace xpu3 namespace baidu { namespace xpu { namespace api { namespace plugin { -static int cpu_wrapper(Context *ctx, bool *stop_flags, int *seq_lens_this_time, - int *seq_lens_decoder, int *block_tables, - int *encoder_block_lens, bool *is_block_step, - int *step_block_list, // [bsz] - int *step_len, int *recover_block_list, int *recover_len, - int *need_block_list, int *need_block_len, - int *used_list_len, int *free_list, int *free_list_len, - int64_t *first_token_ids, int *accept_num, const int bsz, - const int block_size, const int block_num_per_seq, - const int max_decoder_block_num, const int max_draft_tokens) { - for (int i = 0; i < bsz; i++) { - int *block_table_now = block_tables + i * block_num_per_seq; - if (stop_flags[i] && !is_block_step[i]) { - // 回收block块 - const int encoder_block_len = encoder_block_lens[i]; - const int decoder_used_len = used_list_len[i]; - if (decoder_used_len > 0) { - const int ori_free_list_len = free_list_len[0]; - free_list_len[0] += decoder_used_len; - for (int j = 0; j < decoder_used_len; j++) { - free_list[ori_free_list_len + j] = - block_table_now[encoder_block_len + j]; - block_table_now[encoder_block_len + j] = -1; - } - encoder_block_lens[i] = 0; - used_list_len[i] = 0; - } - } else if (block_table_now[seq_lens_decoder[i] / block_size] == -1) { - // 统计需要分配block的位置和总数 - const int ori_need_block_len = need_block_len[0]; - need_block_len[0] += 1; - need_block_list[ori_need_block_len] = i; +static int cpu_wrapper(Context *ctx, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + int *accept_num, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens) { + for (int i = 0; i < bsz; i++) { + int *block_table_now = block_tables + i * block_num_per_seq; + if (stop_flags[i] && !is_block_step[i]) { + // 回收block块 + const int encoder_block_len = encoder_block_lens[i]; + const int decoder_used_len = used_list_len[i]; + if (decoder_used_len > 0) { + const int ori_free_list_len = free_list_len[0]; + free_list_len[0] += decoder_used_len; + for (int j = 0; j < decoder_used_len; j++) { + free_list[ori_free_list_len + j] = + block_table_now[encoder_block_len + j]; + block_table_now[encoder_block_len + j] = -1; } + encoder_block_lens[i] = 0; + used_list_len[i] = 0; + } + } else if (block_table_now[seq_lens_decoder[i] / block_size] == -1) { + // 统计需要分配block的位置和总数 + const int ori_need_block_len = need_block_len[0]; + need_block_len[0] += 1; + need_block_list[ori_need_block_len] = i; } + } - while (need_block_len[0] > free_list_len[0]) { - // 调度block,根据used_list_len从大到小回收block,直到满足need_block_len - int max_used_list_len_id = 0; - int max_used_list_len = 0; - for (int i = 0; i < bsz; i++) { - const int used_block_num = !is_block_step[i] ? used_list_len[i] : 0; - if (used_block_num > max_used_list_len) { - max_used_list_len_id = i; - max_used_list_len = used_block_num; - } - } + while (need_block_len[0] > free_list_len[0]) { + // 调度block,根据used_list_len从大到小回收block,直到满足need_block_len + int max_used_list_len_id = 0; + int max_used_list_len = 0; + for (int i = 0; i < bsz; i++) { + const int used_block_num = !is_block_step[i] ? used_list_len[i] : 0; + if (used_block_num > max_used_list_len) { + max_used_list_len_id = i; + max_used_list_len = used_block_num; + } + } - const int encoder_block_len = encoder_block_lens[max_used_list_len_id]; - int *block_table_now = - block_tables + max_used_list_len_id * block_num_per_seq; - for (int i = 0; i < max_used_list_len; i++) { - free_list[free_list_len[0] + i] = - block_table_now[encoder_block_len + i]; - block_table_now[encoder_block_len + i] = -1; - } - step_block_list[step_len[0]] = max_used_list_len_id; - step_len[0] += 1; - free_list_len[0] += max_used_list_len; - stop_flags[max_used_list_len_id] = true; - is_block_step[max_used_list_len_id] = true; - seq_lens_this_time[max_used_list_len_id] = 0; - seq_lens_decoder[max_used_list_len_id] = 0; + const int encoder_block_len = encoder_block_lens[max_used_list_len_id]; + int *block_table_now = + block_tables + max_used_list_len_id * block_num_per_seq; + for (int i = 0; i < max_used_list_len; i++) { + free_list[free_list_len[0] + i] = block_table_now[encoder_block_len + i]; + block_table_now[encoder_block_len + i] = -1; } + step_block_list[step_len[0]] = max_used_list_len_id; + step_len[0] += 1; + free_list_len[0] += max_used_list_len; + stop_flags[max_used_list_len_id] = true; + is_block_step[max_used_list_len_id] = true; + seq_lens_this_time[max_used_list_len_id] = 0; + seq_lens_decoder[max_used_list_len_id] = 0; + } - // 为需要block的位置分配block,每个位置分配一个block - for (int i = 0; i < bsz; i++) { - if (i < need_block_len[0]) { - const int need_block_id = need_block_list[i]; - if (!stop_flags[need_block_id]) { - // 如果需要的位置正好是上一步中被释放的位置,不做处理 - used_list_len[need_block_id] += 1; - const int ori_free_list_len = free_list_len[0]; - free_list_len[0]--; - int *block_table_now = - block_tables + need_block_id * block_num_per_seq; - block_table_now[seq_lens_decoder[need_block_id] / block_size] = - free_list[ori_free_list_len - 1]; - } - need_block_list[i] = -1; - } + // 为需要block的位置分配block,每个位置分配一个block + for (int i = 0; i < bsz; i++) { + if (i < need_block_len[0]) { + const int need_block_id = need_block_list[i]; + if (!stop_flags[need_block_id]) { + // 如果需要的位置正好是上一步中被释放的位置,不做处理 + used_list_len[need_block_id] += 1; + const int ori_free_list_len = free_list_len[0]; + free_list_len[0]--; + int *block_table_now = block_tables + need_block_id * block_num_per_seq; + block_table_now[seq_lens_decoder[need_block_id] / block_size] = + free_list[ori_free_list_len - 1]; + } + need_block_list[i] = -1; } + } - // 计算可以复原的query id - int ori_step_len = step_len[0]; - if (ori_step_len > 0) { - int ori_free_list_len = free_list_len[0]; - int ori_step_block_id = step_block_list[ori_step_len - 1]; - int tmp_used_len = used_list_len[ori_step_block_id]; - // 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中) - int used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 - : tmp_used_len; - while (ori_step_len > 0 && ori_free_list_len >= used_len) { - recover_block_list[recover_len[0]] = ori_step_block_id; - is_block_step[ori_step_block_id] = false; - used_list_len[ori_step_block_id] = used_len; - ori_free_list_len -= used_len; - step_block_list[ori_step_len - 1] = -1; - step_len[0] -= 1; - recover_len[0] += 1; - ori_step_len = step_len[0]; - if (ori_step_len > 0) { - ori_step_block_id = step_block_list[ori_step_len - 1]; - tmp_used_len = used_list_len[ori_step_block_id]; - used_len = tmp_used_len < max_decoder_block_num - ? tmp_used_len + 1 - : tmp_used_len; - } - } - need_block_len[0] = 0; + // 计算可以复原的query id + int ori_step_len = step_len[0]; + if (ori_step_len > 0) { + int ori_free_list_len = free_list_len[0]; + int ori_step_block_id = step_block_list[ori_step_len - 1]; + int tmp_used_len = used_list_len[ori_step_block_id]; + // 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中) + int used_len = + tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len; + while (ori_step_len > 0 && ori_free_list_len >= used_len) { + recover_block_list[recover_len[0]] = ori_step_block_id; + is_block_step[ori_step_block_id] = false; + used_list_len[ori_step_block_id] = used_len; + ori_free_list_len -= used_len; + step_block_list[ori_step_len - 1] = -1; + step_len[0] -= 1; + recover_len[0] += 1; + ori_step_len = step_len[0]; + if (ori_step_len > 0) { + ori_step_block_id = step_block_list[ori_step_len - 1]; + tmp_used_len = used_list_len[ori_step_block_id]; + used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 + : tmp_used_len; + } } - return api::SUCCESS; + need_block_len[0] = 0; + } + return api::SUCCESS; } -static int xpu3_wrapper(Context *ctx, bool *stop_flags, int *seq_lens_this_time, - int *seq_lens_decoder, int *block_tables, - int *encoder_block_lens, bool *is_block_step, - int *step_block_list, // [bsz] - int *step_len, int *recover_block_list, - int *recover_len, int *need_block_list, - int *need_block_len, int *used_list_len, int *free_list, - int *free_list_len, int64_t *first_token_ids, int *accept_num, - const int bsz, const int block_size, +static int xpu3_wrapper(Context *ctx, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + int *accept_num, + const int bsz, + const int block_size, const int block_num_per_seq, - const int max_decoder_block_num, const int max_draft_tokens) { - using XPU_INT64 = typename XPUIndexType::type; - auto speculate_free_and_dispatch_block_kernel = xpu3::plugin::speculate_free_and_dispatch_block; - speculate_free_and_dispatch_block_kernel<<ncluster(), 64, ctx->xpu_stream>>>( - stop_flags, seq_lens_this_time, seq_lens_decoder, block_tables, - encoder_block_lens, is_block_step, step_block_list, step_len, - recover_block_list, recover_len, need_block_list, need_block_len, - used_list_len, free_list, free_list_len, - reinterpret_cast(first_token_ids), accept_num, bsz, block_size, - block_num_per_seq, max_decoder_block_num, max_draft_tokens); - return api::SUCCESS; + const int max_decoder_block_num, + const int max_draft_tokens) { + using XPU_INT64 = typename XPUIndexType::type; + auto speculate_free_and_dispatch_block_kernel = + xpu3::plugin::speculate_free_and_dispatch_block; + speculate_free_and_dispatch_block_kernel<<ncluster(), + 64, + ctx->xpu_stream>>>( + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_len, + recover_block_list, + recover_len, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + reinterpret_cast(first_token_ids), + accept_num, + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num, + max_draft_tokens); + return api::SUCCESS; } -int speculate_free_and_dispatch_block(Context *ctx, bool *stop_flags, - int *seq_lens_this_time, int *seq_lens_decoder, - int *block_tables, int *encoder_block_lens, - bool *is_block_step, - int *step_block_list, // [bsz] - int *step_len, int *recover_block_list, - int *recover_len, int *need_block_list, - int *need_block_len, int *used_list_len, - int *free_list, int *free_list_len, - int64_t *first_token_ids, int *accept_num, - const int bsz, const int block_size, - const int block_num_per_seq, - const int max_decoder_block_num, const int max_draft_tokens) { - WRAPPER_CHECK_CTX(ctx); - WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_free_and_dispatch_block", float); - WRAPPER_DUMP_PARAM6(ctx, stop_flags, seq_lens_this_time, seq_lens_decoder, - block_tables, encoder_block_lens, is_block_step); - WRAPPER_DUMP_PARAM6(ctx, step_block_list, step_len, recover_block_list, - recover_len, need_block_list, need_block_len); - WRAPPER_DUMP_PARAM4(ctx, used_list_len, free_list, free_list_len, - first_token_ids); - WRAPPER_DUMP_PARAM4(ctx, bsz, block_size, block_num_per_seq, - max_decoder_block_num); - WRAPPER_DUMP(ctx); - if (ctx->dev().type() == api::kCPU) { - return cpu_wrapper( - ctx, stop_flags, seq_lens_this_time, seq_lens_decoder, block_tables, - encoder_block_lens, is_block_step, step_block_list, step_len, - recover_block_list, recover_len, need_block_list, need_block_len, - used_list_len, free_list, free_list_len, first_token_ids, accept_num, bsz, - block_size, block_num_per_seq, max_decoder_block_num, max_draft_tokens); - } - if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { - return xpu3_wrapper( - ctx, stop_flags, seq_lens_this_time, seq_lens_decoder, block_tables, - encoder_block_lens, is_block_step, step_block_list, step_len, - recover_block_list, recover_len, need_block_list, need_block_len, - used_list_len, free_list, free_list_len, first_token_ids, accept_num, bsz, - block_size, block_num_per_seq, max_decoder_block_num, max_draft_tokens); - } - WRAPPER_UNIMPLEMENTED(ctx); +int speculate_free_and_dispatch_block(Context *ctx, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + int *accept_num, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_free_and_dispatch_block", float); + WRAPPER_DUMP_PARAM6(ctx, + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step); + WRAPPER_DUMP_PARAM6(ctx, + step_block_list, + step_len, + recover_block_list, + recover_len, + need_block_list, + need_block_len); + WRAPPER_DUMP_PARAM4( + ctx, used_list_len, free_list, free_list_len, first_token_ids); + WRAPPER_DUMP_PARAM4( + ctx, bsz, block_size, block_num_per_seq, max_decoder_block_num); + WRAPPER_DUMP(ctx); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_len, + recover_block_list, + recover_len, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + first_token_ids, + accept_num, + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num, + max_draft_tokens); + } + if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_len, + recover_block_list, + recover_len, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + first_token_ids, + accept_num, + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num, + max_draft_tokens); + } + WRAPPER_UNIMPLEMENTED(ctx); } -} // namespace plugin -} // namespace api -} // namespace xpu -} // namespace baidu +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp index bd6e8c2212d..2996325c833 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp @@ -20,27 +20,28 @@ namespace xpu3 { namespace plugin { -__attribute__((global)) void speculate_recover_block(int *recover_block_list, // [bsz] - int *recover_len, - bool *stop_flags, - int *seq_lens_this_time, - const int *ori_seq_lens_encoder, - int *seq_lens_encoder, - const int *seq_lens_decoder, - int *block_tables, - int *free_list, - int *free_list_len, - int64_t *input_ids, - const int64_t *pre_ids, - const int64_t *step_idx, - const int *encoder_block_lens, - const int *used_list_len, - const int64_t *next_tokens, - const int64_t *first_token_ids, - const int bsz, - const int block_num_per_seq, - const int length, - const int pre_id_length); +__attribute__((global)) void speculate_recover_block( + int *recover_block_list, // [bsz] + int *recover_len, + bool *stop_flags, + int *seq_lens_this_time, + const int *ori_seq_lens_encoder, + int *seq_lens_encoder, + const int *seq_lens_decoder, + int *block_tables, + int *free_list, + int *free_list_len, + int64_t *input_ids, + const int64_t *pre_ids, + const int64_t *step_idx, + const int *encoder_block_lens, + const int *used_list_len, + const int64_t *next_tokens, + const int64_t *first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length); } // namespace plugin } // namespace xpu3 diff --git a/custom_ops/xpu_ops/test/test_speculate_step.py b/custom_ops/xpu_ops/test/test_speculate_step.py index d3bdb212c4a..070ea393651 100644 --- a/custom_ops/xpu_ops/test/test_speculate_step.py +++ b/custom_ops/xpu_ops/test/test_speculate_step.py @@ -15,7 +15,7 @@ import numpy as np import paddle -from fastdeploy.model_executor.ops.xpu import step_paddle, speculate_step_paddle +from fastdeploy.model_executor.ops.xpu import speculate_step_paddle np.random.seed(2023) From 414e3edafe1cc86d484a81e414ea9e485fe0c902 Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Tue, 4 Nov 2025 05:07:46 +0000 Subject: [PATCH 04/17] format --- custom_ops/xpu_ops/src/ops/pybind/pybind.cc | 86 +++++++++++-------- .../speculate_get_padding_offset.xpu | 1 - .../mtp_kernel/speculate_verify.xpu | 58 ++++++------- .../speculate_get_padding_offset.cpp | 66 +++++++------- .../wrapper/mtp_wrapper/speculate_verify.cpp | 62 ++++++------- 5 files changed, 140 insertions(+), 133 deletions(-) diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 0128f7ca461..f57670e6fad 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -304,29 +304,29 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const bool splitwise_prefill); void DraftModelPreprocessV2(const paddle::Tensor& draft_tokens, - const paddle::Tensor& input_ids, - const paddle::Tensor& stop_flags, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& step_idx, - const paddle::Tensor& not_need_stop, - const paddle::Tensor& is_block_step, - const paddle::Tensor& batch_drop, - const paddle::Tensor& pre_ids, - const paddle::Tensor& accept_tokens, - const paddle::Tensor& accept_num, - const paddle::Tensor& base_model_seq_lens_this_time, - const paddle::Tensor& base_model_seq_lens_encoder, - const paddle::Tensor& base_model_seq_lens_decoder, - const paddle::Tensor& base_model_step_idx, - const paddle::Tensor& base_model_stop_flags, - const paddle::Tensor& base_model_is_block_step, - const paddle::Tensor& base_model_draft_tokens, - const int num_model_step, - const bool truncate_first_token, - const bool splitwise_prefill, - const bool kvcache_scheduler_v1); + const paddle::Tensor& input_ids, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_idx, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& is_block_step, + const paddle::Tensor& batch_drop, + const paddle::Tensor& pre_ids, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& base_model_seq_lens_this_time, + const paddle::Tensor& base_model_seq_lens_encoder, + const paddle::Tensor& base_model_seq_lens_decoder, + const paddle::Tensor& base_model_step_idx, + const paddle::Tensor& base_model_stop_flags, + const paddle::Tensor& base_model_is_block_step, + const paddle::Tensor& base_model_draft_tokens, + const int num_model_step, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1); void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens, const paddle::Tensor& base_model_seq_lens_this_time, @@ -472,21 +472,31 @@ void MTPStepPaddle( const int max_draft_tokens); void SpeculateStepPaddle( - const paddle::Tensor &stop_flags, const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &ori_seq_lens_encoder, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] - const paddle::Tensor &encoder_block_lens, - const paddle::Tensor &is_block_step, const paddle::Tensor &step_block_list, - const paddle::Tensor &step_lens, const paddle::Tensor &recover_block_list, - const paddle::Tensor &recover_lens, const paddle::Tensor &need_block_list, - const paddle::Tensor &need_block_len, const paddle::Tensor &used_list_len, - const paddle::Tensor &free_list, const paddle::Tensor &free_list_len, - const paddle::Tensor &input_ids, const paddle::Tensor &pre_ids, - const paddle::Tensor &step_idx, const paddle::Tensor &next_tokens, - const paddle::Tensor &first_token_ids, const paddle::Tensor &accept_num, - const int block_size, const int encoder_decoder_block_num, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& ori_seq_lens_encoder, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor& encoder_block_lens, + const paddle::Tensor& is_block_step, + const paddle::Tensor& step_block_list, + const paddle::Tensor& step_lens, + const paddle::Tensor& recover_block_list, + const paddle::Tensor& recover_lens, + const paddle::Tensor& need_block_list, + const paddle::Tensor& need_block_len, + const paddle::Tensor& used_list_len, + const paddle::Tensor& free_list, + const paddle::Tensor& free_list_len, + const paddle::Tensor& input_ids, + const paddle::Tensor& pre_ids, + const paddle::Tensor& step_idx, + const paddle::Tensor& next_tokens, + const paddle::Tensor& first_token_ids, + const paddle::Tensor& accept_num, + const int block_size, + const int encoder_decoder_block_num, const int max_draft_tokens); void SaveOutMmsgStatic(const paddle::Tensor& x, diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu index 4af74b8f620..637e076d625 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu @@ -101,7 +101,6 @@ __global__ void speculate_get_padding_offset(int* padding_offset, } } - __global__ void speculate_get_padding_offset_v2(int* batch_id_per_token, int* cum_offsets_out, int* cu_seqlens_q, diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu index 4287c3e7d88..26ad38c9f94 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_verify.xpu @@ -162,7 +162,7 @@ __global__ void speculate_verify( // printf("seq_lens_this_time[%d]-1: %d \n",bid, // seq_lens_this_time[bid]-1); for (; i < seq_lens_this_time[bid] - 1; i++) { - if(benchmark_mode){ + if (benchmark_mode) { break; } if (seq_lens_encoder[bid] != 0) { @@ -304,34 +304,34 @@ __global__ void speculate_verify( } } } -#define SPECULATE_VERIFY_INSTANTIATE(ENABLE_TOPP, USE_TOPK) \ - template __global__ void speculate_verify( \ - int64_t * accept_tokens, \ - int *accept_num, \ - int64_t *step_idx, \ - bool *stop_flags, \ - const int *seq_lens_encoder, \ - const int *seq_lens_decoder, \ - const int64_t *draft_tokens, \ - const int *actual_draft_token_nums, \ - const float *dev_curand_states, \ - const float *topp, \ - const int *seq_lens_this_time, \ - const int64_t *verify_tokens, \ - const float *verify_scores, \ - const int64_t *max_dec_len, \ - const int64_t *end_tokens, \ - const bool *is_block_step, \ - const int *output_cum_offsets, \ - const int *actual_candidate_len, \ - int real_bsz, \ - int max_draft_tokens, \ - int end_length, \ - int max_seq_len, \ - int max_candidate_len, \ - int verify_window, \ - bool prefill_one_step_stop, \ - bool benchmark_mode); +#define SPECULATE_VERIFY_INSTANTIATE(ENABLE_TOPP, USE_TOPK) \ + template __global__ void speculate_verify( \ + int64_t * accept_tokens, \ + int *accept_num, \ + int64_t *step_idx, \ + bool *stop_flags, \ + const int *seq_lens_encoder, \ + const int *seq_lens_decoder, \ + const int64_t *draft_tokens, \ + const int *actual_draft_token_nums, \ + const float *dev_curand_states, \ + const float *topp, \ + const int *seq_lens_this_time, \ + const int64_t *verify_tokens, \ + const float *verify_scores, \ + const int64_t *max_dec_len, \ + const int64_t *end_tokens, \ + const bool *is_block_step, \ + const int *output_cum_offsets, \ + const int *actual_candidate_len, \ + int real_bsz, \ + int max_draft_tokens, \ + int end_length, \ + int max_seq_len, \ + int max_candidate_len, \ + int verify_window, \ + bool prefill_one_step_stop, \ + bool benchmark_mode); SPECULATE_VERIFY_INSTANTIATE(true, true) SPECULATE_VERIFY_INSTANTIATE(true, false) SPECULATE_VERIFY_INSTANTIATE(false, true) diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp index 0886a0196a8..21134d86807 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp @@ -109,16 +109,15 @@ static int cpu_wrapper_get_padding_offset(Context* ctx, return api::SUCCESS; } - static int cpu_wrapper_get_padding_offset_v2(Context* ctx, - int* batch_id_per_token, - int* cum_offsets_out, - int* cu_seqlens_q, - int* cu_seqlens_k, - const int* cum_offsets, - const int* seq_lens, - const int max_seq_len, - int bsz) { + int* batch_id_per_token, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + int bsz) { for (int bi = 0; bi < bsz; ++bi) { int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; for (int i = 0; i < seq_lens[bi]; i++) { @@ -184,14 +183,14 @@ static int xpu3_wrapper_get_padding_offset(Context* ctx, } static int xpu3_wrapper_get_padding_offset_v2(Context* ctx, - int* batch_id_per_token, - int* cum_offsets_out, - int* cu_seqlens_q, - int* cu_seqlens_k, - const int* cum_offsets, - const int* seq_lens, - const int max_seq_len, - int bsz) { + int* batch_id_per_token, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + int bsz) { xpu3::plugin:: speculate_get_padding_offset_v2<<ncluster(), 64, ctx->xpu_stream>>>( batch_id_per_token, @@ -205,7 +204,6 @@ static int xpu3_wrapper_get_padding_offset_v2(Context* ctx, return api::SUCCESS; } - template int speculate_remove_padding(Context* ctx, T* x_remove_padding, @@ -360,25 +358,25 @@ int speculate_get_padding_offset_v2(Context* ctx, if (ctx->dev().type() == api::kCPU) { return cpu_wrapper_get_padding_offset_v2(ctx, - batch_id_per_token, - cum_offsets_out, - cu_seqlens_q, - cu_seqlens_k, - cum_offsets, - seq_lens, - max_seq_len, - bsz); + batch_id_per_token, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + cum_offsets, + seq_lens, + max_seq_len, + bsz); } if (ctx->dev().type() == api::kXPU3) { return xpu3_wrapper_get_padding_offset_v2(ctx, - batch_id_per_token, - cum_offsets_out, - cu_seqlens_q, - cu_seqlens_k, - cum_offsets, - seq_lens, - max_seq_len, - bsz); + batch_id_per_token, + cum_offsets_out, + cu_seqlens_q, + cu_seqlens_k, + cum_offsets, + seq_lens, + max_seq_len, + bsz); } WRAPPER_UNIMPLEMENTED(ctx); diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp index c9571bd513c..3989ce8deb0 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_verify.cpp @@ -162,7 +162,7 @@ static int cpu_wrapper(Context *ctx, // printf("seq_lens_this_time[%d]-1: %d \n",bid, // seq_lens_this_time[bid]-1); for (; i < seq_lens_this_time[bid] - 1; i++) { - if(benchmark_mode){ + if (benchmark_mode) { break; } if (seq_lens_encoder[bid] != 0) { @@ -513,36 +513,36 @@ int speculate_verify(Context *ctx, WRAPPER_UNIMPLEMENTED(ctx); } -#define INSTANTIATE_SPECULATE_VERIFY(ENABLE_TOPP, USE_TOPK) \ - template int \ - baidu::xpu::api::plugin::speculate_verify( \ - baidu::xpu::api::Context *, /* xpu_ctx */ \ - int64_t *, /* accept_tokens */ \ - int *, /* accept_num */ \ - int64_t *, /* step_idx */ \ - bool *, /* stop_flags */ \ - const int *, /* seq_lens_encoder */ \ - const int *, /* seq_lens_decoder */ \ - const int64_t *, /* draft_tokens */ \ - const int *, /* actual_draft_token_nums */ \ - const float *, /* dev_curand_states or topp */ \ - const float *, /* topp or nullptr */ \ - const int *, /* seq_lens_this_time */ \ - const int64_t *, /* verify_tokens */ \ - const float *, /* verify_scores */ \ - const int64_t *, /* max_dec_len */ \ - const int64_t *, /* end_tokens */ \ - const bool *, /* is_block_step */ \ - const int *, /* output_cum_offsets */ \ - const int *, /* actual_candidate_len */ \ - int, /* real_bsz */ \ - int, /* max_draft_tokens */ \ - int, /* end_length */ \ - int, /* max_seq_len */ \ - int, /* max_candidate_len */ \ - int, /* verify_window */ \ - bool, \ - bool); /* prefill_one_step_stop */ +#define INSTANTIATE_SPECULATE_VERIFY(ENABLE_TOPP, USE_TOPK) \ + template int \ + baidu::xpu::api::plugin::speculate_verify( \ + baidu::xpu::api::Context *, /* xpu_ctx */ \ + int64_t *, /* accept_tokens */ \ + int *, /* accept_num */ \ + int64_t *, /* step_idx */ \ + bool *, /* stop_flags */ \ + const int *, /* seq_lens_encoder */ \ + const int *, /* seq_lens_decoder */ \ + const int64_t *, /* draft_tokens */ \ + const int *, /* actual_draft_token_nums */ \ + const float *, /* dev_curand_states or topp */ \ + const float *, /* topp or nullptr */ \ + const int *, /* seq_lens_this_time */ \ + const int64_t *, /* verify_tokens */ \ + const float *, /* verify_scores */ \ + const int64_t *, /* max_dec_len */ \ + const int64_t *, /* end_tokens */ \ + const bool *, /* is_block_step */ \ + const int *, /* output_cum_offsets */ \ + const int *, /* actual_candidate_len */ \ + int, /* real_bsz */ \ + int, /* max_draft_tokens */ \ + int, /* end_length */ \ + int, /* max_seq_len */ \ + int, /* max_candidate_len */ \ + int, /* verify_window */ \ + bool, \ + bool); /* prefill_one_step_stop */ INSTANTIATE_SPECULATE_VERIFY(false, false) INSTANTIATE_SPECULATE_VERIFY(false, true) From fa537115a9685e236218a590e87087b971d795ca Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Tue, 4 Nov 2025 06:02:27 +0000 Subject: [PATCH 05/17] format --- .../xpu_ops/src/plugin/include/xpu/plugin.h | 172 +++++++++--------- 1 file changed, 87 insertions(+), 85 deletions(-) diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index 754297ce6ce..fdbd988fed6 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -74,48 +74,49 @@ DLL_EXPORT int get_padding_offset(Context* ctx, const int* seq_lens, const int max_seq_len, const int bs); - -DLL_EXPORT int speculate_get_padding_offset_v2(Context* ctx, - int* batch_id_per_token, - int* cum_offsets_out, - int* cu_seqlens_q, - int* cu_seqlens_k, - const int* cum_offsets, - const int* seq_lens, - const int max_seq_len, - int bsz); -DLL_EXPORT int draft_model_preprocess_v2(api::Context* ctx, - int64_t* draft_tokens, - int64_t* input_ids, - bool* stop_flags, - int* seq_lens_this_time, - int* seq_lens_encoder, - int* seq_lens_decoder, - int64_t* step_idx, - bool* not_need_stop, - bool* is_block_step, - bool* batch_drop, - int64_t* pre_ids, - const int64_t* accept_tokens, - const int* accept_num, - const int* base_model_seq_lens_this_time, - const int* base_model_seq_lens_encoder, - const int* base_model_seq_lens_decoder, - const int64_t* base_model_step_idx, - const bool* base_model_stop_flags, - const bool* base_model_is_block_step, - int64_t* base_model_draft_tokens, - const int bsz, - const int num_model_step, - const int accept_tokens_len, - const int draft_tokens_len, - const int input_ids_len, - const int base_model_draft_tokens_len, - const int pre_ids_len, - const bool truncate_first_token, - const bool splitwise_prefill, - const bool kvcache_scheduler_v1); +DLL_EXPORT int speculate_get_padding_offset_v2(Context* ctx, + int* batch_id_per_token, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + int bsz); + +DLL_EXPORT int draft_model_preprocess_v2( + api::Context* ctx, + int64_t* draft_tokens, + int64_t* input_ids, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + bool* not_need_stop, + bool* is_block_step, + bool* batch_drop, + int64_t* pre_ids, + const int64_t* accept_tokens, + const int* accept_num, + const int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const int* base_model_seq_lens_decoder, + const int64_t* base_model_step_idx, + const bool* base_model_stop_flags, + const bool* base_model_is_block_step, + int64_t* base_model_draft_tokens, + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1); DLL_EXPORT int update_inputs(Context* ctx, bool* not_need_stop, @@ -153,29 +154,30 @@ DLL_EXPORT int free_and_dispatch_block(Context* ctx, const int block_num_per_seq, const int max_decoder_block_num); -DLL_EXPORT int speculate_free_and_dispatch_block(Context* ctx, - bool* stop_flags, - int* seq_lens_this_time, - int* seq_lens_decoder, - int* block_tables, - int* encoder_block_lens, - bool* is_block_step, - int* step_block_list, // [bsz] - int* step_len, - int* recover_block_list, - int* recover_len, - int* need_block_list, - int* need_block_len, - int* used_list_len, - int* free_list, - int* free_list_len, - int64_t* first_token_ids, - int* accept_num, - const int bsz, - const int block_size, - const int block_num_per_seq, - const int max_decoder_block_num, - const int max_draft_tokens); +DLL_EXPORT int speculate_free_and_dispatch_block( + Context* ctx, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_decoder, + int* block_tables, + int* encoder_block_lens, + bool* is_block_step, + int* step_block_list, // [bsz] + int* step_len, + int* recover_block_list, + int* recover_len, + int* need_block_list, + int* need_block_len, + int* used_list_len, + int* free_list, + int* free_list_len, + int64_t* first_token_ids, + int* accept_num, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, + const int max_draft_tokens); DLL_EXPORT int recover_block(Context* ctx, int* recover_block_list, // [bsz] @@ -201,27 +203,27 @@ DLL_EXPORT int recover_block(Context* ctx, const int pre_id_length); DLL_EXPORT int speculate_recover_block(Context* ctx, - int* recover_block_list, // [bsz] - int* recover_len, - bool* stop_flags, - int* seq_lens_this_time, - const int* ori_seq_lens_encoder, - int* seq_lens_encoder, - const int* seq_lens_decoder, - int* block_tables, - int* free_list, - int* free_list_len, - int64_t* input_ids, - const int64_t* pre_ids, - const int64_t* step_idx, - const int* encoder_block_lens, - const int* used_list_len, - const int64_t* next_tokens, - const int64_t* first_token_ids, - const int bsz, - const int block_num_per_seq, - const int length, - const int pre_id_length); + int* recover_block_list, // [bsz] + int* recover_len, + bool* stop_flags, + int* seq_lens_this_time, + const int* ori_seq_lens_encoder, + int* seq_lens_encoder, + const int* seq_lens_decoder, + int* block_tables, + int* free_list, + int* free_list_len, + int64_t* input_ids, + const int64_t* pre_ids, + const int64_t* step_idx, + const int* encoder_block_lens, + const int* used_list_len, + const int64_t* next_tokens, + const int64_t* first_token_ids, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length); DLL_EXPORT int recover_decode_task(Context* ctx, bool* stop_flags, From 5f96446df5e8f04c4fb2afe8aa8fd74fc82d23a0 Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Wed, 5 Nov 2025 04:18:43 +0000 Subject: [PATCH 06/17] fix gather next token --- .../xpu_ops/src/ops/gather_next_token.cc | 45 +++++++++++++++---- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/custom_ops/xpu_ops/src/ops/gather_next_token.cc b/custom_ops/xpu_ops/src/ops/gather_next_token.cc index bc875b372bf..18a773c5150 100644 --- a/custom_ops/xpu_ops/src/ops/gather_next_token.cc +++ b/custom_ops/xpu_ops/src/ops/gather_next_token.cc @@ -35,6 +35,7 @@ std::vector GatherNextToken( typename XPUTypeTrait::Type; // only support bfloat16 typedef paddle::bfloat16 data_t; const int dim = tmp_out.dims()[1]; + const int token_num = tmp_out.shape()[0]; const int bsz = cum_offsets.shape()[0]; int enc_batch = enc_batch_tensor.data()[0]; int dec_batch = dec_batch_tensor.data()[0]; @@ -52,16 +53,42 @@ std::vector GatherNextToken( dec_batch, const_cast(decoder_batch_map.data())}; - auto out = paddle::full({bsz, dim}, -2, tmp_out.type(), tmp_out.place()); + paddle::Tensor out; + std::vector encode_iota_lod_cpu(enc_batch); + if (output_padding_offset) { + int need_delete_token_num = 0; + if (enc_batch > 0) { + need_delete_token_num = + encoder_seq_lod_cpu.data()[enc_batch] - enc_batch; + std::iota(encode_iota_lod_cpu.begin(), encode_iota_lod_cpu.end(), 0); + encoder_batch_map_vp.cpu = + const_cast(encode_iota_lod_cpu.data()); + encoder_batch_map_vp.len = enc_batch; + encoder_batch_map_vp.xpu = nullptr; + } + out = paddle::empty({token_num - need_delete_token_num, dim}, + tmp_out.type(), + tmp_out.place()); + } else { + out = paddle::empty({bsz, dim}, tmp_out.type(), tmp_out.place()); + } + if (tmp_out.shape()[0] == 0) { + return {out}; + } - int r = baidu::xpu::api::plugin::eb_gather_next_token( - xpu_ctx->x_context(), - reinterpret_cast(tmp_out.data()), - reinterpret_cast(out.data()), - encoder_seqs_lods_vp, - encoder_batch_map_vp, - decoder_batch_map_vp, - dim); + if (output_padding_offset && enc_batch <= 0) { + out = tmp_out.copy_to(tmp_out.place(), false); + } else { + int r = baidu::xpu::api::plugin::eb_gather_next_token( + xpu_ctx->x_context(), + reinterpret_cast(tmp_out.data()), + reinterpret_cast(out.data()), + encoder_seqs_lods_vp, + encoder_batch_map_vp, + decoder_batch_map_vp, + dim); + PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed."); + } return {out}; } From 75f3e03b0170ce20ea805b57fad98477fc96b01c Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Fri, 7 Nov 2025 04:08:55 +0000 Subject: [PATCH 07/17] fix step && add test --- .../src/ops/mtp/speculate_step_paddle.cc | 10 +- .../speculate_free_and_dispatch_block.xpu | 5 + .../speculate_free_and_dispatch_block.cpp | 16 +- .../xpu_ops/test/test_speculate_step.py | 482 ++++++++++++------ 4 files changed, 338 insertions(+), 175 deletions(-) diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc index bdffa3b79a4..d8b113fb81a 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc @@ -48,10 +48,14 @@ void SpeculateStepPaddle( const int block_size, const int encoder_decoder_block_num, const int max_draft_tokens) { + namespace api = baidu::xpu::api; phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); auto xpu_ctx = static_cast(dev_ctx); - + api::Context *ctx = xpu_ctx->x_context(); + if (seq_lens_this_time.is_cpu()) { + ctx = new api::Context(api::kCPU); + } const int bsz = seq_lens_this_time.shape()[0]; PADDLE_ENFORCE_LE( bsz, @@ -63,7 +67,7 @@ void SpeculateStepPaddle( const int pre_id_length = pre_ids.shape()[1]; const int max_decoder_block_num = pre_id_length / block_size; int r = baidu::xpu::api::plugin::speculate_free_and_dispatch_block( - xpu_ctx->x_context(), + ctx, const_cast(stop_flags.data()), const_cast(seq_lens_this_time.data()), const_cast(seq_lens_decoder.data()), @@ -91,7 +95,7 @@ void SpeculateStepPaddle( int recover_lens_cpu_data = recover_lens_cpu.data()[0]; if (recover_lens_cpu_data > 0) { r = baidu::xpu::api::plugin::speculate_recover_block( - xpu_ctx->x_context(), + ctx, const_cast(recover_block_list.data()), const_cast(recover_lens.data()), const_cast(stop_flags.data()), diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_dispatch_block.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_dispatch_block.xpu index 69bcb3f990e..133e9d7988d 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_dispatch_block.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_free_and_dispatch_block.xpu @@ -231,6 +231,11 @@ __global__ void speculate_free_and_dispatch_block( sizeof(int)); LM2GM_ASYNC( &value_zero, seq_lens_decoder + max_used_list_len_id, sizeof(int)); + // Note(@wufeisheng): when step, accept num will not be 0 so + // that next step even if this batch member is stepped, save + // output still stream output, so accept num should be set to 0 + LM2GM_ASYNC( + &accept_num, accept_num + max_used_list_len_id, sizeof(int)); mfence(); } } diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_dispatch_block.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_dispatch_block.cpp index a25537dd7cd..cefe893e975 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_dispatch_block.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_free_and_dispatch_block.cpp @@ -77,8 +77,11 @@ static int cpu_wrapper(Context *ctx, const int max_draft_tokens) { for (int i = 0; i < bsz; i++) { int *block_table_now = block_tables + i * block_num_per_seq; + int max_possible_block_idx = + (seq_lens_decoder[i] + max_draft_tokens + 1) / block_size; if (stop_flags[i] && !is_block_step[i]) { // 回收block块 + first_token_ids[i] = -1; const int encoder_block_len = encoder_block_lens[i]; const int decoder_used_len = used_list_len[i]; if (decoder_used_len > 0) { @@ -92,7 +95,10 @@ static int cpu_wrapper(Context *ctx, encoder_block_lens[i] = 0; used_list_len[i] = 0; } - } else if (block_table_now[seq_lens_decoder[i] / block_size] == -1) { + } else if (seq_lens_this_time[i] != 0 && + max_possible_block_idx < block_num_per_seq && + block_table_now[(seq_lens_decoder[i] + max_draft_tokens + 1) / + block_size] == -1) { // 统计需要分配block的位置和总数 const int ori_need_block_len = need_block_len[0]; need_block_len[0] += 1; @@ -126,6 +132,7 @@ static int cpu_wrapper(Context *ctx, is_block_step[max_used_list_len_id] = true; seq_lens_this_time[max_used_list_len_id] = 0; seq_lens_decoder[max_used_list_len_id] = 0; + accept_num[max_used_list_len_id] = 0; } // 为需要block的位置分配block,每个位置分配一个block @@ -138,8 +145,9 @@ static int cpu_wrapper(Context *ctx, const int ori_free_list_len = free_list_len[0]; free_list_len[0]--; int *block_table_now = block_tables + need_block_id * block_num_per_seq; - block_table_now[seq_lens_decoder[need_block_id] / block_size] = - free_list[ori_free_list_len - 1]; + block_table_now[(seq_lens_decoder[need_block_id] + max_draft_tokens + + 1) / + block_size] = free_list[ori_free_list_len - 1]; } need_block_list[i] = -1; } @@ -154,7 +162,7 @@ static int cpu_wrapper(Context *ctx, // 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中) int used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len; - while (ori_step_len > 0 && ori_free_list_len >= used_len) { + if (ori_step_len > 0 && ori_free_list_len >= used_len) { recover_block_list[recover_len[0]] = ori_step_block_id; is_block_step[ori_step_block_id] = false; used_list_len[ori_step_block_id] = used_len; diff --git a/custom_ops/xpu_ops/test/test_speculate_step.py b/custom_ops/xpu_ops/test/test_speculate_step.py index 070ea393651..852c144ccbd 100644 --- a/custom_ops/xpu_ops/test/test_speculate_step.py +++ b/custom_ops/xpu_ops/test/test_speculate_step.py @@ -12,178 +12,324 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import numpy as np import paddle +import pytest from fastdeploy.model_executor.ops.xpu import speculate_step_paddle +# 固定随机种子,保证测试可复现 np.random.seed(2023) +paddle.seed(2023) + + +def test_data(): + """定义测试数据夹具,一次性生成所有输入数据,供测试用例复用""" + + # max_bs = 128 + max_bs = 8 + bs = max_bs + max_seq_len = 8192 + block_size = 64 + block_bs = 8 + block_ratio = 0.75 + max_draft_tokens = 1 + encoder_decoder_block_num = 1 + + # 生成原始测试数据(完全复用原有逻辑) + stop_flags = np.random.randint(0, 2, [max_bs]).astype("bool") + seq_lens_this_time = np.zeros([bs], "int32") + seq_lens_encoder = np.zeros([max_bs], "int32") + seq_lens_decoder = np.zeros([max_bs], "int32") + accept_num = np.random.randint(1, 3, [max_bs]).astype("int32") + for i in range(bs): + seq_lens_decoder[i] = 2 + i * 2 + seq_lens_this_time[i] = 1 + + ori_seq_lens_encoder = np.zeros([max_bs], "int32") + ori_seq_lens_encoder[:] = seq_lens_decoder[:] // 2 + step_idx = (seq_lens_decoder - ori_seq_lens_encoder).astype("int64") + + max_block_num = block_bs * max_seq_len // block_size + free_list_len = int(max_block_num * (1 - block_ratio)) + free_list_len = np.full([1], free_list_len, "int32") + free_list = np.arange( + max_block_num - 1, max_block_num - free_list_len.item() - 1, -1, dtype="int32" # 加 .item() 转为标量 + ) + encoder_block_lens = np.zeros([max_bs], "int32") + used_list_len = np.zeros([max_bs], "int32") + block_tables = np.full([max_bs, 128], -1, "int32") + encoder_block_id = 0 + + for i in range(bs): + enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size + encoder_block_lens[i] = enc_block_num + dec_block_num = (seq_lens_decoder[i] + block_size - 1) // block_size - enc_block_num + used_list_len[i] = dec_block_num + block_tables[i, :enc_block_num] = np.arange(encoder_block_id, encoder_block_id + enc_block_num, 1, "int32") + encoder_block_id += enc_block_num + if dec_block_num > 0: + block_tables[i, enc_block_num : enc_block_num + dec_block_num] = free_list[ + free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1 + ] + free_list[free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1] = -1 + free_list_len[0] -= dec_block_num + + assert free_list_len[0] >= 0, "free_list_len should not be negative" + + is_block_step = np.zeros([max_bs], "bool") + is_block_step[:bs] = np.random.randint(0, 2, [bs]).astype("bool") + step_block_list = np.full([max_bs], -1, "int32") + step_lens = np.full([1], 0, "int32") + + for i in range(bs): + if is_block_step[i]: + step_block_list[step_lens[0]] = i + step_lens[0] += 1 + + recover_lens = np.full([1], 0, "int32") + recover_block_list = np.full([max_bs], -1, "int32") + need_block_len = np.full([1], 0, "int32") + need_block_list = np.full([max_bs], -1, "int32") + + input_ids = np.random.randint(0, 1000, [max_bs, max_seq_len], "int64") + pre_ids = np.random.randint(0, 1000, [max_bs, max_seq_len], "int64") + next_tokens = np.random.randint(0, 1000, [max_bs], "int64") + first_token_ids = np.random.randint(0, 1000, [max_bs], "int64") + + paddle.set_device("cpu") + # 转换为 paddle tensor(保持原有逻辑) + data_cpu = { + "stop_flags": paddle.to_tensor(stop_flags), + "seq_lens_this_time": paddle.to_tensor(seq_lens_this_time), + "seq_lens_encoder": paddle.to_tensor(seq_lens_encoder), + "seq_lens_decoder": paddle.to_tensor(seq_lens_decoder), + "ori_seq_lens_encoder": paddle.to_tensor(ori_seq_lens_encoder), + "block_tables": paddle.to_tensor(block_tables), + "encoder_block_lens": paddle.to_tensor(encoder_block_lens), + "is_block_step": paddle.to_tensor(is_block_step), + "step_block_list": paddle.to_tensor(step_block_list), + "step_lens": paddle.to_tensor(step_lens), + "recover_block_list": paddle.to_tensor(recover_block_list), + "recover_lens": paddle.to_tensor(recover_lens), + "need_block_list": paddle.to_tensor(need_block_list), + "need_block_len": paddle.to_tensor(need_block_len), + "used_list_len": paddle.to_tensor(used_list_len), + "free_list_len": paddle.to_tensor(free_list_len), + "free_list": paddle.to_tensor(free_list), + "input_ids": paddle.to_tensor(input_ids), + "pre_ids": paddle.to_tensor(pre_ids), + "step_idx": paddle.to_tensor(step_idx), + "next_tokens": paddle.to_tensor(next_tokens), + "first_token_ids": paddle.to_tensor(first_token_ids), + "accept_num": paddle.to_tensor(accept_num), + "block_size": block_size, + "encoder_decoder_block_num": encoder_decoder_block_num, + "max_draft_tokens": max_draft_tokens, + } + + paddle.set_device("xpu:0") + data_xpu = { + "stop_flags": paddle.to_tensor(stop_flags), + "seq_lens_this_time": paddle.to_tensor(seq_lens_this_time), + "seq_lens_encoder": paddle.to_tensor(seq_lens_encoder), + "seq_lens_decoder": paddle.to_tensor(seq_lens_decoder), + "ori_seq_lens_encoder": paddle.to_tensor(ori_seq_lens_encoder), + "block_tables": paddle.to_tensor(block_tables), + "encoder_block_lens": paddle.to_tensor(encoder_block_lens), + "is_block_step": paddle.to_tensor(is_block_step), + "step_block_list": paddle.to_tensor(step_block_list), + "step_lens": paddle.to_tensor(step_lens), + "recover_block_list": paddle.to_tensor(recover_block_list), + "recover_lens": paddle.to_tensor(recover_lens), + "need_block_list": paddle.to_tensor(need_block_list), + "need_block_len": paddle.to_tensor(need_block_len), + "used_list_len": paddle.to_tensor(used_list_len), + "free_list_len": paddle.to_tensor(free_list_len), + "free_list": paddle.to_tensor(free_list), + "input_ids": paddle.to_tensor(input_ids), + "pre_ids": paddle.to_tensor(pre_ids), + "step_idx": paddle.to_tensor(step_idx), + "next_tokens": paddle.to_tensor(next_tokens), + "first_token_ids": paddle.to_tensor(first_token_ids), + "accept_num": paddle.to_tensor(accept_num), + "block_size": block_size, + "encoder_decoder_block_num": encoder_decoder_block_num, + "max_draft_tokens": max_draft_tokens, + } + + return data_cpu, data_xpu + + +def speculate_step_paddle_execution(test_data): + """测试 speculate_step_paddle 函数的执行性和输出合理性""" + # 提取输入数据 + stop_flags = test_data["stop_flags"] # 克隆避免影响夹具数据 + seq_lens_this_time = test_data["seq_lens_this_time"] + ori_seq_lens_encoder = test_data["ori_seq_lens_encoder"] + seq_lens_encoder = test_data["seq_lens_encoder"] + seq_lens_decoder = test_data["seq_lens_decoder"] + block_tables = test_data["block_tables"] + encoder_block_lens = test_data["encoder_block_lens"] + is_block_step = test_data["is_block_step"] + step_block_list = test_data["step_block_list"] + step_lens = test_data["step_lens"] + recover_block_list = test_data["recover_block_list"] + recover_lens = test_data["recover_lens"] + need_block_list = test_data["need_block_list"] + need_block_len = test_data["need_block_len"] + used_list_len = test_data["used_list_len"] + free_list = test_data["free_list"] + free_list_len = test_data["free_list_len"] + input_ids = test_data["input_ids"] + pre_ids = test_data["pre_ids"] + step_idx = test_data["step_idx"] + next_tokens = test_data["next_tokens"] + first_token_ids = test_data["first_token_ids"] + accept_num = test_data["accept_num"] + block_size = test_data["block_size"] + encoder_decoder_block_num = test_data["encoder_decoder_block_num"] + max_draft_tokens = test_data["max_draft_tokens"] + + # 可选:打印执行前关键信息(如需调试可开启) + if os.environ.get("ATTN_MASK_TEST_DEBUG", "0") == "1": + print("-" * 50 + "before step op" + "-" * 50) + print("stop_flags: ", stop_flags) + print("seq_lens_this_time: ", seq_lens_this_time) + print("seq_lens_encoder: ", seq_lens_encoder) + print("seq_lens_decoder: ", seq_lens_decoder) + print("ori_seq_lens_encoder: ", ori_seq_lens_encoder) + print("block_tables: ", block_tables.sum()) + print("encoder_block_lens: ", encoder_block_lens) + print("is_block_step: ", is_block_step) + print("step_block_list: ", step_block_list) + print("step_lens: ", step_lens) + print("recover_lens: ", recover_lens) + print("recover_block_list: ", recover_block_list) + print("need_block_list: ", need_block_list) + print("need_block_len: ", need_block_len) + print("used_list_len: ", used_list_len) + print("free_list_len: ", free_list_len) + print("free_list: ", free_list) + print("input_ids: ", input_ids) + print("pre_ids: ", pre_ids) + print("step_idx: ", step_idx) + print("next_tokens: ", next_tokens) + # 执行目标函数(核心测试步骤) + speculate_step_paddle( + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + seq_lens_encoder, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_lens, + recover_block_list, + recover_lens, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + next_tokens, + first_token_ids, + accept_num, + block_size, + encoder_decoder_block_num, + max_draft_tokens, + ) + + if os.environ.get("ATTN_MASK_TEST_DEBUG", "0") == "1": + # 可选:打印执行后关键信息(如需调试可开启) + print("-" * 50 + "before step op" + "-" * 50) + print("stop_flags: ", stop_flags) + print("seq_lens_this_time: ", seq_lens_this_time) + print("seq_lens_encoder: ", seq_lens_encoder) + print("seq_lens_decoder: ", seq_lens_decoder) + print("ori_seq_lens_encoder: ", ori_seq_lens_encoder) + print("block_tables: ", block_tables.sum()) + print("encoder_block_lens: ", encoder_block_lens) + print("is_block_step: ", is_block_step) + print("step_block_list: ", step_block_list) + print("step_lens: ", step_lens) + print("recover_lens: ", recover_lens) + print("recover_block_list: ", recover_block_list) + print("need_block_list: ", need_block_list) + print("need_block_len: ", need_block_len) + print("used_list_len: ", used_list_len) + print("free_list_len: ", free_list_len) + print("free_list: ", free_list) + print("input_ids: ", input_ids) + print("pre_ids: ", pre_ids) + print("step_idx: ", step_idx) + print("next_tokens: ", next_tokens) + return test_data + + +def assert_test_data_equal(test_data1, test_data2, rtol=1e-05, atol=1e-08): + """ + 断言两个 test_data 结构和数据完全一致,自动处理 host/device 数据转换(paddle Tensor → numpy) + + Args: + test_data1: 第一个待比较的 test_data(可在 host 或 device 上) + test_data2: 第二个待比较的 test_data(可在 host 或 device 上) + rtol: 相对误差容忍度(仅对浮点型有效) + atol: 绝对误差容忍度(仅对浮点型有效) + """ + # 1. 先校验两个 test_data 的字段名完全一致 + keys1 = set(test_data1.keys()) + keys2 = set(test_data2.keys()) + assert ( + keys1 == keys2 + ), f"两个 test_data 字段不一致!\n仅在第一个中存在:{keys1 - keys2}\n仅在第二个中存在:{keys2 - keys1}" + + # 2. 逐字段校验数据 + for key in keys1: + data1 = test_data1[key] + data2 = test_data2[key] + + # 区分:paddle Tensor(需转 numpy)和 普通标量/数组(直接使用) + if isinstance(data1, paddle.Tensor): + # 转换为 numpy:自动处理 device → host(.cpu())、阻止梯度计算(.detach()) + np1 = data1.detach().cpu().numpy() + else: + np1 = np.asarray(data1) # 非 Tensor(如 int/float)转为 numpy 统一格式 + + if isinstance(data2, paddle.Tensor): + np2 = data2.detach().cpu().numpy() + else: + np2 = np.asarray(data2) + + # 3. 校验数据(分类型处理:布尔型/整数型 严格相等,浮点型 允许微小误差) + if np1.dtype in (np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8): + # 布尔/整数型:必须完全相等 + assert np.array_equal(np1, np2), f"字段 {key} 数据不一致!\n第一个数据:{np1}\n第二个数据:{np2}" + else: + # 浮点型:允许 rtol/atol 范围内的误差(如 float32/float64) + assert np.allclose( + np1, np2, rtol=rtol, atol=atol + ), f"字段 {key} 浮点数据不一致!\n相对误差:{rtol},绝对误差:{atol}\n第一个数据:{np1}\n第二个数据:{np2}" + + print("✅ 两个 test_data 结构和数据完全一致!") + + +def test_speculate_step_paddle(): + data_cpu, data_xpu = test_data() + # check before test + assert_test_data_equal(data_xpu, data_cpu) + result_xpu = speculate_step_paddle_execution(data_xpu) + result_cpu = speculate_step_paddle_execution(data_cpu) + # check after test + assert_test_data_equal(result_xpu, result_cpu) + -max_bs = 128 -bs = max_bs -max_seq_len = 8192 -block_size = 64 -block_bs = 8 -block_ratio = 0.75 -max_draft_tokens = 1 - -stop_flags = np.random.randint(0, 2, [max_bs]).astype("bool") -seq_lens_this_time = np.zeros([bs], "int32") -seq_lens_encoder = np.zeros([max_bs], "int32") -seq_lens_decoder = np.zeros([max_bs], "int32") -step_idx = np.zeros([max_bs], "int64") -accept_num = np.random.randint(1, 3, [max_bs]).astype("int32") -for i in range(bs): - seq_lens_decoder[i] = 2 + i * 2 - seq_lens_this_time[i] = 1 -ori_seq_lens_encoder = np.zeros([max_bs], "int32") -ori_seq_lens_encoder[:] = seq_lens_decoder[:] // 2 -step_idx = (seq_lens_decoder - ori_seq_lens_encoder).astype("int64") - -max_block_num = block_bs * max_seq_len // block_size -free_list_len = int(max_block_num * (1 - block_ratio)) -free_list_len = np.full([1], free_list_len, "int32") -free_list = np.arange(max_block_num - 1, max_block_num - free_list_len - 1, -1, dtype="int32") - -encoder_block_lens = np.zeros([max_bs], "int32") -used_list_len = np.zeros([max_bs], "int32") -block_tables = np.full([max_bs, 128], -1, "int32") -encoder_block_id = 0 -for i in range(bs): - enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size - encoder_block_lens[i] = enc_block_num - dec_block_num = (seq_lens_decoder[i] + block_size - 1) // block_size - enc_block_num - used_list_len[i] = dec_block_num - block_tables[i, :enc_block_num] = np.arange(encoder_block_id, encoder_block_id + enc_block_num, 1, "int32") - encoder_block_id += enc_block_num - if dec_block_num > 0: - block_tables[i, enc_block_num : enc_block_num + dec_block_num] = free_list[ - free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1 - ] - free_list[free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1] = -1 - free_list_len[0] -= dec_block_num -assert free_list_len[0] >= 0 - -is_block_step = np.zeros([max_bs], "bool") -is_block_step[:bs] = np.random.randint(0, 2, [bs]).astype("bool") -step_block_list = np.full([max_bs], -1, "int32") -step_lens = np.full([1], 0, "int32") -for i in range(bs): - if is_block_step[i]: - step_block_list[step_lens[0]] = i - step_lens[0] += 1 - -recover_lens = np.full([1], 0, "int32") -recover_block_list = np.full([max_bs], -1, "int32") - -need_block_len = np.full([1], 0, "int32") -need_block_list = np.full([max_bs], -1, "int32") - -input_ids = np.random.randint(0, 1000, [max_bs, max_seq_len], "int64") -pre_ids = np.random.randint(0, 1000, [max_bs, max_seq_len], "int64") - -next_tokens = np.random.randint(0, 1000, [max_bs], "int64") -encoder_decoder_block_num = 1 -first_token_ids = np.random.randint(0, 1000, [max_bs], "int64") - -stop_flags = paddle.to_tensor(stop_flags) -seq_lens_this_time = paddle.to_tensor(seq_lens_this_time) -seq_lens_encoder = paddle.to_tensor(seq_lens_encoder) -seq_lens_decoder = paddle.to_tensor(seq_lens_decoder) -ori_seq_lens_encoder = paddle.to_tensor(ori_seq_lens_encoder) -block_tables = paddle.to_tensor(block_tables) -encoder_block_lens = paddle.to_tensor(encoder_block_lens) -is_block_step = paddle.to_tensor(is_block_step) -step_block_list = paddle.to_tensor(step_block_list) -step_lens = paddle.to_tensor(step_lens) -recover_lens = paddle.to_tensor(recover_lens) -recover_block_list = paddle.to_tensor(recover_block_list) -need_block_list = paddle.to_tensor(need_block_list) -need_block_len = paddle.to_tensor(need_block_len) -used_list_len = paddle.to_tensor(used_list_len) -free_list_len = paddle.to_tensor(free_list_len) -free_list = paddle.to_tensor(free_list) -input_ids = paddle.to_tensor(input_ids) -pre_ids = paddle.to_tensor(pre_ids) -step_idx = paddle.to_tensor(step_idx) -next_tokens = paddle.to_tensor(next_tokens) -first_token_ids = paddle.to_tensor(first_token_ids) -accept_num = paddle.to_tensor(accept_num) - -print("-" * 50 + "before step op" + "-" * 50) -print("stop_flags: ", stop_flags) -print("seq_lens_this_time: ", seq_lens_this_time) -print("seq_lens_encoder: ", seq_lens_encoder) -print("seq_lens_decoder: ", seq_lens_decoder) -print("ori_seq_lens_encoder: ", ori_seq_lens_encoder) -print("block_tables: ", block_tables) -print("encoder_block_lens: ", encoder_block_lens) -print("is_block_step: ", is_block_step) -print("step_block_list: ", step_block_list) -print("step_lens: ", step_lens) -print("recover_lens: ", recover_lens) -print("recover_block_list: ", recover_block_list) -print("need_block_list: ", need_block_list) -print("need_block_len: ", need_block_len) -print("used_list_len: ", used_list_len) -print("free_list_len: ", free_list_len) -print("free_list: ", free_list) -print("input_ids: ", input_ids) -print("pre_ids: ", pre_ids) -print("step_idx: ", step_idx) -print("next_tokens: ", next_tokens) -print("accept_num: ", accept_num) - -speculate_step_paddle( - stop_flags, - seq_lens_this_time, - ori_seq_lens_encoder, - seq_lens_encoder, - seq_lens_decoder, - block_tables, - encoder_block_lens, - is_block_step, - step_block_list, - step_lens, - recover_block_list, - recover_lens, - need_block_list, - need_block_len, - used_list_len, - free_list, - free_list_len, - input_ids, - pre_ids, - step_idx, - next_tokens, - first_token_ids, - accept_num, - block_size, - encoder_decoder_block_num, - max_draft_tokens, -) - -print("-" * 50 + "after step op" + "-" * 50) -print("stop_flags: ", stop_flags) -print("seq_lens_this_time: ", seq_lens_this_time) -print("seq_lens_encoder: ", seq_lens_encoder) -print("seq_lens_decoder: ", seq_lens_decoder) -print("ori_seq_lens_encoder: ", ori_seq_lens_encoder) -print("block_tables: ", block_tables) -print("encoder_block_lens: ", encoder_block_lens) -print("is_block_step: ", is_block_step) -print("step_block_list: ", step_block_list) -print("step_lens: ", step_lens) -print("recover_lens: ", recover_lens) -print("recover_block_list: ", recover_block_list) -print("need_block_list: ", need_block_list) -print("need_block_len: ", need_block_len) -print("used_list_len: ", used_list_len) -print("free_list_len: ", free_list_len) -print("free_list: ", free_list) -print("input_ids: ", input_ids) -print("pre_ids: ", pre_ids) -print("step_idx: ", step_idx) -print("next_tokens: ", next_tokens) -print("first_token_ids: ", first_token_ids) -print("accept_num: ", accept_num) +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) # 直接运行时执行 pytest 并显示详细日志 From d469eab285d3b4cd9d5a8adb0cc20ac1daaf62b5 Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Fri, 7 Nov 2025 04:44:10 +0000 Subject: [PATCH 08/17] fix --- custom_ops/xpu_ops/test/test_speculate_step.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/custom_ops/xpu_ops/test/test_speculate_step.py b/custom_ops/xpu_ops/test/test_speculate_step.py index 852c144ccbd..3bbf37bfd14 100644 --- a/custom_ops/xpu_ops/test/test_speculate_step.py +++ b/custom_ops/xpu_ops/test/test_speculate_step.py @@ -194,7 +194,7 @@ def speculate_step_paddle_execution(test_data): max_draft_tokens = test_data["max_draft_tokens"] # 可选:打印执行前关键信息(如需调试可开启) - if os.environ.get("ATTN_MASK_TEST_DEBUG", "0") == "1": + if os.environ.get("STEP_TEST_DEBUG", "0") == "1": print("-" * 50 + "before step op" + "-" * 50) print("stop_flags: ", stop_flags) print("seq_lens_this_time: ", seq_lens_this_time) @@ -247,7 +247,7 @@ def speculate_step_paddle_execution(test_data): max_draft_tokens, ) - if os.environ.get("ATTN_MASK_TEST_DEBUG", "0") == "1": + if os.environ.get("STEP_TEST_DEBUG", "0") == "1": # 可选:打印执行后关键信息(如需调试可开启) print("-" * 50 + "before step op" + "-" * 50) print("stop_flags: ", stop_flags) From 95c7988d129f1ef6b1ad3fd29787aac0f3a9ec0e Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Mon, 24 Nov 2025 08:18:20 +0000 Subject: [PATCH 09/17] mv pre/post process --- .../model_executor/pre_and_post_process.py | 290 +++++++++++++++++ fastdeploy/worker/xpu_model_runner.py | 296 +----------------- 2 files changed, 300 insertions(+), 286 deletions(-) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index d2c82e2afaa..e01ff0fdae0 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -22,6 +22,7 @@ from fastdeploy import envs from fastdeploy.config import SpeculativeConfig +from fastdeploy.model_executor.forward_meta import XPUForwardMeta from fastdeploy.platforms import current_platform if current_platform.is_iluvatar(): @@ -64,6 +65,20 @@ ) elif current_platform.is_intel_hpu(): pass +elif current_platform.is_xpu(): + from fastdeploy.model_executor.ops.xpu import ( + adjust_batch, + gather_next_token, + get_infer_param, + get_padding_offset, + limit_thinking_content_length_v1, + limit_thinking_content_length_v2, + save_output, + set_stop_value_multi_ends, + step_paddle, + update_inputs, + update_inputs_v1, + ) else: from fastdeploy.model_executor.ops.gpu import ( get_padding_offset, @@ -921,3 +936,278 @@ def post_process_pooling( if save_each_rank or model_output.mp_rank == 0: output = _build_stream_transfer_data(output_tokens=None, pooler_outputs=pooler_output.outputs) async_output_queue.put(output) + + +def xpu_pre_process( + input_ids: paddle.Tensor, + seq_lens_this_time: int, + share_inputs: Dict, + use_speculate_method: bool, + block_size: int, + draft_tokens: Optional[paddle.Tensor] = None, + seq_lens_encoder: Optional[paddle.Tensor] = None, + seq_lens_decoder: Optional[paddle.Tensor] = None, + is_profiling: bool = False, +) -> XPUForwardMeta: + """ """ + max_len = input_ids.shape[1] + cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32") + token_num = paddle.sum(seq_lens_this_time) + + ( + ids_remove_padding, + cum_offsets, + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + ) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time) + + share_inputs["ids_remove_padding"] = None # set this after adjust batch + share_inputs["cum_offsets"] = cum_offsets + share_inputs["batch_id_per_token"] = batch_id_per_token + share_inputs["cu_seqlens_q"] = cu_seqlens_q + share_inputs["cu_seqlens_k"] = cu_seqlens_k + + xpu_forward_meta = XPUForwardMeta( + ids_remove_padding=share_inputs["ids_remove_padding"], + rotary_embs=share_inputs["rope_emb"], + attn_backend=None, + seq_lens_encoder=share_inputs["seq_lens_encoder"], + seq_lens_decoder=share_inputs["seq_lens_decoder"], + seq_lens_this_time=share_inputs["seq_lens_this_time"], + cum_offsets=share_inputs["cum_offsets"], + batch_id_per_token=share_inputs["batch_id_per_token"], + cu_seqlens_q=share_inputs["cu_seqlens_q"], + cu_seqlens_k=share_inputs["cu_seqlens_k"], + block_tables=share_inputs["block_tables"], + caches=share_inputs["caches"], + ) + + ( + xpu_forward_meta.encoder_batch_map, + xpu_forward_meta.decoder_batch_map, + xpu_forward_meta.encoder_batch_idx, + xpu_forward_meta.decoder_batch_idx, + xpu_forward_meta.encoder_seq_lod, + xpu_forward_meta.decoder_seq_lod, + xpu_forward_meta.encoder_kv_lod, + xpu_forward_meta.prefix_len, + xpu_forward_meta.decoder_context_len, + xpu_forward_meta.decoder_context_len_cache, + xpu_forward_meta.prefix_block_tables, + xpu_forward_meta.encoder_batch_map_cpu, + xpu_forward_meta.decoder_batch_map_cpu, + xpu_forward_meta.encoder_batch_idx_cpu, + xpu_forward_meta.decoder_batch_idx_cpu, + xpu_forward_meta.encoder_seq_lod_cpu, + xpu_forward_meta.decoder_seq_lod_cpu, + xpu_forward_meta.encoder_kv_lod_cpu, + xpu_forward_meta.prefix_len_cpu, + xpu_forward_meta.decoder_context_len_cpu, + xpu_forward_meta.decoder_context_len_cache_cpu, + xpu_forward_meta.len_info_cpu, + ) = get_infer_param( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, xpu_forward_meta.block_tables, block_size + ) + xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0] + xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1] + xpu_forward_meta.total_enc_len = xpu_forward_meta.len_info_cpu[2] + + adjusted_input = adjust_batch( + ids_remove_padding.reshape([-1, 1]), + cum_offsets, + xpu_forward_meta.encoder_seq_lod, + xpu_forward_meta.encoder_batch_idx, + xpu_forward_meta.decoder_batch_idx, + xpu_forward_meta.encoder_seq_lod_cpu, + xpu_forward_meta.encoder_batch_idx_cpu, + xpu_forward_meta.decoder_batch_idx_cpu, + xpu_forward_meta.enc_batch, + xpu_forward_meta.dec_batch, + None, # output_padding_offset + -1, # max_input_length + ) + + adjusted_input = adjusted_input.squeeze(1) + + share_inputs["ids_remove_padding"] = adjusted_input + xpu_forward_meta.ids_remove_padding = adjusted_input + # Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends + xpu_forward_meta.is_profiling = is_profiling + return xpu_forward_meta + + +def xpu_process_output( + forward_output, + cum_offsets: paddle.Tensor, + xpu_forward_meta: XPUForwardMeta, +) -> paddle.Tensor: + """ """ + + hiddden_states = gather_next_token( + forward_output, + cum_offsets, + xpu_forward_meta.encoder_seq_lod, + xpu_forward_meta.encoder_batch_map, + xpu_forward_meta.decoder_batch_map, + xpu_forward_meta.encoder_seq_lod_cpu, + xpu_forward_meta.encoder_batch_map_cpu, + xpu_forward_meta.decoder_batch_map_cpu, + xpu_forward_meta.enc_batch, + xpu_forward_meta.dec_batch, + None, # output_padding_offset + -1, # max_input_length + ) + return hiddden_states + + +def xpu_post_process_normal( + sampled_token_ids: paddle.Tensor, + model_output: ModelOutputData, + share_inputs: Dict[str, paddle.Tensor], + block_size: int = 64, + skip_save_output: bool = False, + think_end_id: int = None, + line_break_id: int = None, +) -> None: + """ """ + from fastdeploy.model_executor.ops.xpu import ( + save_output, + set_stop_value_multi_ends, + update_inputs, + ) + + if think_end_id > 0: + limit_strategy = envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR + max_think_lens = share_inputs["max_think_lens"] + step_idx = share_inputs["step_idx"] + limit_think_status = share_inputs["limit_think_status"] + stop_flags = share_inputs["stop_flags"] + eos_token_ids = share_inputs["eos_token_id"] + if limit_strategy == "": + # for ernie-45-vl + limit_thinking_content_length_v1( + sampled_token_ids, + max_think_lens, + step_idx, + limit_think_status, + stop_flags, + eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题 + think_end_id, + ) + elif limit_strategy == "\n\n\n": + # for ernie-x1 + assert line_break_id > 0 + limit_thinking_content_length_v2( + sampled_token_ids, + max_think_lens, + step_idx, + limit_think_status, + stop_flags, + think_end_id, + line_break_id, + ) + else: + raise NotImplementedError(f"Not support {limit_strategy=} for limit thinking content length.") + + # 1. Set stop value + paddle.assign( + paddle.where( + model_output.stop_flags, + model_output.step_idx, + model_output.step_idx + 1, + ), + model_output.step_idx, + ) + length_cond = paddle.greater_equal(model_output.step_idx, model_output.max_dec_len) + paddle.assign( + paddle.logical_or(model_output.stop_flags, length_cond), + model_output.stop_flags, + ) + set_stop_value_multi_ends( + sampled_token_ids, + model_output.stop_flags, + model_output.seq_lens_this_time, + model_output.eos_token_id, + model_output.next_tokens, + False, + ) # multi ends + + # 2. Update the input buffer of the model + with paddle.framework._no_check_dy2st_diff(): + if envs.ENABLE_V1_KVCACHE_SCHEDULER and not skip_save_output: + update_inputs_v1( + model_output.stop_flags, + model_output.not_need_stop, + model_output.seq_lens_this_time, + model_output.seq_lens_encoder, + model_output.seq_lens_decoder, + share_inputs["step_seq_lens_decoder"], + share_inputs["prompt_lens"], + sampled_token_ids, + model_output.input_ids, + share_inputs["block_tables"], + model_output.stop_nums, + model_output.next_tokens, + model_output.is_block_step, + block_size, + ) + else: + update_inputs( + model_output.stop_flags, + model_output.not_need_stop, + model_output.seq_lens_this_time, + model_output.seq_lens_encoder, + model_output.seq_lens_decoder, + model_output.input_ids, + model_output.stop_nums, + sampled_token_ids, + model_output.is_block_step, + ) + # 3. Transmit the model's output and stop generation signal via message queue. + # In the future, we will abandon this approach. + if not skip_save_output: + save_output( + sampled_token_ids, + model_output.not_need_stop, + model_output.mp_rank, + False, # use_ep + ) + + +def step_xpu( + share_inputs: Dict[str, paddle.Tensor], + block_size: int, + enc_dec_block_num: int, +) -> None: + """ + TODO(gongshaotian): normalization name + """ + from fastdeploy.model_executor.ops.xpu import step_paddle + + step_paddle( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + block_size, + enc_dec_block_num, + ) diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index b60eb8cdfc9..ac747879f6b 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -17,7 +17,7 @@ import os import random import time -from typing import Dict, List, Optional +from typing import List, Optional import numpy as np import paddle @@ -28,7 +28,7 @@ from fastdeploy.engine.request import Request, RequestType from fastdeploy.input.ernie4_5_vl_processor import DataProcessor from fastdeploy.inter_communicator import IPCSignal -from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta +from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.graph_optimization.utils import ( profile_run_guard, sot_warmup_guard, @@ -43,17 +43,17 @@ from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp from fastdeploy.model_executor.ops.xpu import ( - adjust_batch, create_kv_signal_sender, destroy_kv_signal_sender, - get_infer_param, - get_padding_offset, - limit_thinking_content_length_v1, - limit_thinking_content_length_v2, recover_decode_task, set_data_ipc, share_external_data, - update_inputs_v1, +) +from fastdeploy.model_executor.pre_and_post_process import ( # xpu_post_process_specualate, # TODO(chenhuan09): add xpu_post_process_specualate + step_xpu, + xpu_post_process_normal, + xpu_pre_process, + xpu_process_output, ) from fastdeploy.utils import get_logger from fastdeploy.worker.model_runner_base import ModelRunnerBase @@ -62,282 +62,6 @@ logger = get_logger("xpu_model_runner", "xpu_model_runner.log") -def xpu_pre_process( - input_ids: paddle.Tensor, - seq_lens_this_time: int, - share_inputs: Dict, - use_speculate_method: bool, - block_size: int, - draft_tokens: Optional[paddle.Tensor] = None, - seq_lens_encoder: Optional[paddle.Tensor] = None, - seq_lens_decoder: Optional[paddle.Tensor] = None, - is_profiling: bool = False, -) -> XPUForwardMeta: - """ """ - max_len = input_ids.shape[1] - cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32") - token_num = paddle.sum(seq_lens_this_time) - - ( - ids_remove_padding, - cum_offsets, - batch_id_per_token, - cu_seqlens_q, - cu_seqlens_k, - ) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time) - - share_inputs["ids_remove_padding"] = None # set this after adjust batch - share_inputs["cum_offsets"] = cum_offsets - share_inputs["batch_id_per_token"] = batch_id_per_token - share_inputs["cu_seqlens_q"] = cu_seqlens_q - share_inputs["cu_seqlens_k"] = cu_seqlens_k - - xpu_forward_meta = XPUForwardMeta( - ids_remove_padding=share_inputs["ids_remove_padding"], - rotary_embs=share_inputs["rope_emb"], - attn_backend=None, - seq_lens_encoder=share_inputs["seq_lens_encoder"], - seq_lens_decoder=share_inputs["seq_lens_decoder"], - seq_lens_this_time=share_inputs["seq_lens_this_time"], - cum_offsets=share_inputs["cum_offsets"], - batch_id_per_token=share_inputs["batch_id_per_token"], - cu_seqlens_q=share_inputs["cu_seqlens_q"], - cu_seqlens_k=share_inputs["cu_seqlens_k"], - block_tables=share_inputs["block_tables"], - caches=share_inputs["caches"], - ) - - ( - xpu_forward_meta.encoder_batch_map, - xpu_forward_meta.decoder_batch_map, - xpu_forward_meta.encoder_batch_idx, - xpu_forward_meta.decoder_batch_idx, - xpu_forward_meta.encoder_seq_lod, - xpu_forward_meta.decoder_seq_lod, - xpu_forward_meta.encoder_kv_lod, - xpu_forward_meta.prefix_len, - xpu_forward_meta.decoder_context_len, - xpu_forward_meta.decoder_context_len_cache, - xpu_forward_meta.prefix_block_tables, - xpu_forward_meta.encoder_batch_map_cpu, - xpu_forward_meta.decoder_batch_map_cpu, - xpu_forward_meta.encoder_batch_idx_cpu, - xpu_forward_meta.decoder_batch_idx_cpu, - xpu_forward_meta.encoder_seq_lod_cpu, - xpu_forward_meta.decoder_seq_lod_cpu, - xpu_forward_meta.encoder_kv_lod_cpu, - xpu_forward_meta.prefix_len_cpu, - xpu_forward_meta.decoder_context_len_cpu, - xpu_forward_meta.decoder_context_len_cache_cpu, - xpu_forward_meta.len_info_cpu, - ) = get_infer_param( - seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, xpu_forward_meta.block_tables, block_size - ) - xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0] - xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1] - xpu_forward_meta.total_enc_len = xpu_forward_meta.len_info_cpu[2] - - adjusted_input = adjust_batch( - ids_remove_padding.reshape([-1, 1]), - cum_offsets, - xpu_forward_meta.encoder_seq_lod, - xpu_forward_meta.encoder_batch_idx, - xpu_forward_meta.decoder_batch_idx, - xpu_forward_meta.encoder_seq_lod_cpu, - xpu_forward_meta.encoder_batch_idx_cpu, - xpu_forward_meta.decoder_batch_idx_cpu, - xpu_forward_meta.enc_batch, - xpu_forward_meta.dec_batch, - None, # output_padding_offset - -1, # max_input_length - ) - - adjusted_input = adjusted_input.squeeze(1) - - share_inputs["ids_remove_padding"] = adjusted_input - xpu_forward_meta.ids_remove_padding = adjusted_input - # Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends - xpu_forward_meta.is_profiling = is_profiling - return xpu_forward_meta - - -def xpu_process_output( - forward_output, - cum_offsets: paddle.Tensor, - xpu_forward_meta: XPUForwardMeta, -) -> paddle.Tensor: - """ """ - from fastdeploy.model_executor.ops.xpu import gather_next_token - - hiddden_states = gather_next_token( - forward_output, - cum_offsets, - xpu_forward_meta.encoder_seq_lod, - xpu_forward_meta.encoder_batch_map, - xpu_forward_meta.decoder_batch_map, - xpu_forward_meta.encoder_seq_lod_cpu, - xpu_forward_meta.encoder_batch_map_cpu, - xpu_forward_meta.decoder_batch_map_cpu, - xpu_forward_meta.enc_batch, - xpu_forward_meta.dec_batch, - None, # output_padding_offset - -1, # max_input_length - ) - return hiddden_states - - -def xpu_post_process( - sampled_token_ids: paddle.Tensor, - model_output: ModelOutputData, - share_inputs: Dict[str, paddle.Tensor], - block_size: int = 64, - skip_save_output: bool = False, - think_end_id: int = None, - line_break_id: int = None, -) -> None: - """ """ - from fastdeploy.model_executor.ops.xpu import ( - save_output, - set_stop_value_multi_ends, - update_inputs, - ) - - if think_end_id > 0: - limit_strategy = envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR - max_think_lens = share_inputs["max_think_lens"] - step_idx = share_inputs["step_idx"] - limit_think_status = share_inputs["limit_think_status"] - stop_flags = share_inputs["stop_flags"] - eos_token_ids = share_inputs["eos_token_id"] - if limit_strategy == "": - # for ernie-45-vl - limit_thinking_content_length_v1( - sampled_token_ids, - max_think_lens, - step_idx, - limit_think_status, - stop_flags, - eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题 - think_end_id, - ) - elif limit_strategy == "\n\n\n": - # for ernie-x1 - assert line_break_id > 0 - limit_thinking_content_length_v2( - sampled_token_ids, - max_think_lens, - step_idx, - limit_think_status, - stop_flags, - think_end_id, - line_break_id, - ) - else: - raise NotImplementedError(f"Not support {limit_strategy=} for limit thinking content length.") - - # 1. Set stop value - paddle.assign( - paddle.where( - model_output.stop_flags, - model_output.step_idx, - model_output.step_idx + 1, - ), - model_output.step_idx, - ) - length_cond = paddle.greater_equal(model_output.step_idx, model_output.max_dec_len) - paddle.assign( - paddle.logical_or(model_output.stop_flags, length_cond), - model_output.stop_flags, - ) - set_stop_value_multi_ends( - sampled_token_ids, - model_output.stop_flags, - model_output.seq_lens_this_time, - model_output.eos_token_id, - model_output.next_tokens, - False, - ) # multi ends - - # 2. Update the input buffer of the model - with paddle.framework._no_check_dy2st_diff(): - if envs.ENABLE_V1_KVCACHE_SCHEDULER and not skip_save_output: - update_inputs_v1( - model_output.stop_flags, - model_output.not_need_stop, - model_output.seq_lens_this_time, - model_output.seq_lens_encoder, - model_output.seq_lens_decoder, - share_inputs["step_seq_lens_decoder"], - share_inputs["prompt_lens"], - sampled_token_ids, - model_output.input_ids, - share_inputs["block_tables"], - model_output.stop_nums, - model_output.next_tokens, - model_output.is_block_step, - block_size, - ) - else: - update_inputs( - model_output.stop_flags, - model_output.not_need_stop, - model_output.seq_lens_this_time, - model_output.seq_lens_encoder, - model_output.seq_lens_decoder, - model_output.input_ids, - model_output.stop_nums, - sampled_token_ids, - model_output.is_block_step, - ) - # 3. Transmit the model's output and stop generation signal via message queue. - # In the future, we will abandon this approach. - if not skip_save_output: - save_output( - sampled_token_ids, - model_output.not_need_stop, - model_output.mp_rank, - False, # use_ep - ) - - -def step_paddle( - share_inputs: Dict[str, paddle.Tensor], - block_size: int, - enc_dec_block_num: int, -) -> None: - """ - TODO(gongshaotian): normalization name - """ - from fastdeploy.model_executor.ops.xpu import step_paddle - - step_paddle( - share_inputs["stop_flags"], - share_inputs["seq_lens_this_time"], - share_inputs["step_seq_lens_encoder"], - share_inputs["seq_lens_encoder"], - share_inputs["seq_lens_decoder"], - share_inputs["block_tables"], - share_inputs["encoder_block_lens"], - share_inputs["is_block_step"], - share_inputs["step_block_list"], - share_inputs["step_lens"], - share_inputs["recover_block_list"], - share_inputs["recover_lens"], - share_inputs["need_block_list"], - share_inputs["need_block_len"], - share_inputs["used_list_len"], - share_inputs["free_list"], - share_inputs["free_list_len"], - share_inputs["input_ids"], - share_inputs["pre_ids"], - share_inputs["step_idx"], - share_inputs["next_tokens"], - share_inputs["first_token_ids"], - block_size, - enc_dec_block_num, - ) - - class XPUModelRunner(ModelRunnerBase): """ """ @@ -1247,7 +971,7 @@ class at the server level, which is too granular for ModelRunner. stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], ) - xpu_post_process( + xpu_post_process_normal( sampled_token_ids=sampler_output.sampled_token_ids, model_output=model_output_data, share_inputs=self.share_inputs, @@ -1260,7 +984,7 @@ class at the server level, which is too granular for ModelRunner. # 7. Updata 'infer_seed' and step_paddle() self.share_inputs["infer_seed"].add_(self.infer_seed_increment) self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED - step_paddle( + step_xpu( self.share_inputs, self.cache_config.block_size, self.cache_config.enc_dec_block_num, From d3ce9686fd55ca517a4f114329b3abb30c21283c Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Mon, 24 Nov 2025 10:07:06 +0000 Subject: [PATCH 10/17] add adjust batch / gather next token for mtp --- custom_ops/xpu_ops/src/ops/adjust_batch.cc | 66 +++-- .../xpu_ops/src/ops/gather_next_token.cc | 193 +++++++++------ custom_ops/xpu_ops/src/ops/pybind/pybind.cc | 28 ++- .../xpu_ops/src/plugin/include/xpu/plugin.h | 12 + .../src/kernel/kunlun3cpp/eb_adjust_batch.xpu | 23 +- .../mtp_kernel/eb_mtp_gather_next_token.xpu | 128 ++++++++++ .../plugin/src/wrapper/eb_adjust_batch.cpp | 40 ++- .../mtp_wrapper/eb_mtp_gather_next_token.cpp | 227 ++++++++++++++++++ ...test_adjust_batch_and_gather_next_token.py | 196 +++++++++++++++ .../model_executor/pre_and_post_process.py | 7 +- 10 files changed, 785 insertions(+), 135 deletions(-) create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/eb_mtp_gather_next_token.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/eb_mtp_gather_next_token.cpp create mode 100644 custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py diff --git a/custom_ops/xpu_ops/src/ops/adjust_batch.cc b/custom_ops/xpu_ops/src/ops/adjust_batch.cc index d263d2cae5c..fb3b3168856 100644 --- a/custom_ops/xpu_ops/src/ops/adjust_batch.cc +++ b/custom_ops/xpu_ops/src/ops/adjust_batch.cc @@ -18,38 +18,49 @@ #include "utility/helper.h" #include "xpu/plugin.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + template std::vector AdjustBatchKernel( const paddle::Tensor &x, // [token_num, dim_embed] const paddle::Tensor &cum_offsets, // [bsz, 1] const paddle::Tensor &encoder_seq_lod, + const paddle::Tensor &decoder_seq_lod, const paddle::Tensor &encoder_batch_idx, const paddle::Tensor &decoder_batch_idx, const paddle::Tensor &encoder_seq_lod_cpu, + const paddle::Tensor &decoder_seq_lod_cpu, const paddle::Tensor &encoder_batch_idx_cpu, const paddle::Tensor &decoder_batch_idx_cpu, - const paddle::Tensor &enc_batch_tensor, - const paddle::Tensor &dec_batch_tensor, + const paddle::Tensor &len_info_cpu, const paddle::optional &output_padding_offset, int max_input_length) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); + auto ctx = static_cast(dev_ctx)->x_context(); PD_CHECK(x.dtype() == T); PD_CHECK(x.dims().size() == 2); - + if (x.is_cpu()) { + ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU); + } using XPUType = typename XPUTypeTrait::DataType>::Type; using data_t = typename PDTraits::data_t; const int token_num = x.dims()[0]; const int dim = x.dims()[1]; const int bsz = cum_offsets.shape()[0]; - int enc_batch = enc_batch_tensor.data()[0]; - int dec_batch = dec_batch_tensor.data()[0]; + int enc_batch = len_info_cpu.data()[0]; + int dec_batch = len_info_cpu.data()[1]; baidu::xpu::api::VectorParam encoder_seqs_lods_vp{ const_cast(encoder_seq_lod_cpu.data()), enc_batch + 1, const_cast(encoder_seq_lod.data())}; + baidu::xpu::api::VectorParam decoder_seqs_lods_vp{ + const_cast(decoder_seq_lod_cpu.data()), + dec_batch + 1, + const_cast(decoder_seq_lod.data())}; baidu::xpu::api::VectorParam encoder_batch_map_vp{ const_cast(encoder_batch_idx_cpu.data()), enc_batch, @@ -59,13 +70,14 @@ std::vector AdjustBatchKernel( dec_batch, const_cast(decoder_batch_idx.data())}; - auto out = paddle::full({token_num, dim}, -2, x.type(), x.place()); + auto out = paddle::empty({token_num, dim}, x.type(), x.place()); int r = baidu::xpu::api::plugin::eb_adjust_batch( - xpu_ctx->x_context(), + ctx, reinterpret_cast(x.data()), reinterpret_cast(out.data()), encoder_seqs_lods_vp, + decoder_seqs_lods_vp, encoder_batch_map_vp, decoder_batch_map_vp, dim); @@ -76,13 +88,14 @@ using AdjustBatchKernelFuncPtr = std::vector (*)( const paddle::Tensor &x, // [token_num, dim_embed] const paddle::Tensor &cum_offsets, // [bsz, 1] const paddle::Tensor &encoder_seq_lod, + const paddle::Tensor &decoder_seq_lod, const paddle::Tensor &encoder_batch_idx, const paddle::Tensor &decoder_batch_idx, const paddle::Tensor &encoder_seq_lod_cpu, + const paddle::Tensor &decoder_seq_lod_cpu, const paddle::Tensor &encoder_batch_idx_cpu, const paddle::Tensor &decoder_batch_idx_cpu, - const paddle::Tensor &enc_batch_tensor, - const paddle::Tensor &dec_batch_tensor, + const paddle::Tensor &len_info_cpu, const paddle::optional &output_padding_offset, int max_input_length); @@ -90,13 +103,14 @@ std::vector AdjustBatch( const paddle::Tensor &x, // [token_num, dim_embed] const paddle::Tensor &cum_offsets, // [bsz, 1] const paddle::Tensor &encoder_seq_lod, + const paddle::Tensor &decoder_seq_lod, const paddle::Tensor &encoder_batch_idx, const paddle::Tensor &decoder_batch_idx, const paddle::Tensor &encoder_seq_lod_cpu, + const paddle::Tensor &decoder_seq_lod_cpu, const paddle::Tensor &encoder_batch_idx_cpu, const paddle::Tensor &decoder_batch_idx_cpu, - const paddle::Tensor &enc_batch_tensor, - const paddle::Tensor &dec_batch_tensor, + const paddle::Tensor &len_info_cpu, const paddle::optional &output_padding_offset, int max_input_length) { AdjustBatchKernelFuncPtr func = nullptr; @@ -108,12 +122,12 @@ std::vector AdjustBatch( case paddle::DataType::FLOAT16: func = &AdjustBatchKernel; break; - case paddle::DataType::FLOAT32: - func = &AdjustBatchKernel; - break; case paddle::DataType::INT64: func = &AdjustBatchKernel; break; + case paddle::DataType::FLOAT32: + func = &AdjustBatchKernel; + break; default: PD_THROW("Unsupported data type: ", x.dtype()); } @@ -121,13 +135,14 @@ std::vector AdjustBatch( return func(x, cum_offsets, encoder_seq_lod, + decoder_seq_lod, encoder_batch_idx, decoder_batch_idx, encoder_seq_lod_cpu, + decoder_seq_lod_cpu, encoder_batch_idx_cpu, decoder_batch_idx_cpu, - enc_batch_tensor, - dec_batch_tensor, + len_info_cpu, output_padding_offset, max_input_length); } @@ -136,13 +151,14 @@ std::vector> AdjustBatchInferShape( const std::vector &x_shape, const std::vector &cum_offsets_shape, const std::vector &encoder_seq_lod_shape, + const std::vector &decoder_seq_lod_shape, const std::vector &encoder_batch_idx_shape, const std::vector &decoder_batch_idx_shape, const std::vector &encoder_seq_lod_cpu_shape, + const std::vector &decoder_seq_lod_cpu_shape, const std::vector &encoder_batch_idx_cpu_shape, const std::vector &decoder_batch_idx_cpu_shape, - const std::vector &enc_batch_tensor_shape, - const std::vector &dec_batch_tensor_shape, + const std::vector &len_info_cpu_shape, const paddle::optional> &output_padding_offset_shape) { if (output_padding_offset_shape) { PD_THROW("speculative decoding is not supported in XPU."); @@ -156,28 +172,30 @@ std::vector AdjustBatchInferDtype( const paddle::DataType &x_dtype, const paddle::DataType &cum_offsets_dtype, const paddle::DataType &encoder_seq_lod_dtype, + const paddle::DataType &decoder_seq_lod_dtype, const paddle::DataType &encoder_batch_idx_dtype, const paddle::DataType &decoder_batch_idx_dtype, const paddle::DataType &encoder_seq_lod_cpu_dtype, + const paddle::DataType &decoder_seq_lod_cpu_dtype, const paddle::DataType &encoder_batch_idx_cpu_dtype, const paddle::DataType &decoder_batch_idx_cpu_dtype, - const paddle::DataType &enc_batch_tensor_dtype, - const paddle::DataType &dec_batch_tensor_dtype, + const paddle::DataType &len_info_cpu_dtype, const paddle::optional &output_padding_offset_dtype) { return {x_dtype}; } -PD_BUILD_OP(adjust_batch) +PD_BUILD_STATIC_OP(adjust_batch) .Inputs({"x", "cum_offsets", "encoder_seq_lod", + "decoder_seq_lod", "encoder_batch_idx", "decoder_batch_idx", "encoder_seq_lod_cpu", + "decoder_seq_lod_cpu", "encoder_batch_idx_cpu", "decoder_batch_idx_cpu", - "enc_batch_tensor", - "dec_batch_tensor", + "len_info_cpu", paddle::Optional("output_padding_offset")}) .Outputs({"out"}) .Attrs({"max_input_length: int"}) diff --git a/custom_ops/xpu_ops/src/ops/gather_next_token.cc b/custom_ops/xpu_ops/src/ops/gather_next_token.cc index 18a773c5150..9a35f91f9c8 100644 --- a/custom_ops/xpu_ops/src/ops/gather_next_token.cc +++ b/custom_ops/xpu_ops/src/ops/gather_next_token.cc @@ -13,134 +13,169 @@ // limitations under the License. #include +#include #include "paddle/extension.h" #include "xpu/plugin.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + std::vector GatherNextToken( - const paddle::Tensor &tmp_out, // [token_num, dim_embed] - const paddle::Tensor &cum_offsets, // [bsz, 1] - const paddle::Tensor &encoder_seq_lod, - const paddle::Tensor &encoder_batch_map, - const paddle::Tensor &decoder_batch_map, - const paddle::Tensor &encoder_seq_lod_cpu, - const paddle::Tensor &encoder_batch_map_cpu, - const paddle::Tensor &decoder_batch_map_cpu, - const paddle::Tensor &enc_batch_tensor, - const paddle::Tensor &dec_batch_tensor, - const paddle::optional &output_padding_offset, - int max_input_length) { + const paddle::Tensor& x, // [token_num, dim_embed] + const paddle::Tensor& cum_offsets, // [bsz, 1] + const paddle::Tensor& encoder_seq_lod, + const paddle::Tensor& decoder_seq_lod, + const paddle::Tensor& encoder_batch_map, + const paddle::Tensor& decoder_batch_map, + const paddle::Tensor& encoder_seq_lod_cpu, + const paddle::Tensor& decoder_seq_lod_cpu, + const paddle::Tensor& encoder_batch_map_cpu, + const paddle::Tensor& decoder_batch_map_cpu, + const paddle::Tensor& len_info_cpu, + const paddle::optional& output_padding_offset, + int max_bsz) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); + auto ctx = static_cast(dev_ctx)->x_context(); + if (x.is_cpu()) { + ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU); + } using XPUType = typename XPUTypeTrait::Type; // only support bfloat16 typedef paddle::bfloat16 data_t; - const int dim = tmp_out.dims()[1]; - const int token_num = tmp_out.shape()[0]; - const int bsz = cum_offsets.shape()[0]; - int enc_batch = enc_batch_tensor.data()[0]; - int dec_batch = dec_batch_tensor.data()[0]; - + const int dim = x.dims()[1]; + const int token_num = x.shape()[0]; + int bsz = cum_offsets.shape()[0]; + int enc_batch = len_info_cpu.data()[0]; + int dec_batch = len_info_cpu.data()[1]; + if (max_bsz > 0) { + PD_CHECK(encoder_batch_map_cpu.data()[enc_batch - 1] <= max_bsz, + "encoder_batch_map_cpu check failed"); + PD_CHECK(decoder_batch_map_cpu.data()[dec_batch - 1] <= max_bsz, + "decoder_batch_map_cpu check failed"); + bsz = max_bsz; + } baidu::xpu::api::VectorParam encoder_seqs_lods_vp{ - const_cast(encoder_seq_lod_cpu.data()), + const_cast(encoder_seq_lod_cpu.data()), enc_batch + 1, - const_cast(encoder_seq_lod.data())}; + const_cast(encoder_seq_lod.data())}; + baidu::xpu::api::VectorParam decoder_seqs_lods_vp{ + const_cast(decoder_seq_lod_cpu.data()), + dec_batch + 1, + const_cast(decoder_seq_lod.data())}; baidu::xpu::api::VectorParam encoder_batch_map_vp{ - const_cast(encoder_batch_map_cpu.data()), + const_cast(encoder_batch_map_cpu.data()), enc_batch, - const_cast(encoder_batch_map.data())}; + const_cast(encoder_batch_map.data())}; baidu::xpu::api::VectorParam decoder_batch_map_vp{ - const_cast(decoder_batch_map_cpu.data()), + const_cast(decoder_batch_map_cpu.data()), dec_batch, - const_cast(decoder_batch_map.data())}; + const_cast(decoder_batch_map.data())}; paddle::Tensor out; - std::vector encode_iota_lod_cpu(enc_batch); if (output_padding_offset) { int need_delete_token_num = 0; if (enc_batch > 0) { need_delete_token_num = encoder_seq_lod_cpu.data()[enc_batch] - enc_batch; - std::iota(encode_iota_lod_cpu.begin(), encode_iota_lod_cpu.end(), 0); - encoder_batch_map_vp.cpu = - const_cast(encode_iota_lod_cpu.data()); - encoder_batch_map_vp.len = enc_batch; - encoder_batch_map_vp.xpu = nullptr; } - out = paddle::empty({token_num - need_delete_token_num, dim}, - tmp_out.type(), - tmp_out.place()); + out = paddle::empty( + {token_num - need_delete_token_num, dim}, x.type(), x.place()); } else { - out = paddle::empty({bsz, dim}, tmp_out.type(), tmp_out.place()); + out = paddle::empty({bsz, dim}, x.type(), x.place()); } - if (tmp_out.shape()[0] == 0) { + if (x.shape()[0] == 0) { return {out}; } - if (output_padding_offset && enc_batch <= 0) { - out = tmp_out.copy_to(tmp_out.place(), false); + if (enc_batch <= 0) { + out = x.copy_to(x.place(), false); } else { - int r = baidu::xpu::api::plugin::eb_gather_next_token( - xpu_ctx->x_context(), - reinterpret_cast(tmp_out.data()), - reinterpret_cast(out.data()), - encoder_seqs_lods_vp, - encoder_batch_map_vp, - decoder_batch_map_vp, - dim); - PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed."); + if (output_padding_offset) { + int r = + baidu::xpu::api::plugin::eb_mtp_gather_next_token( + ctx, + reinterpret_cast(x.data()), + reinterpret_cast(out.data()), + encoder_seqs_lods_vp, + decoder_seqs_lods_vp, + encoder_batch_map_vp, + decoder_batch_map_vp, + dim); + PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed."); + } else { + int r = baidu::xpu::api::plugin::eb_gather_next_token( + ctx, + reinterpret_cast(x.data()), + reinterpret_cast(out.data()), + encoder_seqs_lods_vp, + encoder_batch_map_vp, + decoder_batch_map_vp, + dim); + PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed."); + } } return {out}; } std::vector> GatherNextTokenInferShape( - const std::vector &tmp_out_shape, - const std::vector &cum_offsets_shape, - const std::vector &encoder_seq_lod_shape, - const std::vector &encoder_batch_map_shape, - const std::vector &decoder_batch_map_shape, - const std::vector &encoder_seq_lod_cpu_shape, - const std::vector &encoder_batch_map_cpu_shape, - const std::vector &decoder_batch_map_cpu_shape, - const std::vector &enc_batch_tensor_shape, - const std::vector &dec_batch_tensor_shape, - const paddle::optional> &output_padding_offset_shape) { + const std::vector& x_shape, + const std::vector& cum_offsets_shape, + const std::vector& encoder_seq_lod_shape, + const std::vector& decoder_seq_lod_shape, + const std::vector& encoder_batch_map_shape, + const std::vector& decoder_batch_map_shape, + const std::vector& encoder_seq_lod_cpu_shape, + const std::vector& decoder_seq_lod_cpu_shape, + const std::vector& encoder_batch_map_cpu_shape, + const std::vector& decoder_batch_map_cpu_shape, + const std::vector& len_info_cpu_shape, + const paddle::optional>& output_padding_offset_shape) { + // if (output_padding_offset_shape) { + // PD_THROW("speculative decoding is not supported in XPU."); + // } + int64_t bsz = cum_offsets_shape[0]; + int64_t dim_embed = x_shape[1]; if (output_padding_offset_shape) { - PD_THROW("speculative decoding is not supported in XPU."); + return {{-1, dim_embed}}; + } else { + int64_t bsz = cum_offsets_shape[0]; + return {{bsz, dim_embed}}; } - int64_t bsz = cum_offsets_shape[0]; - int64_t dim_embed = tmp_out_shape[1]; - return {{bsz, dim_embed}}; } std::vector GatherNextTokenInferDtype( - const paddle::DataType &tmp_out_dtype, - const paddle::DataType &cum_offsets_dtype, - const paddle::DataType &encoder_seq_lod_dtype, - const paddle::DataType &encoder_batch_map_dtype, - const paddle::DataType &decoder_batch_map_dtype, - const paddle::DataType &encoder_seq_lod_cpu_dtype, - const paddle::DataType &encoder_batch_map_cpu_dtype, - const paddle::DataType &decoder_batch_map_cpu_dtype, - const paddle::DataType &enc_batch_tensor_dtype, - const paddle::DataType &dec_batch_tensor_dtype, - const paddle::optional &output_padding_offset_dtype) { - return {tmp_out_dtype}; + const paddle::DataType& x_dtype, + const paddle::DataType& cum_offsets_dtype, + const paddle::DataType& encoder_seq_lod_dtype, + const paddle::DataType& decoder_seq_lod_dtype, + const paddle::DataType& encoder_batch_map_dtype, + const paddle::DataType& decoder_batch_map_dtype, + const paddle::DataType& encoder_seq_lod_cpu_dtype, + const paddle::DataType& decoder_seq_lod_cpu_dtype, + const paddle::DataType& encoder_batch_map_cpu_dtype, + const paddle::DataType& decoder_batch_map_cpu_dtype, + const paddle::DataType& len_info_cpu_dtype, + const paddle::optional& output_padding_offset_dtype) { + return {x_dtype}; } -PD_BUILD_OP(gather_next_token) - .Inputs({"tmp_out", +PD_BUILD_STATIC_OP(gather_next_token) + .Inputs({"x", "cum_offsets", "encoder_seq_lod", + "decoder_seq_lod", "encoder_batch_map", "decoder_batch_map", "encoder_seq_lod_cpu", + "decoder_seq_lod_cpu", "encoder_batch_map_cpu", "decoder_batch_map_cpu", - "enc_batch_tensor", - "dec_batch_tensor", + "len_info_cpu", paddle::Optional("output_padding_offset")}) .Outputs({"out"}) - .Attrs({"max_input_length: int"}) + .Attrs({"max_bsz: int"}) .SetKernelFn(PD_KERNEL(GatherNextToken)) .SetInferShapeFn(PD_INFER_SHAPE(GatherNextTokenInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index f57670e6fad..8c755e1a078 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -37,13 +37,14 @@ std::vector AdjustBatch( const paddle::Tensor& x, // [token_num, dim_embed] const paddle::Tensor& cum_offsets, // [bsz, 1] const paddle::Tensor& encoder_seq_lod, + const paddle::Tensor& decoder_seq_lod, const paddle::Tensor& encoder_batch_idx, const paddle::Tensor& decoder_batch_idx, const paddle::Tensor& encoder_seq_lod_cpu, + const paddle::Tensor& decoder_seq_lod_cpu, const paddle::Tensor& encoder_batch_idx_cpu, const paddle::Tensor& decoder_batch_idx_cpu, - const paddle::Tensor& enc_batch_tensor, - const paddle::Tensor& dec_batch_tensor, + const paddle::Tensor& len_info_cpu, const paddle::optional& output_padding_offset, int max_input_length); @@ -351,18 +352,19 @@ std::vector EagleGetSelfHiddenStates( const paddle::Tensor& step_idx); std::vector GatherNextToken( - const paddle::Tensor& tmp_out, // [token_num, dim_embed] + const paddle::Tensor& x, // [token_num, dim_embed] const paddle::Tensor& cum_offsets, // [bsz, 1] const paddle::Tensor& encoder_seq_lod, + const paddle::Tensor& decoder_seq_lod, const paddle::Tensor& encoder_batch_map, const paddle::Tensor& decoder_batch_map, const paddle::Tensor& encoder_seq_lod_cpu, + const paddle::Tensor& decoder_seq_lod_cpu, const paddle::Tensor& encoder_batch_map_cpu, const paddle::Tensor& decoder_batch_map_cpu, - const paddle::Tensor& enc_batch_tensor, - const paddle::Tensor& dec_batch_tensor, + const paddle::Tensor& len_info_cpu, const paddle::optional& output_padding_offset, - int max_input_length); + int max_bsz); std::vector GetImgBoundaries( const paddle::Tensor& task_input_ids, @@ -605,13 +607,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("x"), py::arg("cum_offsets"), py::arg("encoder_seq_lod"), + py::arg("decoder_seq_lod"), py::arg("encoder_batch_idx"), py::arg("decoder_batch_idx"), py::arg("encoder_seq_lod_cpu"), + py::arg("decoder_seq_lod_cpu"), py::arg("encoder_batch_idx_cpu"), py::arg("decoder_batch_idx_cpu"), - py::arg("enc_batch_tensor"), - py::arg("dec_batch_tensor"), + py::arg("len_info_cpu"), py::arg("output_padding_offset"), py::arg("max_input_length"), "adjust batch in XPU"); @@ -818,18 +821,19 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("gather_next_token", &GatherNextToken, - py::arg("tmp_out"), + py::arg("x"), py::arg("cum_offsets"), py::arg("encoder_seq_lod"), + py::arg("decoder_seq_lod"), py::arg("encoder_batch_map"), py::arg("decoder_batch_map"), py::arg("encoder_seq_lod_cpu"), + py::arg("decoder_seq_lod_cpu"), py::arg("encoder_batch_map_cpu"), py::arg("decoder_batch_map_cpu"), - py::arg("enc_batch_tensor"), - py::arg("dec_batch_tensor"), + py::arg("len_info_cpu"), py::arg("output_padding_offset"), - py::arg("max_input_length"), + py::arg("max_bsz"), "Gather next token for XPU"); m.def("get_img_boundaries", diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index fdbd988fed6..38e604c40cc 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -263,6 +263,7 @@ DLL_EXPORT int eb_adjust_batch( const TX* x, TY* y, VectorParam& encoder_seqs_lods, // NOLINT + VectorParam& decoder_seqs_lods, // NOLINT VectorParam& encoder_batch_map, // NOLINT VectorParam& decoder_batch_map, // NOLINT int64_t hidden_dim); @@ -277,6 +278,17 @@ DLL_EXPORT int eb_gather_next_token( VectorParam& decoder_batch_map, // NOLINT int64_t hidden_dim); +template +DLL_EXPORT int eb_mtp_gather_next_token( + Context* ctx, + const TX* x, + TY* y, + VectorParam& encoder_seqs_lods, // NOLINT + VectorParam& decoder_seqs_lods, // NOLINT + VectorParam& encoder_batch_map, // NOLINT + VectorParam& decoder_batch_map, // NOLINT + int64_t hidden_dim); + template DLL_EXPORT int quant2d_per_channel(api::Context* ctx, const TX* x, diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/eb_adjust_batch.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/eb_adjust_batch.xpu index b675785a402..bd791bd98eb 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/eb_adjust_batch.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/eb_adjust_batch.xpu @@ -4,7 +4,7 @@ namespace xpu3 { namespace plugin { #define MAX_LM_SIZE 28672 -// One core has 32KB LM(group LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is +// One core has 32KB LM(gropu LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is // the stack space #define MAX_BATCH 512 #define ALIGNMENT 64 @@ -53,6 +53,7 @@ template __global__ void eb_adjust_batch(TX* src, TY* dst, int* encoder_seqs_lods, + int* decoder_seqs_lods, int* encoder_batch_map, int* decoder_batch_map, int en_batch, @@ -61,9 +62,11 @@ __global__ void eb_adjust_batch(TX* src, int tid = core_id() * cluster_num() + cluster_id(); int nthreads = core_num() * cluster_num(); __group_shared__ int local_lods_en[MAX_BATCH + 1]; + __group_shared__ int local_lods_de[MAX_BATCH + 1]; __group_shared__ int local_map_en[MAX_BATCH]; __group_shared__ int local_map_de[MAX_BATCH]; GM2GSM_ASYNC(encoder_seqs_lods, local_lods_en, (en_batch + 1) * sizeof(int)); + GM2GSM_ASYNC(decoder_seqs_lods, local_lods_de, (de_batch + 1) * sizeof(int)); if (en_batch > 0) { GM2GSM_ASYNC(encoder_batch_map, local_map_en, en_batch * sizeof(int)); } @@ -72,7 +75,8 @@ __global__ void eb_adjust_batch(TX* src, } mfence(); int max_encoder_len = local_lods_en[en_batch]; - int seq_sum = max_encoder_len + de_batch; + int max_decoder_len = local_lods_de[de_batch]; + int seq_sum = max_encoder_len + max_decoder_len; int total_batch = en_batch + de_batch; int start = 0; int end = 0; @@ -82,13 +86,16 @@ __global__ void eb_adjust_batch(TX* src, while (i < end) { if (i >= max_encoder_len) { // dst decode part - int cur_de_bs = i - max_encoder_len; + int cur_de_bs = 0; + get_cur_batch(local_lods_de, de_batch, i - max_encoder_len, cur_de_bs); int cur_en_bs = local_map_de[cur_de_bs] - cur_de_bs; + int cur_len = + min(end, local_lods_de[cur_de_bs + 1] + max_encoder_len) - i; _global_ptr_ TY* cur_dst = dst + i * copy_size; _global_ptr_ TX* cur_src = - src + (cur_de_bs + local_lods_en[cur_en_bs]) * copy_size; - do_memcpy_1d(cur_src, cur_dst, copy_size); - i++; + src + (local_lods_en[cur_en_bs] + i - max_encoder_len) * copy_size; + do_memcpy_1d(cur_src, cur_dst, copy_size * cur_len); + i += cur_len; } else { // dst encode part int cur_en_bs = 0; @@ -97,7 +104,8 @@ __global__ void eb_adjust_batch(TX* src, cur_de_bs = local_map_en[cur_en_bs] - cur_en_bs; int cur_len = min(end, local_lods_en[cur_en_bs + 1]) - i; _global_ptr_ TY* cur_dst = dst + i * copy_size; - _global_ptr_ TX* cur_src = src + (cur_de_bs + i) * copy_size; + _global_ptr_ TX* cur_src = + src + (local_lods_de[cur_de_bs] + i) * copy_size; do_memcpy_1d(cur_src, cur_dst, copy_size * cur_len); i += cur_len; } @@ -108,6 +116,7 @@ __global__ void eb_adjust_batch(TX* src, template __global__ void eb_adjust_batch(TX * src, \ TY * dst, \ int* encoder_seqs_lods, \ + int* decoder_seqs_lods, \ int* encoder_batch_map, \ int* decoder_batch_map, \ int en_batch, \ diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/eb_mtp_gather_next_token.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/eb_mtp_gather_next_token.xpu new file mode 100644 index 00000000000..9e964b17746 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/eb_mtp_gather_next_token.xpu @@ -0,0 +1,128 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_debug.h" +#include "xpu/kernel/cluster_primitive.h" +namespace xpu3 { +namespace plugin { +#define MAX_LM_SIZE 28672 +// One core has 32KB LM(group LM), MAX_LM_SIZE = (32 - 4)KB / 2 = 30720, 4KB is +// the stack space +#define MAX_BATCH 512 +#define ALIGNMENT 64 + +template +static __device__ void do_memcpy_1d(_global_ptr_ TX* src, + _global_ptr_ TY* dst, + int64_t copy_size) { +#ifdef __XPU3__ + constexpr int buf_size = 2048; +#else + constexpr int buf_size = 512; +#endif + __group_shared__ __simd__ float double_lmx[2][buf_size]; + int64_t pingpong = 0; + for (int64_t i = 0; i < copy_size; i += buf_size) { + int real_size = min(buf_size, copy_size - i); + _group_shared_ptr_ float* lmx = double_lmx[pingpong]; + GM2GSM(src + i, lmx, real_size * sizeof(TX)); + if (!xpu_std::is_same::value) { + primitive_cast_gsm( + (_group_shared_ptr_ TX*)lmx, lmx, real_size); + primitive_cast_gsm( + lmx, (_group_shared_ptr_ TY*)lmx, real_size); + } + GSM2GM_ASYNC((_group_shared_ptr_ TY*)lmx, dst + i, real_size * sizeof(TY)); + pingpong = 1 - pingpong; + } + mfence(); +} + +template +__global__ void eb_mtp_gather_next_token(TX* src, + TY* dst, + int* encoder_seqs_lods, + int* decoder_seqs_lods, + int* encoder_batch_map, + int* decoder_batch_map, + int en_batch, + int de_batch, + int64_t copy_size) { + int tid = core_id() * cluster_num() + cluster_id(); + int nthreads = core_num() * cluster_num(); + __group_shared__ int local_lods_en[MAX_BATCH + 1]; + __group_shared__ int local_lods_de[MAX_BATCH + 1]; + __group_shared__ int local_map_en[MAX_BATCH]; + __group_shared__ int local_map_de[MAX_BATCH]; + GM2GSM_ASYNC(encoder_seqs_lods, local_lods_en, (en_batch + 1) * sizeof(int)); + GM2GSM_ASYNC(decoder_seqs_lods, local_lods_de, (de_batch + 1) * sizeof(int)); + if (en_batch > 0) { + GM2GSM_ASYNC(encoder_batch_map, local_map_en, en_batch * sizeof(int)); + } + if (de_batch > 0) { + GM2GSM_ASYNC(decoder_batch_map, local_map_de, de_batch * sizeof(int)); + } + mfence(); + int encoder_len_total = en_batch > 0 ? local_lods_en[en_batch] : 0; + int output_len = en_batch + local_lods_de[de_batch]; + int start = 0; + int end = 0; + partition(tid, nthreads, output_len, 1, &start, &end); + for (int i = start; i < end; i++) { + int len = 0; + int enc_idx = 0, dec_idx = 0; + bool is_enc; + while (i >= len) { + if (enc_idx >= en_batch) { + len += local_lods_de[dec_idx + 1] - local_lods_de[dec_idx]; + dec_idx++; + is_enc = false; + continue; + } + if (dec_idx >= de_batch) { + len += 1; + enc_idx++; + is_enc = true; + continue; + } + if (local_map_en[enc_idx] < local_map_de[dec_idx]) { + len += 1; + enc_idx++; + is_enc = true; + } else { + len += local_lods_de[dec_idx + 1] - local_lods_de[dec_idx]; + dec_idx++; + is_enc = false; + } + } + _global_ptr_ TX* cur_src = nullptr; + _global_ptr_ TY* cur_dst = dst + i * copy_size; + if (is_enc) { + cur_src = src + (local_lods_en[enc_idx] - 1) * copy_size; + } else { + cur_src = src + (encoder_len_total + local_lods_de[dec_idx] - (len - i)) * copy_size; + } + do_memcpy_1d(cur_src, cur_dst, copy_size); + } +} +#define _XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(TX, TY) \ + template __global__ void eb_mtp_gather_next_token( \ + TX * src, \ + TY * dst, \ + int* encoder_seqs_lods, \ + int* decoder_seqs_lods, \ + int* encoder_batch_map, \ + int* decoder_batch_map, \ + int en_batch, \ + int de_batch, \ + int64_t copy_size); + +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float16, float16); +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(bfloat16, bfloat16); +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float, float); +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float16, float); +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float, float16); +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(bfloat16, float16); +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float16, bfloat16); +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(bfloat16, float); +_XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float, bfloat16); +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/eb_adjust_batch.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/eb_adjust_batch.cpp index 94f2352137b..4a4ff43c215 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/eb_adjust_batch.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/eb_adjust_batch.cpp @@ -23,6 +23,7 @@ template __attribute__((global)) void eb_adjust_batch(TX *src, TY *dst, int *encoder_seqs_lods, + int *decoder_seqs_lods, int *encoder_batch_map, int *decoder_batch_map, int en_batch, @@ -41,6 +42,7 @@ static int cpu_wrapper(api::Context *ctx, const TX *x, TY *y, const int *encoder_seqs_lods, + const int *decoder_seqs_lods, const int *encoder_batch_map, const int *decoder_batch_map, int en_batch, @@ -56,11 +58,12 @@ static int cpu_wrapper(api::Context *ctx, // get copy size && src_offset int cpy_m = 0; if (de_batch > 0 && decoder_batch_map[de_idx] == i) { - cpy_m = 1; - ret = api::cast(ctx, - x + cur_offset * hidden_dim, - y + (encoder_len_total + de_idx) * hidden_dim, - cpy_m * hidden_dim); + cpy_m = decoder_seqs_lods[de_idx + 1] - decoder_seqs_lods[de_idx]; + ret = api::cast( + ctx, + x + cur_offset * hidden_dim, + y + (encoder_len_total + decoder_seqs_lods[de_idx]) * hidden_dim, + cpy_m * hidden_dim); WRAPPER_ASSERT_SUCCESS(ctx, ret); de_idx++; } @@ -84,6 +87,7 @@ static int xpu3_wrapper(api::Context *ctx, const TX *x, TY *y, api::VectorParam &encoder_seqs_lods, // NOLINT + api::VectorParam &decoder_seqs_lods, // NOLINT api::VectorParam &encoder_batch_map, // NOLINT api::VectorParam &decoder_batch_map, // NOLINT int en_batch, @@ -98,6 +102,7 @@ static int xpu3_wrapper(api::Context *ctx, reinterpret_cast(const_cast(x)), reinterpret_cast(y), encoder_seqs_lods.xpu, + decoder_seqs_lods.xpu, encoder_batch_map.xpu, decoder_batch_map.xpu, en_batch, @@ -111,6 +116,7 @@ int eb_adjust_batch(api::Context *ctx, const TX *x, TY *y, api::VectorParam &encoder_seqs_lods, // NOLINT + api::VectorParam &decoder_seqs_lods, // NOLINT api::VectorParam &encoder_batch_map, // NOLINT api::VectorParam &decoder_batch_map, // NOLINT int64_t hidden_dim) { @@ -119,28 +125,35 @@ int eb_adjust_batch(api::Context *ctx, // if (dev_id ==0) { // ctx->set_debug_level(0xA1); // } - + // std::cout << decoder_seqs_lods.cpu[0] << " " << decoder_seqs_lods.cpu[1] << + // std::endl; WRAPPER_CHECK_CTX(ctx); WRAPPER_DUMP_FUNCTION_T2(ctx, "eb_adjust_batch", TX, TY); WRAPPER_DUMP_PARAM6(ctx, x, y, encoder_seqs_lods, + decoder_seqs_lods, encoder_batch_map, - decoder_batch_map, - hidden_dim); + decoder_batch_map); + WRAPPER_DUMP_PARAM1(ctx, hidden_dim); WRAPPER_DUMP(ctx); int encoder_batch = encoder_batch_map.len; - int total_batch = encoder_batch + decoder_batch_map.len; + int decoder_batch = decoder_batch_map.len; + int total_batch = encoder_batch + decoder_batch; int max_encoder_lod = encoder_seqs_lods.cpu[encoder_batch]; - int m = max_encoder_lod + decoder_batch_map.len; + int max_decoder_lod = decoder_seqs_lods.cpu[decoder_batch]; + int m = max_encoder_lod + max_decoder_lod; WRAPPER_CHECK_PTR(ctx, TX, m * hidden_dim, x); WRAPPER_CHECK_PTR(ctx, TY, m * hidden_dim, y); WRAPPER_ASSERT_GT(ctx, hidden_dim, 0); // check VectorParam WRAPPER_ASSERT_EQ(ctx, encoder_seqs_lods.len, encoder_batch_map.len + 1); + WRAPPER_ASSERT_EQ(ctx, decoder_seqs_lods.len, decoder_batch_map.len + 1); WRAPPER_ASSERT_GE(ctx, encoder_seqs_lods.cpu[0], 0); WRAPPER_ASSERT_LE(ctx, encoder_seqs_lods.cpu[0], max_encoder_lod); + WRAPPER_ASSERT_GE(ctx, decoder_seqs_lods.cpu[0], 0); + WRAPPER_ASSERT_LE(ctx, decoder_seqs_lods.cpu[0], max_decoder_lod); for (int i = 0; i < encoder_batch_map.len; ++i) { WRAPPER_ASSERT_GE(ctx, encoder_batch_map.cpu[i], 0); WRAPPER_ASSERT_LT(ctx, encoder_batch_map.cpu[i], total_batch) @@ -150,12 +163,15 @@ int eb_adjust_batch(api::Context *ctx, for (int i = 0; i < decoder_batch_map.len; ++i) { WRAPPER_ASSERT_GE(ctx, decoder_batch_map.cpu[i], 0); WRAPPER_ASSERT_LT(ctx, decoder_batch_map.cpu[i], total_batch) + WRAPPER_ASSERT_GE(ctx, decoder_seqs_lods.cpu[i + 1], 0); + WRAPPER_ASSERT_LE(ctx, decoder_seqs_lods.cpu[i + 1], max_decoder_lod); } if (ctx->dev().type() == api::kCPU) { return cpu_wrapper(ctx, x, y, encoder_seqs_lods.cpu, + decoder_seqs_lods.cpu, encoder_batch_map.cpu, decoder_batch_map.cpu, encoder_batch_map.len, @@ -166,6 +182,8 @@ int eb_adjust_batch(api::Context *ctx, api::ctx_guard RAII_GUARD(ctx); api::VectorParam encoder_seqs_lods_xpu = encoder_seqs_lods.to_xpu(RAII_GUARD); + api::VectorParam decoder_seqs_lods_xpu = + decoder_seqs_lods.to_xpu(RAII_GUARD); api::VectorParam encoder_batch_map_xpu = encoder_batch_map.to_xpu(RAII_GUARD); api::VectorParam decoder_batch_map_xpu = @@ -174,6 +192,7 @@ int eb_adjust_batch(api::Context *ctx, x, y, encoder_seqs_lods_xpu, + decoder_seqs_lods_xpu, encoder_batch_map_xpu, decoder_batch_map_xpu, encoder_batch_map.len, @@ -190,6 +209,7 @@ int eb_adjust_batch(api::Context *ctx, api::VectorParam &, \ api::VectorParam &, \ api::VectorParam &, \ + api::VectorParam &, \ int64_t); INSTANTIATION_EB_ADJUST_BATCH(float16, float16); diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/eb_mtp_gather_next_token.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/eb_mtp_gather_next_token.cpp new file mode 100644 index 00000000000..f817baf8938 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/eb_mtp_gather_next_token.cpp @@ -0,0 +1,227 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xpu/plugin.h" +#include "xpu/refactor/impl/launch_strategy.h" +#include "xpu/refactor/impl_public/wrapper_check.h" +#include "xpu/xdnn.h" + +namespace xpu3 { +namespace plugin { +template +__attribute__((global)) void eb_mtp_gather_next_token(TX *src, + TY *dst, + int *encoder_seqs_lods, + int *decoder_seqs_lods, + int *encoder_batch_map, + int *decoder_batch_map, + int en_batch, + int de_batch, + int64_t copy_size); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { +template +static int cpu_wrapper(api::Context *ctx, + const TX *x, + TY *y, + const int *encoder_seqs_lods, + const int *decoder_seqs_lods, + const int *encoder_batch_map, + const int *decoder_batch_map, + int en_batch, + int de_batch, + int64_t hidden_dim) { + int ret = 0; + int encoder_len_total = encoder_seqs_lods[en_batch]; + int decoder_len_total = decoder_seqs_lods[de_batch]; + int output_token_num = en_batch + decoder_len_total; + for (int i = 0; i < output_token_num; i++) { + int len = 0; + int enc_idx = 0, dec_idx = 0; + bool is_enc; + while (i >= len) { + if (enc_idx >= en_batch) { + len += decoder_seqs_lods[dec_idx + 1] - decoder_seqs_lods[dec_idx]; + dec_idx++; + is_enc = false; + continue; + } + if (dec_idx >= de_batch) { + len += 1; + enc_idx++; + is_enc = true; + continue; + } + if ((encoder_batch_map[enc_idx] < decoder_batch_map[dec_idx])) { + len += 1; + enc_idx++; + is_enc = true; + } else { + len += decoder_seqs_lods[dec_idx + 1] - decoder_seqs_lods[dec_idx]; + dec_idx++; + is_enc = false; + } + } + const TX *src = nullptr; + if (is_enc) { + src = x + (encoder_seqs_lods[enc_idx] - 1) * hidden_dim; + } else { + src = x + (encoder_len_total + decoder_seqs_lods[dec_idx] - (len - i)) * + hidden_dim; + } + ret = api::cast(ctx, src, y + i * hidden_dim, hidden_dim); + WRAPPER_ASSERT_SUCCESS(ctx, ret); + } + return api::SUCCESS; +} +template +static int xpu3_wrapper(api::Context *ctx, + const TX *x, + TY *y, + api::VectorParam &encoder_seqs_lods, // NOLINT + api::VectorParam &decoder_seqs_lods, // NOLINT + api::VectorParam &encoder_batch_map, // NOLINT + api::VectorParam &decoder_batch_map, // NOLINT + int en_batch, + int de_batch, + int64_t hidden_dim) { + auto eb_mtp_gather_next_token_kernel = + xpu3::plugin::eb_mtp_gather_next_token; + // NOTE: Don't change 16 to 64, because kernel use gsm + eb_mtp_gather_next_token_kernel<<ncluster(), 16, ctx->xpu_stream>>>( + const_cast(x), + y, + encoder_seqs_lods.xpu, + decoder_seqs_lods.xpu, + encoder_batch_map.xpu, + decoder_batch_map.xpu, + en_batch, + de_batch, + hidden_dim); + return api::SUCCESS; +} + +template +int eb_mtp_gather_next_token( + api::Context *ctx, + const TX *x, + TY *y, + api::VectorParam &encoder_seqs_lods, // NOLINT + api::VectorParam &decoder_seqs_lods, // NOLINT + api::VectorParam &encoder_batch_map, // NOLINT + api::VectorParam &decoder_batch_map, // NOLINT + int64_t hidden_dim) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T2(ctx, "eb_mtp_gather_next_token", TX, TY); + WRAPPER_DUMP_PARAM6(ctx, + x, + y, + encoder_seqs_lods, + decoder_seqs_lods, + encoder_batch_map, + decoder_batch_map); + WRAPPER_DUMP_PARAM1(ctx, hidden_dim); + WRAPPER_DUMP(ctx); + int encoder_batch = encoder_batch_map.len; + int decoder_batch = decoder_batch_map.len; + int max_encoder_lod = encoder_seqs_lods.cpu[encoder_batch]; + int max_decoder_lod = decoder_seqs_lods.cpu[decoder_batch]; + int m = encoder_seqs_lods.cpu[encoder_batch] + + decoder_seqs_lods.cpu[decoder_batch]; + int out_m = encoder_batch + decoder_seqs_lods.cpu[decoder_batch]; + WRAPPER_CHECK_PTR(ctx, TX, m * hidden_dim, x); + WRAPPER_CHECK_PTR(ctx, TY, out_m * hidden_dim, y); + WRAPPER_ASSERT_GT(ctx, hidden_dim, 0); + // check VectorParam + WRAPPER_ASSERT_EQ(ctx, encoder_seqs_lods.len, encoder_batch_map.len + 1); + WRAPPER_ASSERT_EQ(ctx, decoder_seqs_lods.len, decoder_batch_map.len + 1); + WRAPPER_ASSERT_GE(ctx, encoder_seqs_lods.cpu[0], 0); + WRAPPER_ASSERT_LE(ctx, encoder_seqs_lods.cpu[0], max_encoder_lod); + WRAPPER_ASSERT_GE(ctx, decoder_seqs_lods.cpu[0], 0); + WRAPPER_ASSERT_LE(ctx, decoder_seqs_lods.cpu[0], max_decoder_lod); + // 注意: encoder/decoder的batch + // map数值上有可能大于batch,因为复原后的batch排布有可能是稀疏的,所以这里只做非负检查 + for (int i = 0; i < encoder_batch_map.len; ++i) { + WRAPPER_ASSERT_GE(ctx, encoder_batch_map.cpu[i], 0); + WRAPPER_ASSERT_GE(ctx, encoder_seqs_lods.cpu[i + 1], 0); + WRAPPER_ASSERT_LE(ctx, encoder_seqs_lods.cpu[i + 1], max_encoder_lod); + } + for (int i = 0; i < decoder_batch_map.len; ++i) { + WRAPPER_ASSERT_GE(ctx, decoder_batch_map.cpu[i], 0); + WRAPPER_ASSERT_GE(ctx, decoder_seqs_lods.cpu[i + 1], 0); + WRAPPER_ASSERT_LE(ctx, decoder_seqs_lods.cpu[i + 1], max_decoder_lod); + } + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + x, + y, + encoder_seqs_lods.cpu, + decoder_seqs_lods.cpu, + encoder_batch_map.cpu, + decoder_batch_map.cpu, + encoder_batch_map.len, + decoder_batch_map.len, + hidden_dim); + } + if (ctx->dev().type() == api::kXPU3) { + api::ctx_guard RAII_GUARD(ctx); + api::VectorParam encoder_seqs_lods_xpu = + encoder_seqs_lods.to_xpu(RAII_GUARD); + api::VectorParam decoder_seqs_lods_xpu = + decoder_seqs_lods.to_xpu(RAII_GUARD); + api::VectorParam encoder_batch_map_xpu = + encoder_batch_map.to_xpu(RAII_GUARD); + api::VectorParam decoder_batch_map_xpu = + decoder_batch_map.to_xpu(RAII_GUARD); + return xpu3_wrapper(ctx, + x, + y, + encoder_seqs_lods_xpu, + decoder_seqs_lods_xpu, + encoder_batch_map_xpu, + decoder_batch_map_xpu, + encoder_batch_map.len, + decoder_batch_map.len, + hidden_dim); + } + WRAPPER_UNIMPLEMENTED(ctx); +} +#define INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(TX, TY) \ + template int eb_mtp_gather_next_token(api::Context *, \ + const TX *, \ + TY *, \ + api::VectorParam &, \ + api::VectorParam &, \ + api::VectorParam &, \ + api::VectorParam &, \ + int64_t); + +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float16, float16); +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(bfloat16, bfloat16); +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float, float); +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float16, float); +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float, float16); +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(bfloat16, float16); +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float16, bfloat16); +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(bfloat16, float); +INSTANTIATION_EB_MTP_GATHER_NEXT_TOKEN(float, bfloat16); +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py b/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py new file mode 100644 index 00000000000..eebbf81f10f --- /dev/null +++ b/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py @@ -0,0 +1,196 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import pytest + +from fastdeploy.model_executor.ops.xpu import ( + adjust_batch, + gather_next_token, + get_infer_param, +) + + +def _run_test_base(seq_lens_this_time_data, output_padding_offset): + """ + 通用的基础测试执行函数,包含了两个场景共有的逻辑。 + """ + seq_lens_encoder = paddle.to_tensor([100, 0, 0, 0, 120, 140, 0], dtype="int32") + seq_lens_decoder = paddle.to_tensor([0, 5, 0, 25, 64, 0, 128], dtype="int32") + seq_lens_this_time = paddle.to_tensor(seq_lens_this_time_data, dtype="int32") + + bsz = seq_lens_this_time.shape[0] + cum_offsets = paddle.zeros(bsz, dtype="int32") + block_table = paddle.arange(0, 56, dtype="int32").reshape((bsz, 8)) + + infer_params = get_infer_param(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, block_table, 64) + + ( + encoder_batch_map, + decoder_batch_map, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod, + decoder_seq_lod, + _, + _, + _, + _, + _, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + _, + _, + _, + _, + len_info_cpu, + ) = infer_params + + token_num = seq_lens_this_time.sum().cpu().item() + hidden_dim = 8192 + row_indices = paddle.arange(token_num, dtype="int32") + row_indices_bf16 = row_indices.astype("bfloat16") + input_tensor = paddle.unsqueeze(row_indices_bf16, axis=1).expand(shape=[token_num, hidden_dim]) + + # 测试 adjust_batch + adjusted_output = adjust_batch( + input_tensor, + cum_offsets, + encoder_seq_lod, + decoder_seq_lod, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + len_info_cpu, + None, # output_padding_offset + -1, # max_input_length + ) + + adjusted_output_cpu = adjust_batch( + input_tensor.cpu(), + cum_offsets, + encoder_seq_lod, + decoder_seq_lod, + encoder_batch_idx, + decoder_batch_idx, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_batch_idx_cpu, + decoder_batch_idx_cpu, + len_info_cpu, + None, # output_padding_offset + -1, # max_input_length + ) + assert ( + paddle.equal_all(adjusted_output.astype("float32").cpu(), adjusted_output_cpu.astype("float32")).all().item() + ), "adjust_batch check failed!" + + # 测试 gather_next_token + gather_out = gather_next_token( + adjusted_output, + cum_offsets, + encoder_seq_lod, + decoder_seq_lod, + encoder_batch_map, + decoder_batch_map, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + len_info_cpu, + output_padding_offset, + -1, + ) + + gather_out_cpu = gather_next_token( + adjusted_output.cpu(), + cum_offsets, + encoder_seq_lod, + decoder_seq_lod, + encoder_batch_map, + decoder_batch_map, + encoder_seq_lod_cpu, + decoder_seq_lod_cpu, + encoder_batch_map_cpu, + decoder_batch_map_cpu, + len_info_cpu, + output_padding_offset, + -1, + ) + + if output_padding_offset is not None: + np.testing.assert_allclose( + gather_out.astype("float32").cpu().numpy(), + gather_out_cpu.astype("float32").cpu().numpy(), + err_msg="gather_next_token check failed!", + ) + else: + for i in range(gather_out_cpu.shape[0]): + if seq_lens_this_time[i] > 0: + np.testing.assert_allclose( + gather_out[i].astype("float32").cpu().numpy(), + gather_out_cpu[i].astype("float32").cpu().numpy(), + err_msg="gather_next_token check failed!", + ) + + +def test_mix_with_mtp(): + """测试混合批次处理中的MTP(Multi-Token Prediction)场景。 + + 验证在不同序列长度(包括零长度)情况下,MTP功能是否能正确处理。 + + Args: + 无显式参数,但内部使用: + seq_lens_this_time_data: 包含不同长度序列的列表,用于模拟混合批次 + output_padding_offset: 用于处理序列填充的偏移量张量 + + Returns: + 无返回值,但会打印测试结果 + """ + print("\nRunning test: test_mix_with_mtp") + seq_lens_this_time_data = [100, 2, 0, 1, 120, 140, 3] + bsz = len(seq_lens_this_time_data) + output_padding_offset = paddle.zeros(bsz, dtype="int32") + + _run_test_base(seq_lens_this_time_data, output_padding_offset) + print("Test passed for scenario: With MTP") + + +def test_mix_without_mtp(): + """测试非MTP(Single-Token Prediction)场景下的功能。 + + 该测试用例专门验证在非MTP(多令牌预测)场景下,模型处理不同长度序列的能力。 + + Args: + seq_lens_this_time_data: 本次处理的序列长度列表,包含各种长度的序列 + output_padding_offset: 非MTP场景下此参数应为None + """ + print("\nRunning test: test_mix_without_mtp") + seq_lens_this_time_data = [100, 1, 0, 1, 120, 140, 1] + output_padding_offset = None # 非MTP场景下,此参数为None + + _run_test_base(seq_lens_this_time_data, output_padding_offset) + print("Test passed for scenario: Without MTP") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index e01ff0fdae0..4ac6af5621c 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -1017,15 +1017,16 @@ def xpu_pre_process( ids_remove_padding.reshape([-1, 1]), cum_offsets, xpu_forward_meta.encoder_seq_lod, + xpu_forward_meta.decoder_seq_lod, xpu_forward_meta.encoder_batch_idx, xpu_forward_meta.decoder_batch_idx, xpu_forward_meta.encoder_seq_lod_cpu, + xpu_forward_meta.decoder_seq_lod_cpu, xpu_forward_meta.encoder_batch_idx_cpu, xpu_forward_meta.decoder_batch_idx_cpu, - xpu_forward_meta.enc_batch, - xpu_forward_meta.dec_batch, + xpu_forward_meta.len_info_cpu, None, # output_padding_offset - -1, # max_input_length + -1, # max bs ) adjusted_input = adjusted_input.squeeze(1) From 3f91930993ca8437aa43f0ea14cf8a493eadfbe8 Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Mon, 24 Nov 2025 10:09:42 +0000 Subject: [PATCH 11/17] fix code style --- .../mtp_kernel/eb_mtp_gather_next_token.xpu | 43 ++++++++++--------- .../model_executor/pre_and_post_process.py | 10 +++-- fastdeploy/worker/xpu_model_runner.py | 5 ++- 3 files changed, 32 insertions(+), 26 deletions(-) diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/eb_mtp_gather_next_token.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/eb_mtp_gather_next_token.xpu index 9e964b17746..522e2911e1f 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/eb_mtp_gather_next_token.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/eb_mtp_gather_next_token.xpu @@ -38,14 +38,14 @@ static __device__ void do_memcpy_1d(_global_ptr_ TX* src, template __global__ void eb_mtp_gather_next_token(TX* src, - TY* dst, - int* encoder_seqs_lods, - int* decoder_seqs_lods, - int* encoder_batch_map, - int* decoder_batch_map, - int en_batch, - int de_batch, - int64_t copy_size) { + TY* dst, + int* encoder_seqs_lods, + int* decoder_seqs_lods, + int* encoder_batch_map, + int* decoder_batch_map, + int en_batch, + int de_batch, + int64_t copy_size) { int tid = core_id() * cluster_num() + cluster_id(); int nthreads = core_num() * cluster_num(); __group_shared__ int local_lods_en[MAX_BATCH + 1]; @@ -75,13 +75,13 @@ __global__ void eb_mtp_gather_next_token(TX* src, len += local_lods_de[dec_idx + 1] - local_lods_de[dec_idx]; dec_idx++; is_enc = false; - continue; + continue; } if (dec_idx >= de_batch) { len += 1; enc_idx++; is_enc = true; - continue; + continue; } if (local_map_en[enc_idx] < local_map_de[dec_idx]) { len += 1; @@ -98,21 +98,22 @@ __global__ void eb_mtp_gather_next_token(TX* src, if (is_enc) { cur_src = src + (local_lods_en[enc_idx] - 1) * copy_size; } else { - cur_src = src + (encoder_len_total + local_lods_de[dec_idx] - (len - i)) * copy_size; + cur_src = src + (encoder_len_total + local_lods_de[dec_idx] - (len - i)) * + copy_size; } do_memcpy_1d(cur_src, cur_dst, copy_size); } } -#define _XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(TX, TY) \ - template __global__ void eb_mtp_gather_next_token( \ - TX * src, \ - TY * dst, \ - int* encoder_seqs_lods, \ - int* decoder_seqs_lods, \ - int* encoder_batch_map, \ - int* decoder_batch_map, \ - int en_batch, \ - int de_batch, \ +#define _XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(TX, TY) \ + template __global__ void eb_mtp_gather_next_token( \ + TX * src, \ + TY * dst, \ + int* encoder_seqs_lods, \ + int* decoder_seqs_lods, \ + int* encoder_batch_map, \ + int* decoder_batch_map, \ + int en_batch, \ + int de_batch, \ int64_t copy_size); _XPU_DEF__EB_MTP_GATHER_NEXT_TOKEN(float16, float16); diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 4ac6af5621c..d603d6d6f53 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -1042,21 +1042,25 @@ def xpu_process_output( forward_output, cum_offsets: paddle.Tensor, xpu_forward_meta: XPUForwardMeta, + share_inputs, ) -> paddle.Tensor: """ """ + output_padding_offset = share_inputs.get("output_padding_offset", None) + hiddden_states = gather_next_token( forward_output, cum_offsets, xpu_forward_meta.encoder_seq_lod, + xpu_forward_meta.decoder_seq_lod, xpu_forward_meta.encoder_batch_map, xpu_forward_meta.decoder_batch_map, xpu_forward_meta.encoder_seq_lod_cpu, + xpu_forward_meta.decoder_seq_lod_cpu, xpu_forward_meta.encoder_batch_map_cpu, xpu_forward_meta.decoder_batch_map_cpu, - xpu_forward_meta.enc_batch, - xpu_forward_meta.dec_batch, - None, # output_padding_offset + xpu_forward_meta.len_info_cpu, + output_padding_offset, # output_padding_offset -1, # max_input_length ) return hiddden_states diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index ac747879f6b..9a21b23627c 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -936,8 +936,9 @@ class at the server level, which is too granular for ModelRunner. forward_meta=self.forward_meta, ) - hidden_states = xpu_process_output(model_output, self.share_inputs["cum_offsets"], self.forward_meta) - + hidden_states = xpu_process_output( + model_output, self.share_inputs["cum_offsets"], self.forward_meta, self.share_inputs + ) # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states) sampler_output = self.sampler(logits, self.sampling_metadata) From bea118c544cd0ff17d51639fe07135d0eda89563 Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Mon, 24 Nov 2025 10:36:51 +0000 Subject: [PATCH 12/17] fix mtp kenrel name --- .../src/ops/mtp/draft_model_preprocess.cc | 40 +- .../src/ops/mtp/draft_model_preprocess_v2.cc | 150 ------- .../ops/mtp/speculate_get_padding_offset.cc | 8 +- .../mtp/speculate_get_padding_offset_v2.cc | 133 ------ custom_ops/xpu_ops/src/ops/pybind/pybind.cc | 79 +--- .../xpu_ops/src/plugin/include/xpu/plugin.h | 122 ++--- .../mtp_kernel/draft_model_preprocess.xpu | 144 +++--- .../mtp_kernel/draft_model_preprocess_v2.xpu | 240 ---------- .../speculate_get_padding_offset.xpu | 38 +- .../mtp_wrapper/draft_model_preprocess.cpp | 267 ++++++----- .../mtp_wrapper/draft_model_preprocess_v2.cpp | 419 ------------------ .../speculate_get_padding_offset.cpp | 123 +---- 12 files changed, 300 insertions(+), 1463 deletions(-) delete mode 100644 custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess_v2.cc delete mode 100644 custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset_v2.cc delete mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess_v2.xpu delete mode 100644 custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess_v2.cpp diff --git a/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess.cc b/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess.cc index ec501a7904b..a4cf8e68748 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess.cc @@ -29,21 +29,23 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& step_idx, - const paddle::Tensor& seq_lens_encoder_record, - const paddle::Tensor& seq_lens_decoder_record, const paddle::Tensor& not_need_stop, + const paddle::Tensor& is_block_step, const paddle::Tensor& batch_drop, + const paddle::Tensor& pre_ids, const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_num, + const paddle::Tensor& base_model_seq_lens_this_time, const paddle::Tensor& base_model_seq_lens_encoder, const paddle::Tensor& base_model_seq_lens_decoder, const paddle::Tensor& base_model_step_idx, const paddle::Tensor& base_model_stop_flags, const paddle::Tensor& base_model_is_block_step, const paddle::Tensor& base_model_draft_tokens, - const int max_draft_token, + const int num_model_step, const bool truncate_first_token, - const bool splitwise_prefill) { + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); api::Context* ctx = static_cast(dev_ctx)->x_context(); @@ -54,6 +56,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, int accept_tokens_len = accept_tokens.shape()[1]; int input_ids_len = input_ids.shape()[1]; int draft_tokens_len = draft_tokens.shape()[1]; + int pre_ids_len = pre_ids.shape()[1]; + constexpr int BlockSize = 512; int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1]; auto not_need_stop_gpu = not_need_stop.copy_to(seq_lens_this_time.place(), false); @@ -67,12 +71,13 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const_cast(seq_lens_encoder.data()), const_cast(seq_lens_decoder.data()), const_cast(step_idx.data()), - const_cast(seq_lens_encoder_record.data()), - const_cast(seq_lens_decoder_record.data()), const_cast(not_need_stop_gpu.data()), + const_cast(is_block_step.data()), const_cast(batch_drop.data()), + const_cast(pre_ids.data()), accept_tokens.data(), accept_num.data(), + base_model_seq_lens_this_time.data(), base_model_seq_lens_encoder.data(), base_model_seq_lens_decoder.data(), base_model_step_idx.data(), @@ -80,13 +85,16 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, base_model_is_block_step.data(), const_cast(base_model_draft_tokens.data()), real_bsz, - max_draft_token, + num_model_step, accept_tokens_len, draft_tokens_len, input_ids_len, base_model_draft_tokens_len, + pre_ids_len, truncate_first_token, - splitwise_prefill); + splitwise_prefill, + kvcache_scheduler_v1); + PD_CHECK(r == 0, "xpu::plugin::draft_model_preprocess failed."); auto not_need_stop_cpu = not_need_stop_gpu.copy_to(not_need_stop.place(), false); @@ -102,12 +110,13 @@ PD_BUILD_STATIC_OP(draft_model_preprocess) "seq_lens_encoder", "seq_lens_decoder", "step_idx", - "seq_lens_encoder_record", - "seq_lens_decoder_record", "not_need_stop", + "is_block_step", "batch_drop", + "pre_ids", "accept_tokens", "accept_num", + "base_model_seq_lens_this_time", "base_model_seq_lens_encoder", "base_model_seq_lens_decoder", "base_model_step_idx", @@ -123,11 +132,11 @@ PD_BUILD_STATIC_OP(draft_model_preprocess) "step_idx_out", "not_need_stop_out", "batch_drop_out", - "seq_lens_encoder_record_out", - "seq_lens_decoder_record_out"}) - .Attrs({"max_draft_token: int", + "pre_ids_out"}) + .Attrs({"num_model_step: int", "truncate_first_token: bool", - "splitwise_prefill: bool"}) + "splitwise_prefill: bool", + "kvcache_scheduler_v1: bool"}) .SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, {"input_ids", "input_ids_out"}, {"stop_flags", "stop_flags_out"}, @@ -137,6 +146,5 @@ PD_BUILD_STATIC_OP(draft_model_preprocess) {"step_idx", "step_idx_out"}, {"not_need_stop", "not_need_stop_out"}, {"batch_drop", "batch_drop_out"}, - {"seq_lens_encoder_record", "seq_lens_encoder_record_out"}, - {"seq_lens_decoder_record", "seq_lens_decoder_record_out"}}) + {"pre_ids", "pre_ids_out"}}) .SetKernelFn(PD_KERNEL(DraftModelPreprocess)); diff --git a/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess_v2.cc b/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess_v2.cc deleted file mode 100644 index c2eb3313b27..00000000000 --- a/custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess_v2.cc +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "paddle/extension.h" -#include "paddle/phi/core/enforce.h" -#include "xpu/plugin.h" - -#ifndef PD_BUILD_STATIC_OP -#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) -#endif - -namespace api = baidu::xpu::api; -void DraftModelPreprocessV2(const paddle::Tensor& draft_tokens, - const paddle::Tensor& input_ids, - const paddle::Tensor& stop_flags, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& step_idx, - const paddle::Tensor& not_need_stop, - const paddle::Tensor& is_block_step, - const paddle::Tensor& batch_drop, - const paddle::Tensor& pre_ids, - const paddle::Tensor& accept_tokens, - const paddle::Tensor& accept_num, - const paddle::Tensor& base_model_seq_lens_this_time, - const paddle::Tensor& base_model_seq_lens_encoder, - const paddle::Tensor& base_model_seq_lens_decoder, - const paddle::Tensor& base_model_step_idx, - const paddle::Tensor& base_model_stop_flags, - const paddle::Tensor& base_model_is_block_step, - const paddle::Tensor& base_model_draft_tokens, - const int num_model_step, - const bool truncate_first_token, - const bool splitwise_prefill, - const bool kvcache_scheduler_v1) { - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); - api::Context* ctx = static_cast(dev_ctx)->x_context(); - if (draft_tokens.is_cpu()) { - ctx = new api::Context(api::kCPU); - } - int real_bsz = seq_lens_this_time.shape()[0]; - int accept_tokens_len = accept_tokens.shape()[1]; - int input_ids_len = input_ids.shape()[1]; - int draft_tokens_len = draft_tokens.shape()[1]; - int pre_ids_len = pre_ids.shape()[1]; - constexpr int BlockSize = 512; - int base_model_draft_tokens_len = base_model_draft_tokens.shape()[1]; - auto not_need_stop_gpu = - not_need_stop.copy_to(seq_lens_this_time.place(), false); - - int r = baidu::xpu::api::plugin::draft_model_preprocess_v2( - ctx, - const_cast(draft_tokens.data()), - const_cast(input_ids.data()), - const_cast(stop_flags.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(step_idx.data()), - const_cast(not_need_stop_gpu.data()), - const_cast(is_block_step.data()), - const_cast(batch_drop.data()), - const_cast(pre_ids.data()), - accept_tokens.data(), - accept_num.data(), - base_model_seq_lens_this_time.data(), - base_model_seq_lens_encoder.data(), - base_model_seq_lens_decoder.data(), - base_model_step_idx.data(), - base_model_stop_flags.data(), - base_model_is_block_step.data(), - const_cast(base_model_draft_tokens.data()), - real_bsz, - num_model_step, - accept_tokens_len, - draft_tokens_len, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len, - truncate_first_token, - splitwise_prefill, - kvcache_scheduler_v1); - - PD_CHECK(r == 0, "xpu::plugin::draft_model_preprocess failed."); - auto not_need_stop_cpu = - not_need_stop_gpu.copy_to(not_need_stop.place(), false); - bool* not_need_stop_data = const_cast(not_need_stop.data()); - not_need_stop_data[0] = not_need_stop_cpu.data()[0]; -} - -PD_BUILD_STATIC_OP(draft_model_preprocess_v2) - .Inputs({"draft_tokens", - "input_ids", - "stop_flags", - "seq_lens_this_time", - "seq_lens_encoder", - "seq_lens_decoder", - "step_idx", - "not_need_stop", - "is_block_step", - "batch_drop", - "pre_ids", - "accept_tokens", - "accept_num", - "base_model_seq_lens_this_time", - "base_model_seq_lens_encoder", - "base_model_seq_lens_decoder", - "base_model_step_idx", - "base_model_stop_flags", - "base_model_is_block_step", - "base_model_draft_tokens"}) - .Outputs({"draft_tokens_out", - "input_ids_out", - "stop_flags_out", - "seq_lens_this_time_out", - "seq_lens_encoder_out", - "seq_lens_decoder_out", - "step_idx_out", - "not_need_stop_out", - "batch_drop_out", - "pre_ids_out"}) - .Attrs({"num_model_step: int", - "truncate_first_token: bool", - "splitwise_prefill: bool", - "kvcache_scheduler_v1: bool"}) - .SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, - {"input_ids", "input_ids_out"}, - {"stop_flags", "stop_flags_out"}, - {"seq_lens_this_time", "seq_lens_this_time_out"}, - {"seq_lens_encoder", "seq_lens_encoder_out"}, - {"seq_lens_decoder", "seq_lens_decoder_out"}, - {"step_idx", "step_idx_out"}, - {"not_need_stop", "not_need_stop_out"}, - {"batch_drop", "batch_drop_out"}, - {"pre_ids", "pre_ids_out"}}) - .SetKernelFn(PD_KERNEL(DraftModelPreprocessV2)); diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset.cc index 1cf14b810b4..f22dc7aaa89 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset.cc @@ -43,6 +43,8 @@ std::vector SpeculateGetPaddingOffset( {token_num_data}, paddle::DataType::INT64, input_ids.place()); auto padding_offset = paddle::empty( {token_num_data}, paddle::DataType::INT32, input_ids.place()); + auto batch_id_per_token = paddle::empty( + {token_num_data}, paddle::DataType::INT32, input_ids.place()); auto cu_seqlens_q = paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place()); auto cu_seqlens_k = @@ -57,7 +59,7 @@ std::vector SpeculateGetPaddingOffset( int r = baidu::xpu::api::plugin::speculate_get_padding_offset( xpu_ctx->x_context(), - padding_offset.data(), + batch_id_per_token.data(), cum_offsets_out.data(), cu_seqlens_q.data(), cu_seqlens_k.data(), @@ -83,7 +85,7 @@ std::vector SpeculateGetPaddingOffset( return {x_remove_padding, cum_offsets_out, - padding_offset, + batch_id_per_token, cu_seqlens_q, cu_seqlens_k}; // , enc_token_num, dec_token_num}; } @@ -123,7 +125,7 @@ PD_BUILD_STATIC_OP(speculate_get_padding_offset) "seq_lens_encoder"}) .Outputs({"x_remove_padding", "cum_offsets_out", - "padding_offset", + "batch_id_per_token", "cu_seqlens_q", "cu_seqlens_k"}) .SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffset)) diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset_v2.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset_v2.cc deleted file mode 100644 index 18b945bcc05..00000000000 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_get_padding_offset_v2.cc +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "paddle/extension.h" -#include "xpu/plugin.h" - -#ifndef PD_BUILD_STATIC_OP -#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) -#endif - -std::vector SpeculateGetPaddingOffsetV2( - const paddle::Tensor& input_ids, - const paddle::Tensor& draft_tokens, - const paddle::Tensor& cum_offsets, - const paddle::Tensor& token_num, - const paddle::Tensor& seq_len, - const paddle::Tensor& seq_lens_encoder) { - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); - - std::vector input_ids_shape = input_ids.shape(); - const int bsz = seq_len.shape()[0]; - const int seq_length = input_ids_shape[1]; - const int max_draft_tokens = draft_tokens.shape()[1]; - auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false); - auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); - - const int token_num_data = cpu_token_num.data()[0]; - auto x_remove_padding = paddle::empty( - {token_num_data}, paddle::DataType::INT64, input_ids.place()); - auto padding_offset = paddle::empty( - {token_num_data}, paddle::DataType::INT32, input_ids.place()); - auto batch_id_per_token = paddle::empty( - {token_num_data}, paddle::DataType::INT32, input_ids.place()); - auto cu_seqlens_q = - paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place()); - auto cu_seqlens_k = - paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place()); - - PD_CHECK(input_ids.is_contiguous(), "Input ids tensor must be contiguous"); - PD_CHECK(draft_tokens.is_contiguous(), - "Draft tokens tensor must be contiguous"); - PD_CHECK(cum_offsets.is_contiguous(), - "Cum offsets tensor must be contiguous"); - PD_CHECK(seq_len.is_contiguous(), "Seq lens tensor must be contiguous"); - - int r = baidu::xpu::api::plugin::speculate_get_padding_offset_v2( - xpu_ctx->x_context(), - batch_id_per_token.data(), - cum_offsets_out.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - cum_offsets.data(), - seq_len.data(), - seq_length, - bsz); - PD_CHECK(r == 0, "XPU speculate_get_padding_offset_v2 failed"); - - r = baidu::xpu::api::plugin::speculate_remove_padding( - xpu_ctx->x_context(), - x_remove_padding.data(), - input_ids.data(), - draft_tokens.data(), - seq_len.data(), - seq_lens_encoder.data(), - cum_offsets_out.data(), - seq_length, - max_draft_tokens, - bsz, - token_num_data); - PD_CHECK(r == 0, "XPU speculate_remove_padding failed"); - - return {x_remove_padding, - cum_offsets_out, - batch_id_per_token, - cu_seqlens_q, - cu_seqlens_k}; // , enc_token_num, dec_token_num}; -} - -std::vector> SpeculateGetPaddingOffsetV2InferShape( - const std::vector& input_ids_shape, - const std::vector& draft_tokens_shape, - const std::vector& cum_offsets_shape, - const std::vector& token_num_shape, - const std::vector& seq_len_shape, - const std::vector& seq_lens_encoder_shape) { - int64_t bsz = seq_len_shape[0]; - int64_t seq_len = input_ids_shape[1]; - return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}}; -} - -std::vector SpeculateGetPaddingOffsetV2InferDtype( - const paddle::DataType& input_ids_dtype, - const paddle::DataType& draft_tokens_dtype, - const paddle::DataType& cum_offsets_dtype, - const paddle::DataType& token_num_dtype, - const paddle::DataType& seq_len_dtype, - const paddle::DataType& seq_lens_encoder_dtype) { - return {input_ids_dtype, - seq_len_dtype, - seq_len_dtype, - seq_len_dtype, - seq_len_dtype}; -} - -PD_BUILD_STATIC_OP(speculate_get_padding_offset_v2) - .Inputs({"input_ids", - "draft_tokens", - "cum_offsets", - "token_num", - "seq_len", - "seq_lens_encoder"}) - .Outputs({"x_remove_padding", - "cum_offsets_out", - "batch_id_per_token", - "cu_seqlens_q", - "cu_seqlens_k"}) - .SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffsetV2)) - .SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetPaddingOffsetV2InferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetPaddingOffsetV2InferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 8c755e1a078..0400aa02d7d 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -288,46 +288,23 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& step_idx, - const paddle::Tensor& seq_lens_encoder_record, - const paddle::Tensor& seq_lens_decoder_record, const paddle::Tensor& not_need_stop, + const paddle::Tensor& is_block_step, const paddle::Tensor& batch_drop, + const paddle::Tensor& pre_ids, const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_num, + const paddle::Tensor& base_model_seq_lens_this_time, const paddle::Tensor& base_model_seq_lens_encoder, const paddle::Tensor& base_model_seq_lens_decoder, const paddle::Tensor& base_model_step_idx, const paddle::Tensor& base_model_stop_flags, const paddle::Tensor& base_model_is_block_step, const paddle::Tensor& base_model_draft_tokens, - const int max_draft_token, + const int num_model_step, const bool truncate_first_token, - const bool splitwise_prefill); - -void DraftModelPreprocessV2(const paddle::Tensor& draft_tokens, - const paddle::Tensor& input_ids, - const paddle::Tensor& stop_flags, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& step_idx, - const paddle::Tensor& not_need_stop, - const paddle::Tensor& is_block_step, - const paddle::Tensor& batch_drop, - const paddle::Tensor& pre_ids, - const paddle::Tensor& accept_tokens, - const paddle::Tensor& accept_num, - const paddle::Tensor& base_model_seq_lens_this_time, - const paddle::Tensor& base_model_seq_lens_encoder, - const paddle::Tensor& base_model_seq_lens_decoder, - const paddle::Tensor& base_model_step_idx, - const paddle::Tensor& base_model_stop_flags, - const paddle::Tensor& base_model_is_block_step, - const paddle::Tensor& base_model_draft_tokens, - const int num_model_step, - const bool truncate_first_token, - const bool splitwise_prefill, - const bool kvcache_scheduler_v1); + const bool splitwise_prefill, + const bool kvcache_scheduler_v1); void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens, const paddle::Tensor& base_model_seq_lens_this_time, @@ -425,14 +402,6 @@ std::vector SpeculateGetPaddingOffset( const paddle::Tensor& seq_len, const paddle::Tensor& seq_lens_encoder); -std::vector SpeculateGetPaddingOffsetV2( - const paddle::Tensor& input_ids, - const paddle::Tensor& draft_tokens, - const paddle::Tensor& cum_offsets, - const paddle::Tensor& token_num, - const paddle::Tensor& seq_len, - const paddle::Tensor& seq_lens_encoder); - void StepPaddle(const paddle::Tensor& stop_flags, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& ori_seq_lens_encoder, @@ -686,32 +655,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("seq_lens_encoder"), py::arg("seq_lens_decoder"), py::arg("step_idx"), - py::arg("seq_lens_encoder_record"), - py::arg("seq_lens_decoder_record"), - py::arg("not_need_stop"), - py::arg("batch_drop"), - py::arg("accept_tokens"), - py::arg("accept_num"), - py::arg("base_model_seq_lens_encoder"), - py::arg("base_model_seq_lens_decoder"), - py::arg("base_model_step_idx"), - py::arg("base_model_stop_flags"), - py::arg("base_model_is_block_step"), - py::arg("base_model_draft_tokens"), - py::arg("max_draft_token"), - py::arg("truncate_first_token"), - py::arg("splitwise_prefill"), - "Preprocess data for draft model in speculative decoding"); - - m.def("draft_model_preprocess_v2", - &DraftModelPreprocessV2, - py::arg("draft_tokens"), - py::arg("input_ids"), - py::arg("stop_flags"), - py::arg("seq_lens_this_time"), - py::arg("seq_lens_encoder"), - py::arg("seq_lens_decoder"), - py::arg("step_idx"), py::arg("not_need_stop"), py::arg("is_block_step"), py::arg("batch_drop"), @@ -1118,16 +1061,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("seq_lens_encoder"), "Get padding offset"); - m.def("speculate_get_padding_offset_v2", - &SpeculateGetPaddingOffsetV2, - py::arg("input_ids"), - py::arg("draft_tokens"), - py::arg("cum_offsets"), - py::arg("token_num"), - py::arg("seq_len"), - py::arg("seq_lens_encoder"), - "Get padding offset v2"); - m.def("speculate_step_reschedule", &SpeculateStepSchedule, py::arg("stop_flags"), diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index 38e604c40cc..09a426a3126 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -75,48 +75,47 @@ DLL_EXPORT int get_padding_offset(Context* ctx, const int max_seq_len, const int bs); -DLL_EXPORT int speculate_get_padding_offset_v2(Context* ctx, - int* batch_id_per_token, - int* cum_offsets_out, - int* cu_seqlens_q, - int* cu_seqlens_k, - const int* cum_offsets, - const int* seq_lens, - const int max_seq_len, - int bsz); - -DLL_EXPORT int draft_model_preprocess_v2( - api::Context* ctx, - int64_t* draft_tokens, - int64_t* input_ids, - bool* stop_flags, - int* seq_lens_this_time, - int* seq_lens_encoder, - int* seq_lens_decoder, - int64_t* step_idx, - bool* not_need_stop, - bool* is_block_step, - bool* batch_drop, - int64_t* pre_ids, - const int64_t* accept_tokens, - const int* accept_num, - const int* base_model_seq_lens_this_time, - const int* base_model_seq_lens_encoder, - const int* base_model_seq_lens_decoder, - const int64_t* base_model_step_idx, - const bool* base_model_stop_flags, - const bool* base_model_is_block_step, - int64_t* base_model_draft_tokens, - const int bsz, - const int num_model_step, - const int accept_tokens_len, - const int draft_tokens_len, - const int input_ids_len, - const int base_model_draft_tokens_len, - const int pre_ids_len, - const bool truncate_first_token, - const bool splitwise_prefill, - const bool kvcache_scheduler_v1); +DLL_EXPORT int speculate_get_padding_offset(Context* ctx, + int* batch_id_per_token, + int* cum_offsets_out, + int* cu_seqlens_q, + int* cu_seqlens_k, + const int* cum_offsets, + const int* seq_lens, + const int max_seq_len, + int bsz); + +DLL_EXPORT int draft_model_preprocess(api::Context* ctx, + int64_t* draft_tokens, + int64_t* input_ids, + bool* stop_flags, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int64_t* step_idx, + bool* not_need_stop, + bool* is_block_step, + bool* batch_drop, + int64_t* pre_ids, + const int64_t* accept_tokens, + const int* accept_num, + const int* base_model_seq_lens_this_time, + const int* base_model_seq_lens_encoder, + const int* base_model_seq_lens_decoder, + const int64_t* base_model_step_idx, + const bool* base_model_stop_flags, + const bool* base_model_is_block_step, + int64_t* base_model_draft_tokens, + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1); DLL_EXPORT int update_inputs(Context* ctx, bool* not_need_stop, @@ -446,35 +445,6 @@ DLL_EXPORT int draft_model_update(Context* ctx, const int substep, const bool prefill_one_step_stop); -DLL_EXPORT int draft_model_preprocess(api::Context* ctx, - int64_t* draft_tokens, - int64_t* input_ids, - bool* stop_flags, - int* seq_lens_this_time, - int* seq_lens_encoder, - int* seq_lens_decoder, - int64_t* step_idx, - int* seq_lens_encoder_record, - int* seq_lens_decoder_record, - bool* not_need_stop, - bool* batch_drop, - const int64_t* accept_tokens, - const int* accept_num, - const int* base_model_seq_lens_encoder, - const int* base_model_seq_lens_decoder, - const int64_t* base_model_step_idx, - const bool* base_model_stop_flags, - const bool* base_model_is_block_step, - int64_t* base_model_draft_tokens, - int real_bsz, - int max_draft_token, - int accept_tokens_len, - int draft_tokens_len, - int input_ids_len, - int base_model_draft_tokens_len, - bool truncate_first_token, - bool splitwise_prefill); - DLL_EXPORT int speculate_set_stop_value_multi_seqs(Context* ctx, bool* stop_flags, int64_t* accept_tokens, @@ -515,16 +485,6 @@ DLL_EXPORT int speculate_remove_padding(Context* ctx, int bsz, int token_num_data); -DLL_EXPORT int speculate_get_padding_offset(Context* ctx, - int* padding_offset, - int* cum_offsets_out, - int* cu_seqlens_q, - int* cu_seqlens_k, - const int* cum_offsets, - const int* seq_lens, - const int max_seq_len, - int bsz); - DLL_EXPORT int compute_self_order(api::Context* ctx, const int* last_seq_lens_this_time, const int* seq_lens_this_time, diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess.xpu index 9471fd096d9..425dc4b22f9 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess.xpu @@ -13,26 +13,29 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens, int* seq_lens_encoder, int* seq_lens_decoder, int64_t* step_idx, - int* seq_lens_encoder_record, - int* seq_lens_decoder_record, bool* not_need_stop, + bool* is_block_step, bool* batch_drop, + int64_t* pre_ids, const int64_t* accept_tokens, const int* accept_num, + const int* base_model_seq_lens_this_time, const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_decoder, const int64_t* base_model_step_idx, const bool* base_model_stop_flags, const bool* base_model_is_block_step, int64_t* base_model_draft_tokens, - int real_bsz, - int max_draft_token, - int accept_tokens_len, - int draft_tokens_len, - int input_ids_len, - int base_model_draft_tokens_len, - bool truncate_first_token, - bool splitwise_prefill) { + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { int cid = core_id(); int ncores = core_num(); int clusterid = cluster_id(); @@ -46,7 +49,7 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens, int64_t value_fu = -1; if (splitwise_prefill) { - for (; tid < real_bsz; tid += ncores * nclusters) { + for (; tid < bsz; tid += ncores * nclusters) { int64_t base_model_step_idx_now = 0; int seq_lens_encoder_now = 0; int seq_lens_this_time_now = 0; @@ -57,35 +60,25 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens, GM2LM_ASYNC( base_model_step_idx + tid, &base_model_step_idx_now, sizeof(int64_t)); - GM2LM_ASYNC(seq_lens_encoder_record + tid, - &seq_lens_encoder_record_now, - sizeof(int)); + GM2LM_ASYNC(seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int)); GM2LM(accept_tokens + tid * accept_tokens_len, &base_model_first_token, sizeof(int64_t)); - - if (base_model_step_idx_now == 1 && seq_lens_encoder_record_now > 0) { + if (seq_lens_encoder_now > 0) { not_stop_flag_sm[cid] += 1; - int seq_len_encoder_record = seq_lens_encoder_record_now; - seq_lens_encoder_now = seq_len_encoder_record; - seq_lens_encoder_record_now = -1; stop_flags_now = false; - int position = seq_len_encoder_record; + int position = seq_lens_encoder_now; if (truncate_first_token) { position = position - 1; input_ids_now = base_model_first_token; - seq_lens_this_time_now = seq_len_encoder_record; + seq_lens_this_time_now = seq_lens_encoder_now; } else { input_ids_now = base_model_first_token; - seq_lens_this_time_now = seq_len_encoder_record + 1; + seq_lens_this_time_now = seq_lens_encoder_now + 1; } LM2GM_ASYNC(&input_ids_now, input_ids + tid * input_ids_len + position, sizeof(int64_t)); - LM2GM_ASYNC(&seq_lens_encoder_record_now, - seq_lens_encoder_record + tid, - sizeof(int)); - } else { stop_flags_now = true; seq_lens_this_time_now = 0; @@ -98,21 +91,23 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens, LM2GM(&seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int)); } } else { - for (; tid < real_bsz; tid += ncores * nclusters) { + for (; tid < bsz; tid += ncores * nclusters) { bool base_model_stop_flags_now = false; bool base_model_is_block_step_now = false; bool batch_drop_now = false; bool stop_flags_now = false; + bool is_block_step_now = false; int seq_lens_this_time_now = 0; - int seq_lens_encoder_record_now = 0; int seq_lens_encoder_now = 0; int seq_lens_decoder_new = 0; - int seq_lens_decoder_record_now = 0; int accept_num_now = 0; int base_model_seq_lens_decoder_now = 0; + int base_model_seq_lens_this_time_now = 0; int64_t step_id_now = 0; int64_t base_model_step_idx_now; + int64_t pre_ids_now; mfence(); + GM2LM_ASYNC(is_block_step + tid, &is_block_step_now, sizeof(bool)); GM2LM_ASYNC(base_model_stop_flags + tid, &base_model_stop_flags_now, sizeof(bool)); @@ -121,12 +116,6 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens, sizeof(bool)); GM2LM_ASYNC(batch_drop + tid, &batch_drop_now, sizeof(bool)); GM2LM_ASYNC(stop_flags + tid, &stop_flags_now, sizeof(bool)); - GM2LM_ASYNC(seq_lens_encoder_record + tid, - &seq_lens_encoder_record_now, - sizeof(int)); - GM2LM_ASYNC(seq_lens_decoder_record + tid, - &seq_lens_decoder_record_now, - sizeof(int)); GM2LM_ASYNC(seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int)); GM2LM_ASYNC(seq_lens_decoder + tid, &seq_lens_decoder_new, sizeof(int)); @@ -135,6 +124,9 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens, accept_tokens_len * sizeof(int64_t)); GM2LM_ASYNC(accept_num + tid, &accept_num_now, sizeof(int)); + GM2LM_ASYNC(base_model_seq_lens_this_time + tid, + &base_model_seq_lens_this_time_now, + sizeof(int)); GM2LM_ASYNC(base_model_seq_lens_decoder + tid, &base_model_seq_lens_decoder_now, sizeof(int)); @@ -148,57 +140,67 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens, base_model_draft_tokens + tid * base_model_draft_tokens_len + i, sizeof(int)); } - if (base_model_stop_flags_now && base_model_is_block_step_now) { - batch_drop_now = true; - stop_flags_now = true; + if (kvcache_scheduler_v1) { + if (base_model_stop_flags_now && base_model_is_block_step_now) { + stop_flags_now = true; + is_block_step_now = true; + } + } else { + if (base_model_stop_flags_now && base_model_is_block_step_now) { + batch_drop_now = true; + stop_flags_now = true; + } } if (!(base_model_stop_flags_now || batch_drop_now)) { not_stop_flag_sm[cid] += 1; - if (base_model_step_idx_now == 0) { - seq_lens_this_time_now = 0; - not_stop_flag_sm[cid] -= 1; // 因为上面加过,这次减去,符合=0逻辑 - } else if (base_model_step_idx_now == 1 && - seq_lens_encoder_record_now > 0) { - int seq_len_encoder_record = seq_lens_encoder_record_now; - seq_lens_encoder_now = seq_len_encoder_record; - seq_lens_encoder_record_now = -1; - seq_lens_decoder_new = seq_lens_decoder_record_now; - seq_lens_decoder_record_now = 0; + if (seq_lens_encoder_now > 0) { + int seq_len_encoder = seq_lens_encoder_now; stop_flags_now = false; int64_t base_model_first_token = accept_tokens_now[0]; - int position = seq_len_encoder_record; + LM2GM(&base_model_first_token, + pre_ids + tid * pre_ids_len, + sizeof(int64_t)); + int position = seq_len_encoder; if (truncate_first_token) { LM2GM(&base_model_first_token, input_ids + tid * input_ids_len + position - 1, sizeof(int64_t)); - seq_lens_this_time_now = seq_len_encoder_record; + seq_lens_this_time_now = seq_len_encoder; } else { LM2GM(&base_model_first_token, input_ids + tid * input_ids_len + position, sizeof(int64_t)); - seq_lens_this_time_now = seq_len_encoder_record + 1; + seq_lens_this_time_now = seq_len_encoder + 1; + } + } else { + if (kvcache_scheduler_v1) { + if (!base_model_is_block_step_now && is_block_step_now) { + is_block_step_now = false; + } } - } else if (accept_num_now <= max_draft_token) { if (stop_flags_now) { stop_flags_now = false; - seq_lens_decoder_new = base_model_seq_lens_decoder_now; - step_id_now = base_model_step_idx_now; + seq_lens_decoder_new = base_model_seq_lens_decoder_now - + base_model_seq_lens_this_time_now; + step_id_now = + base_model_step_idx_now - base_model_seq_lens_this_time_now; + } else { - seq_lens_decoder_new -= max_draft_token - accept_num_now; - step_id_now -= max_draft_token - accept_num_now; + seq_lens_decoder_new -= num_model_step - 1; + step_id_now -= num_model_step - 1; } - int64_t modified_token = accept_tokens_now[accept_num_now - 1]; - LM2GM(&modified_token, - draft_tokens + tid * draft_tokens_len, - sizeof(int64_t)); - seq_lens_this_time_now = 1; - - } else /*Accept all draft tokens*/ { - LM2GM(accept_tokens_now + max_draft_token, - draft_tokens + tid * draft_tokens_len + 1, - sizeof(int64_t)); - seq_lens_this_time_now = 2; + for (int i = 0; i < accept_num_now; i++) { + const int pre_id_pos = + base_model_step_idx_now - (accept_num_now - i); + LM2GM(accept_tokens_now + i, + draft_tokens + tid * draft_tokens_len + i, + sizeof(int64_t)); + LM2GM(accept_tokens_now + i, + pre_ids + tid * pre_ids_len + pre_id_pos, + sizeof(int64_t)); + } + seq_lens_this_time_now = accept_num_now; } } else { @@ -209,17 +211,11 @@ __global__ void draft_model_preprocess(int64_t* draft_tokens, } LM2GM_ASYNC(&stop_flags_now, stop_flags + tid, sizeof(bool)); LM2GM_ASYNC(&batch_drop_now, batch_drop + tid, sizeof(bool)); - + LM2GM_ASYNC(&is_block_step_now, is_block_step + tid, sizeof(bool)); LM2GM_ASYNC(&seq_lens_decoder_new, seq_lens_decoder + tid, sizeof(int)); LM2GM_ASYNC( &seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int)); LM2GM_ASYNC(&seq_lens_encoder_now, seq_lens_encoder + tid, sizeof(int)); - LM2GM_ASYNC(&seq_lens_encoder_record_now, - seq_lens_encoder_record + tid, - sizeof(int)); - LM2GM_ASYNC(&seq_lens_decoder_record_now, - seq_lens_decoder_record + tid, - sizeof(int)); LM2GM_ASYNC(&step_id_now, step_idx + tid, sizeof(int64_t)); } } diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess_v2.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess_v2.xpu deleted file mode 100644 index 9d26919c33a..00000000000 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/draft_model_preprocess_v2.xpu +++ /dev/null @@ -1,240 +0,0 @@ -#include "xpu/kernel/cluster.h" -#include "xpu/kernel/cluster_debug.h" -#include "xpu/kernel/cluster_partition.h" -#include "xpu/kernel/cluster_primitive.h" -#include "xpu/kernel/cluster_simd.h" - -namespace xpu3 { -namespace plugin { -__global__ void draft_model_preprocess_v2( - int64_t* draft_tokens, - int64_t* input_ids, - bool* stop_flags, - int* seq_lens_this_time, - int* seq_lens_encoder, - int* seq_lens_decoder, - int64_t* step_idx, - bool* not_need_stop, - bool* is_block_step, - bool* batch_drop, - int64_t* pre_ids, - const int64_t* accept_tokens, - const int* accept_num, - const int* base_model_seq_lens_this_time, - const int* base_model_seq_lens_encoder, - const int* base_model_seq_lens_decoder, - const int64_t* base_model_step_idx, - const bool* base_model_stop_flags, - const bool* base_model_is_block_step, - int64_t* base_model_draft_tokens, - const int bsz, - const int num_model_step, - const int accept_tokens_len, - const int draft_tokens_len, - const int input_ids_len, - const int base_model_draft_tokens_len, - const int pre_ids_len, - const bool truncate_first_token, - const bool splitwise_prefill, - const bool kvcache_scheduler_v1) { - int cid = core_id(); - int ncores = core_num(); - int clusterid = cluster_id(); - int nclusters = cluster_num(); - int tid = clusterid * ncores + cid; - __shared__ int not_stop_flag_sm[64]; - not_stop_flag_sm[cid] = 0; - int64_t accept_tokens_now[128]; - - int value_zero = 0; - int64_t value_fu = -1; - - if (splitwise_prefill) { - for (; tid < bsz; tid += ncores * nclusters) { - int64_t base_model_step_idx_now = 0; - int seq_lens_encoder_now = 0; - int seq_lens_this_time_now = 0; - bool stop_flags_now = false; - int64_t base_model_first_token; - int seq_lens_encoder_record_now = 0; - int64_t input_ids_now = 0; - - GM2LM_ASYNC( - base_model_step_idx + tid, &base_model_step_idx_now, sizeof(int64_t)); - GM2LM_ASYNC(seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int)); - GM2LM(accept_tokens + tid * accept_tokens_len, - &base_model_first_token, - sizeof(int64_t)); - if (seq_lens_encoder_now > 0) { - not_stop_flag_sm[cid] += 1; - stop_flags_now = false; - int position = seq_lens_encoder_now; - if (truncate_first_token) { - position = position - 1; - input_ids_now = base_model_first_token; - seq_lens_this_time_now = seq_lens_encoder_now; - } else { - input_ids_now = base_model_first_token; - seq_lens_this_time_now = seq_lens_encoder_now + 1; - } - LM2GM_ASYNC(&input_ids_now, - input_ids + tid * input_ids_len + position, - sizeof(int64_t)); - } else { - stop_flags_now = true; - seq_lens_this_time_now = 0; - seq_lens_encoder_now = 0; - not_stop_flag_sm[cid] += 0; - LM2GM_ASYNC(&value_zero, seq_lens_decoder + tid, sizeof(int)); - } - LM2GM_ASYNC(&seq_lens_encoder_now, seq_lens_encoder + tid, sizeof(int)); - LM2GM_ASYNC(&stop_flags_now, stop_flags + tid, sizeof(bool)); - LM2GM(&seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int)); - } - } else { - for (; tid < bsz; tid += ncores * nclusters) { - bool base_model_stop_flags_now = false; - bool base_model_is_block_step_now = false; - bool batch_drop_now = false; - bool stop_flags_now = false; - bool is_block_step_now = false; - int seq_lens_this_time_now = 0; - int seq_lens_encoder_now = 0; - int seq_lens_decoder_new = 0; - int accept_num_now = 0; - int base_model_seq_lens_decoder_now = 0; - int base_model_seq_lens_this_time_now = 0; - int64_t step_id_now = 0; - int64_t base_model_step_idx_now; - int64_t pre_ids_now; - mfence(); - GM2LM_ASYNC(is_block_step + tid, &is_block_step_now, sizeof(bool)); - GM2LM_ASYNC(base_model_stop_flags + tid, - &base_model_stop_flags_now, - sizeof(bool)); - GM2LM_ASYNC(base_model_is_block_step + tid, - &base_model_is_block_step_now, - sizeof(bool)); - GM2LM_ASYNC(batch_drop + tid, &batch_drop_now, sizeof(bool)); - GM2LM_ASYNC(stop_flags + tid, &stop_flags_now, sizeof(bool)); - GM2LM_ASYNC(seq_lens_encoder + tid, &seq_lens_encoder_now, sizeof(int)); - GM2LM_ASYNC(seq_lens_decoder + tid, &seq_lens_decoder_new, sizeof(int)); - - GM2LM_ASYNC(accept_tokens + tid * accept_tokens_len, - accept_tokens_now, - accept_tokens_len * sizeof(int64_t)); - GM2LM_ASYNC(accept_num + tid, &accept_num_now, sizeof(int)); - - GM2LM_ASYNC(base_model_seq_lens_this_time + tid, - &base_model_seq_lens_this_time_now, - sizeof(int)); - GM2LM_ASYNC(base_model_seq_lens_decoder + tid, - &base_model_seq_lens_decoder_now, - sizeof(int)); - GM2LM_ASYNC(step_idx + tid, &step_id_now, sizeof(int64_t)); - GM2LM( - base_model_step_idx + tid, &base_model_step_idx_now, sizeof(int64_t)); - - for (int i = 1; i < base_model_draft_tokens_len; i++) { - LM2GM_ASYNC( - &value_fu, - base_model_draft_tokens + tid * base_model_draft_tokens_len + i, - sizeof(int)); - } - if (kvcache_scheduler_v1) { - if (base_model_stop_flags_now && base_model_is_block_step_now) { - stop_flags_now = true; - is_block_step_now = true; - } - } else { - if (base_model_stop_flags_now && base_model_is_block_step_now) { - batch_drop_now = true; - stop_flags_now = true; - } - } - - if (!(base_model_stop_flags_now || batch_drop_now)) { - not_stop_flag_sm[cid] += 1; - if (seq_lens_encoder_now > 0) { - int seq_len_encoder = seq_lens_encoder_now; - stop_flags_now = false; - int64_t base_model_first_token = accept_tokens_now[0]; - LM2GM(&base_model_first_token, - pre_ids + tid * pre_ids_len, - sizeof(int64_t)); - int position = seq_len_encoder; - if (truncate_first_token) { - LM2GM(&base_model_first_token, - input_ids + tid * input_ids_len + position - 1, - sizeof(int64_t)); - seq_lens_this_time_now = seq_len_encoder; - } else { - LM2GM(&base_model_first_token, - input_ids + tid * input_ids_len + position, - sizeof(int64_t)); - seq_lens_this_time_now = seq_len_encoder + 1; - } - } else { - if (kvcache_scheduler_v1) { - if (!base_model_is_block_step_now && is_block_step_now) { - is_block_step_now = false; - } - } - if (stop_flags_now) { - stop_flags_now = false; - seq_lens_decoder_new = base_model_seq_lens_decoder_now - - base_model_seq_lens_this_time_now; - step_id_now = - base_model_step_idx_now - base_model_seq_lens_this_time_now; - - } else { - seq_lens_decoder_new -= num_model_step - 1; - step_id_now -= num_model_step - 1; - } - for (int i = 0; i < accept_num_now; i++) { - const int pre_id_pos = - base_model_step_idx_now - (accept_num_now - i); - LM2GM(accept_tokens_now + i, - draft_tokens + tid * draft_tokens_len + i, - sizeof(int64_t)); - LM2GM(accept_tokens_now + i, - pre_ids + tid * pre_ids_len + pre_id_pos, - sizeof(int64_t)); - } - seq_lens_this_time_now = accept_num_now; - } - - } else { - stop_flags_now = true; - seq_lens_this_time_now = 0; - seq_lens_encoder_now = 0; - seq_lens_decoder_new = 0; - } - LM2GM_ASYNC(&stop_flags_now, stop_flags + tid, sizeof(bool)); - LM2GM_ASYNC(&batch_drop_now, batch_drop + tid, sizeof(bool)); - LM2GM_ASYNC(&is_block_step_now, is_block_step + tid, sizeof(bool)); - LM2GM_ASYNC(&seq_lens_decoder_new, seq_lens_decoder + tid, sizeof(int)); - LM2GM_ASYNC( - &seq_lens_this_time_now, seq_lens_this_time + tid, sizeof(int)); - LM2GM_ASYNC(&seq_lens_encoder_now, seq_lens_encoder + tid, sizeof(int)); - LM2GM_ASYNC(&step_id_now, step_idx + tid, sizeof(int64_t)); - } - } - mfence(); - sync_cluster(); - bool value_true = true; - bool value_false = false; - if (cid == 0) { - for (int i = 0; i < ncores; i++) { - not_stop_flag_sm[0] += not_stop_flag_sm[i]; - } - if (not_stop_flag_sm[0] > 0) { - LM2GM(&value_true, not_need_stop, sizeof(bool)); - } else { - LM2GM(&value_false, not_need_stop, sizeof(bool)); - } - } -} - -} // namespace plugin -} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu index 637e076d625..a1e766d31ed 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_padding_offset.xpu @@ -65,7 +65,7 @@ __global__ void speculate_remove_padding(T* output_data, } } -__global__ void speculate_get_padding_offset(int* padding_offset, +__global__ void speculate_get_padding_offset(int* batch_id_per_token, int* cum_offsets_out, int* cu_seqlens_q, int* cu_seqlens_k, @@ -89,42 +89,6 @@ __global__ void speculate_get_padding_offset(int* padding_offset, } GM2LM(cum_offsets + bi, &cum_offsets_now_ind, sizeof(int)); - for (int i = tid; i < seq_lens_now; i += ncores) { - LM2GM(&cum_offsets_now, - padding_offset + bi * max_seq_len - cum_offsets_now + i, - sizeof(int)); - } - LM2GM(&cum_offsets_now, cum_offsets_out + bi, sizeof(int)); - int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets_now_ind; - LM2GM(&cum_seq_len, cu_seqlens_q + bi + 1, sizeof(int)); - LM2GM(&cum_seq_len, cu_seqlens_k + bi + 1, sizeof(int)); - } -} - -__global__ void speculate_get_padding_offset_v2(int* batch_id_per_token, - int* cum_offsets_out, - int* cu_seqlens_q, - int* cu_seqlens_k, - const int* cum_offsets, - const int* seq_lens, - const int max_seq_len, - int bsz) { - int bid = cluster_id(); - int tid = core_id(); - int ncores = core_num(); - int nclusters = cluster_num(); - int seq_lens_now = 0; - int cum_offsets_now = 0; - int cum_offsets_now_ind = 0; - for (int bi = bid; bi < bsz; bi += nclusters) { - GM2LM(seq_lens + bi, &seq_lens_now, sizeof(int)); - if (bi == 0) { - cum_offsets_now = 0; - } else { - GM2LM(cum_offsets + bi - 1, &cum_offsets_now, sizeof(int)); - } - GM2LM(cum_offsets + bi, &cum_offsets_now_ind, sizeof(int)); - for (int i = tid; i < seq_lens_now; i += ncores) { LM2GM(&bi, batch_id_per_token + bi * max_seq_len - cum_offsets_now + i, diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess.cpp index 9ca1f2224f0..3a9273aee55 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess.cpp @@ -27,26 +27,29 @@ __attribute__((global)) void draft_model_preprocess( int* seq_lens_encoder, int* seq_lens_decoder, int64_t* step_idx, - int* seq_lens_encoder_record, - int* seq_lens_decoder_record, bool* not_need_stop, + bool* is_block_step, bool* batch_drop, + int64_t* pre_ids, const int64_t* accept_tokens, const int* accept_num, + const int* base_model_seq_lens_this_time, const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_decoder, const int64_t* base_model_step_idx, const bool* base_model_stop_flags, const bool* base_model_is_block_step, int64_t* base_model_draft_tokens, - int real_bsz, - int max_draft_token, - int accept_tokens_len, - int draft_tokens_len, - int input_ids_len, - int base_model_draft_tokens_len, - bool truncate_first_token, - bool splitwise_prefill); + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1); } // namespace plugin } // namespace xpu3 @@ -67,49 +70,47 @@ static int cpu_wrapper(api::Context* ctx, int* seq_lens_encoder, int* seq_lens_decoder, int64_t* step_idx, - int* seq_lens_encoder_record, - int* seq_lens_decoder_record, bool* not_need_stop, + bool* is_block_step, bool* batch_drop, + int64_t* pre_ids, const int64_t* accept_tokens, const int* accept_num, + const int* base_model_seq_lens_this_time, const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_decoder, const int64_t* base_model_step_idx, const bool* base_model_stop_flags, const bool* base_model_is_block_step, int64_t* base_model_draft_tokens, - int real_bsz, - int max_draft_token, - int accept_tokens_len, - int draft_tokens_len, - int input_ids_len, - int base_model_draft_tokens_len, - bool truncate_first_token, - bool splitwise_prefill) { + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { int64_t not_stop_flag_sum = 0; int64_t not_stop_flag = 0; - for (int tid = 0; tid < real_bsz; tid++) { + for (int tid = 0; tid < bsz; tid++) { if (splitwise_prefill) { - int base_model_step_idx_now = base_model_step_idx[tid]; auto* input_ids_now = input_ids + tid * input_ids_len; auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len; - // printf("bid: %d, base_model_step_idx_now: %d seq_lens_encoder_record: - // %d\n", tid, base_model_step_idx_now, seq_lens_encoder_record[tid]); - if (base_model_step_idx_now == 1 && seq_lens_encoder_record[tid] > 0) { + if (seq_lens_encoder[tid] > 0) { not_stop_flag = 1; - int seq_len_encoder_record = seq_lens_encoder_record[tid]; - seq_lens_encoder[tid] = seq_len_encoder_record; - seq_lens_encoder_record[tid] = -1; + int seq_len_encoder = seq_lens_encoder[tid]; stop_flags[tid] = false; int64_t base_model_first_token = accept_tokens_now[0]; - int position = seq_len_encoder_record; + int position = seq_len_encoder; if (truncate_first_token) { input_ids_now[position - 1] = base_model_first_token; - seq_lens_this_time[tid] = seq_len_encoder_record; + seq_lens_this_time[tid] = seq_len_encoder; } else { input_ids_now[position] = base_model_first_token; - seq_lens_this_time[tid] = seq_len_encoder_record + 1; + seq_lens_this_time[tid] = seq_len_encoder + 1; } } else { stop_flags[tid] = true; @@ -120,63 +121,77 @@ static int cpu_wrapper(api::Context* ctx, } not_stop_flag_sum += not_stop_flag; } else { - auto base_model_step_idx_now = base_model_step_idx[tid]; auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len; auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len; auto accept_num_now = accept_num[tid]; auto* input_ids_now = input_ids + tid * input_ids_len; auto* base_model_draft_tokens_now = base_model_draft_tokens + tid * base_model_draft_tokens_len; + auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid]; + const int32_t base_model_seq_len_this_time = + base_model_seq_lens_this_time[tid]; + auto* pre_ids_now = pre_ids + tid * pre_ids_len; for (int i = 1; i < base_model_draft_tokens_len; i++) { base_model_draft_tokens_now[i] = -1; } - if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { - batch_drop[tid] = true; - stop_flags[tid] = true; + if (kvcache_scheduler_v1) { + if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { + stop_flags[tid] = true; + is_block_step[tid] = true; + // Need to continue infer + } + } else { + if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { + batch_drop[tid] = true; + stop_flags[tid] = true; + } } if (!(base_model_stop_flags[tid] || batch_drop[tid])) { not_stop_flag = 1; - // 1. first token - - if (base_model_step_idx_now == 0) { - seq_lens_this_time[tid] = 0; - not_stop_flag = 0; - } else if (base_model_step_idx_now == 1 && - seq_lens_encoder_record[tid] > 0) { + // prefill generation + if (seq_lens_encoder[tid] > 0) { // Can be extended to first few tokens - int seq_len_encoder_record = seq_lens_encoder_record[tid]; - seq_lens_encoder[tid] = seq_len_encoder_record; - seq_lens_encoder_record[tid] = -1; - seq_lens_decoder[tid] = seq_lens_decoder_record[tid]; - seq_lens_decoder_record[tid] = 0; + int seq_len_encoder = seq_lens_encoder[tid]; stop_flags[tid] = false; int64_t base_model_first_token = accept_tokens_now[0]; - int position = seq_len_encoder_record; + pre_ids_now[0] = base_model_first_token; + int position = seq_len_encoder; if (truncate_first_token) { input_ids_now[position - 1] = base_model_first_token; - seq_lens_this_time[tid] = seq_len_encoder_record; + seq_lens_this_time[tid] = seq_len_encoder; } else { input_ids_now[position] = base_model_first_token; - seq_lens_this_time[tid] = seq_len_encoder_record + 1; + seq_lens_this_time[tid] = seq_len_encoder + 1; + } + } else { // decode generation + if (kvcache_scheduler_v1) { + // 3. try to recover mtp infer in V1 mode + if (!base_model_is_block_step[tid] && is_block_step[tid]) { + is_block_step[tid] = false; + } } - } else if (accept_num_now <= - max_draft_token) /*Accept partial draft tokens*/ { - // Base Model reject stop if (stop_flags[tid]) { stop_flags[tid] = false; - seq_lens_decoder[tid] = base_model_seq_lens_decoder[tid]; - step_idx[tid] = base_model_step_idx[tid]; + // TODO: check + seq_lens_decoder[tid] = + base_model_seq_len_decoder - base_model_seq_len_this_time; + step_idx[tid] = + base_model_step_idx[tid] - base_model_seq_len_this_time; } else { - seq_lens_decoder[tid] -= max_draft_token - accept_num_now; - step_idx[tid] -= max_draft_token - accept_num_now; + // 2: Last base model generated token and first MTP + // token + seq_lens_decoder[tid] -= num_model_step - 1; + step_idx[tid] -= num_model_step - 1; + } + for (int i = 0; i < accept_num_now; i++) { + draft_tokens_now[i] = accept_tokens_now[i]; + const int pre_id_pos = + base_model_step_idx[tid] - (accept_num_now - i); + const int64_t accept_token = accept_tokens_now[i]; + pre_ids_now[pre_id_pos] = accept_token; } - int64_t modified_token = accept_tokens_now[accept_num_now - 1]; - draft_tokens_now[0] = modified_token; - seq_lens_this_time[tid] = 1; - } else /*Accept all draft tokens*/ { - draft_tokens_now[1] = accept_tokens_now[max_draft_token]; - seq_lens_this_time[tid] = 2; + seq_lens_this_time[tid] = accept_num_now; } } else { stop_flags[tid] = true; @@ -199,26 +214,29 @@ static int xpu3_wrapper(api::Context* ctx, int* seq_lens_encoder, int* seq_lens_decoder, int64_t* step_idx, - int* seq_lens_encoder_record, - int* seq_lens_decoder_record, bool* not_need_stop, + bool* is_block_step, bool* batch_drop, + int64_t* pre_ids, const int64_t* accept_tokens, const int* accept_num, + const int* base_model_seq_lens_this_time, const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_decoder, const int64_t* base_model_step_idx, const bool* base_model_stop_flags, const bool* base_model_is_block_step, int64_t* base_model_draft_tokens, - int real_bsz, - int max_draft_token, - int accept_tokens_len, - int draft_tokens_len, - int input_ids_len, - int base_model_draft_tokens_len, - bool truncate_first_token, - bool splitwise_prefill) { + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { using XPU_INT64 = typename XPUIndexType::type; // NOTE: Don't change 16 to 64, because kernel use gsm @@ -230,26 +248,29 @@ static int xpu3_wrapper(api::Context* ctx, seq_lens_encoder, seq_lens_decoder, reinterpret_cast(step_idx), - seq_lens_encoder_record, - seq_lens_decoder_record, not_need_stop, + is_block_step, batch_drop, + reinterpret_cast(pre_ids), reinterpret_cast(accept_tokens), accept_num, + base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, reinterpret_cast(base_model_step_idx), base_model_stop_flags, base_model_is_block_step, reinterpret_cast(base_model_draft_tokens), - real_bsz, - max_draft_token, + bsz, + num_model_step, accept_tokens_len, draft_tokens_len, input_ids_len, base_model_draft_tokens_len, + pre_ids_len, truncate_first_token, - splitwise_prefill); + splitwise_prefill, + kvcache_scheduler_v1); return api::SUCCESS; } @@ -261,26 +282,29 @@ int draft_model_preprocess(api::Context* ctx, int* seq_lens_encoder, int* seq_lens_decoder, int64_t* step_idx, - int* seq_lens_encoder_record, - int* seq_lens_decoder_record, bool* not_need_stop, + bool* is_block_step, bool* batch_drop, + int64_t* pre_ids, const int64_t* accept_tokens, const int* accept_num, + const int* base_model_seq_lens_this_time, const int* base_model_seq_lens_encoder, const int* base_model_seq_lens_decoder, const int64_t* base_model_step_idx, const bool* base_model_stop_flags, const bool* base_model_is_block_step, int64_t* base_model_draft_tokens, - int real_bsz, - int max_draft_token, - int accept_tokens_len, - int draft_tokens_len, - int input_ids_len, - int base_model_draft_tokens_len, - bool truncate_first_token, - bool splitwise_prefill) { + const int bsz, + const int num_model_step, + const int accept_tokens_len, + const int draft_tokens_len, + const int input_ids_len, + const int base_model_draft_tokens_len, + const int pre_ids_len, + const bool truncate_first_token, + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { WRAPPER_CHECK_CTX(ctx); WRAPPER_DUMP_FUNCTION_T1(ctx, "draft_model_preprocess", int64_t); WRAPPER_DUMP_PARAM6(ctx, @@ -290,37 +314,34 @@ int draft_model_preprocess(api::Context* ctx, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder); - WRAPPER_DUMP_PARAM5(ctx, - step_idx, - seq_lens_encoder_record, - seq_lens_decoder_record, - not_need_stop, - batch_drop); + WRAPPER_DUMP_PARAM5( + ctx, step_idx, not_need_stop, is_block_step, batch_drop, pre_ids); WRAPPER_DUMP_PARAM3( ctx, accept_tokens, accept_num, base_model_seq_lens_encoder); - WRAPPER_DUMP_PARAM3(ctx, + WRAPPER_DUMP_PARAM4(ctx, + base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, base_model_stop_flags); WRAPPER_DUMP_PARAM3( - ctx, base_model_is_block_step, base_model_draft_tokens, real_bsz); - WRAPPER_DUMP_PARAM3( - ctx, max_draft_token, accept_tokens_len, draft_tokens_len); - WRAPPER_DUMP_PARAM3( - ctx, input_ids_len, base_model_draft_tokens_len, truncate_first_token); - WRAPPER_DUMP_PARAM1(ctx, splitwise_prefill); + ctx, base_model_is_block_step, base_model_draft_tokens, bsz); + WRAPPER_DUMP_PARAM3(ctx, num_model_step, accept_tokens_len, draft_tokens_len); + WRAPPER_DUMP_PARAM4(ctx, + input_ids_len, + base_model_draft_tokens_len, + pre_ids_len, + truncate_first_token); + WRAPPER_DUMP_PARAM2(ctx, splitwise_prefill, kvcache_scheduler_v1); WRAPPER_DUMP(ctx); - WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time); - WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * accept_tokens_len, accept_tokens); - WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * input_ids_len, input_ids); - WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * draft_tokens_len, draft_tokens); - WRAPPER_CHECK_PTR(ctx, - int64_t, - real_bsz * base_model_draft_tokens_len, - base_model_draft_tokens); + WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz * accept_tokens_len, accept_tokens); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz * input_ids_len, input_ids); + WRAPPER_CHECK_PTR(ctx, int64_t, bsz * draft_tokens_len, draft_tokens); + WRAPPER_CHECK_PTR( + ctx, int64_t, bsz * base_model_draft_tokens_len, base_model_draft_tokens); - WRAPPER_ASSERT_GT(ctx, real_bsz, 0); + WRAPPER_ASSERT_GT(ctx, bsz, 0); WRAPPER_ASSERT_LT(ctx, accept_tokens_len, 128); if (ctx->dev().type() == api::kCPU) { @@ -332,26 +353,29 @@ int draft_model_preprocess(api::Context* ctx, seq_lens_encoder, seq_lens_decoder, step_idx, - seq_lens_encoder_record, - seq_lens_decoder_record, not_need_stop, + is_block_step, batch_drop, + pre_ids, accept_tokens, accept_num, + base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, base_model_stop_flags, base_model_is_block_step, base_model_draft_tokens, - real_bsz, - max_draft_token, + bsz, + num_model_step, accept_tokens_len, draft_tokens_len, input_ids_len, base_model_draft_tokens_len, + pre_ids_len, truncate_first_token, - splitwise_prefill); + splitwise_prefill, + kvcache_scheduler_v1); } if (ctx->dev().type() == api::kXPU3) { return xpu3_wrapper(ctx, @@ -362,26 +386,29 @@ int draft_model_preprocess(api::Context* ctx, seq_lens_encoder, seq_lens_decoder, step_idx, - seq_lens_encoder_record, - seq_lens_decoder_record, not_need_stop, + is_block_step, batch_drop, + pre_ids, accept_tokens, accept_num, + base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, base_model_stop_flags, base_model_is_block_step, base_model_draft_tokens, - real_bsz, - max_draft_token, + bsz, + num_model_step, accept_tokens_len, draft_tokens_len, input_ids_len, base_model_draft_tokens_len, + pre_ids_len, truncate_first_token, - splitwise_prefill); + splitwise_prefill, + kvcache_scheduler_v1); } WRAPPER_UNIMPLEMENTED(ctx); } diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess_v2.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess_v2.cpp deleted file mode 100644 index 13b3b892b49..00000000000 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/draft_model_preprocess_v2.cpp +++ /dev/null @@ -1,419 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "xpu/plugin.h" -#include "xpu/refactor/impl/launch_strategy.h" -#include "xpu/refactor/impl_public/wrapper_check.h" -#include "xpu/xdnn.h" - -namespace xpu3 { -namespace plugin { -__attribute__((global)) void draft_model_preprocess_v2( - int64_t* draft_tokens, - int64_t* input_ids, - bool* stop_flags, - int* seq_lens_this_time, - int* seq_lens_encoder, - int* seq_lens_decoder, - int64_t* step_idx, - bool* not_need_stop, - bool* is_block_step, - bool* batch_drop, - int64_t* pre_ids, - const int64_t* accept_tokens, - const int* accept_num, - const int* base_model_seq_lens_this_time, - const int* base_model_seq_lens_encoder, - const int* base_model_seq_lens_decoder, - const int64_t* base_model_step_idx, - const bool* base_model_stop_flags, - const bool* base_model_is_block_step, - int64_t* base_model_draft_tokens, - const int bsz, - const int num_model_step, - const int accept_tokens_len, - const int draft_tokens_len, - const int input_ids_len, - const int base_model_draft_tokens_len, - const int pre_ids_len, - const bool truncate_first_token, - const bool splitwise_prefill, - const bool kvcache_scheduler_v1); -} // namespace plugin -} // namespace xpu3 - -namespace xpu2 { -namespace plugin {} // namespace plugin -} // namespace xpu2 - -namespace baidu { -namespace xpu { -namespace api { -namespace plugin { - -static int cpu_wrapper(api::Context* ctx, - int64_t* draft_tokens, - int64_t* input_ids, - bool* stop_flags, - int* seq_lens_this_time, - int* seq_lens_encoder, - int* seq_lens_decoder, - int64_t* step_idx, - bool* not_need_stop, - bool* is_block_step, - bool* batch_drop, - int64_t* pre_ids, - const int64_t* accept_tokens, - const int* accept_num, - const int* base_model_seq_lens_this_time, - const int* base_model_seq_lens_encoder, - const int* base_model_seq_lens_decoder, - const int64_t* base_model_step_idx, - const bool* base_model_stop_flags, - const bool* base_model_is_block_step, - int64_t* base_model_draft_tokens, - const int bsz, - const int num_model_step, - const int accept_tokens_len, - const int draft_tokens_len, - const int input_ids_len, - const int base_model_draft_tokens_len, - const int pre_ids_len, - const bool truncate_first_token, - const bool splitwise_prefill, - const bool kvcache_scheduler_v1) { - int64_t not_stop_flag_sum = 0; - int64_t not_stop_flag = 0; - for (int tid = 0; tid < bsz; tid++) { - if (splitwise_prefill) { - auto* input_ids_now = input_ids + tid * input_ids_len; - auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len; - if (seq_lens_encoder[tid] > 0) { - not_stop_flag = 1; - int seq_len_encoder = seq_lens_encoder[tid]; - stop_flags[tid] = false; - int64_t base_model_first_token = accept_tokens_now[0]; - int position = seq_len_encoder; - if (truncate_first_token) { - input_ids_now[position - 1] = base_model_first_token; - seq_lens_this_time[tid] = seq_len_encoder; - } else { - input_ids_now[position] = base_model_first_token; - seq_lens_this_time[tid] = seq_len_encoder + 1; - } - } else { - stop_flags[tid] = true; - seq_lens_this_time[tid] = 0; - seq_lens_decoder[tid] = 0; - seq_lens_encoder[tid] = 0; - not_stop_flag = 0; - } - not_stop_flag_sum += not_stop_flag; - } else { - auto* accept_tokens_now = accept_tokens + tid * accept_tokens_len; - auto* draft_tokens_now = draft_tokens + tid * draft_tokens_len; - auto accept_num_now = accept_num[tid]; - auto* input_ids_now = input_ids + tid * input_ids_len; - auto* base_model_draft_tokens_now = - base_model_draft_tokens + tid * base_model_draft_tokens_len; - auto base_model_seq_len_decoder = base_model_seq_lens_decoder[tid]; - const int32_t base_model_seq_len_this_time = - base_model_seq_lens_this_time[tid]; - auto* pre_ids_now = pre_ids + tid * pre_ids_len; - for (int i = 1; i < base_model_draft_tokens_len; i++) { - base_model_draft_tokens_now[i] = -1; - } - if (kvcache_scheduler_v1) { - if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { - stop_flags[tid] = true; - is_block_step[tid] = true; - // Need to continue infer - } - } else { - if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { - batch_drop[tid] = true; - stop_flags[tid] = true; - } - } - - if (!(base_model_stop_flags[tid] || batch_drop[tid])) { - not_stop_flag = 1; - // prefill generation - if (seq_lens_encoder[tid] > 0) { - // Can be extended to first few tokens - int seq_len_encoder = seq_lens_encoder[tid]; - stop_flags[tid] = false; - int64_t base_model_first_token = accept_tokens_now[0]; - pre_ids_now[0] = base_model_first_token; - int position = seq_len_encoder; - if (truncate_first_token) { - input_ids_now[position - 1] = base_model_first_token; - seq_lens_this_time[tid] = seq_len_encoder; - } else { - input_ids_now[position] = base_model_first_token; - seq_lens_this_time[tid] = seq_len_encoder + 1; - } - } else { // decode generation - if (kvcache_scheduler_v1) { - // 3. try to recover mtp infer in V1 mode - if (!base_model_is_block_step[tid] && is_block_step[tid]) { - is_block_step[tid] = false; - } - } - if (stop_flags[tid]) { - stop_flags[tid] = false; - // TODO: check - seq_lens_decoder[tid] = - base_model_seq_len_decoder - base_model_seq_len_this_time; - step_idx[tid] = - base_model_step_idx[tid] - base_model_seq_len_this_time; - } else { - // 2: Last base model generated token and first MTP - // token - seq_lens_decoder[tid] -= num_model_step - 1; - step_idx[tid] -= num_model_step - 1; - } - for (int i = 0; i < accept_num_now; i++) { - draft_tokens_now[i] = accept_tokens_now[i]; - const int pre_id_pos = - base_model_step_idx[tid] - (accept_num_now - i); - const int64_t accept_token = accept_tokens_now[i]; - pre_ids_now[pre_id_pos] = accept_token; - } - seq_lens_this_time[tid] = accept_num_now; - } - } else { - stop_flags[tid] = true; - seq_lens_this_time[tid] = 0; - seq_lens_decoder[tid] = 0; - seq_lens_encoder[tid] = 0; - } - not_stop_flag_sum += not_stop_flag; - } - } - not_need_stop[0] = not_stop_flag_sum > 0; - return api::SUCCESS; -} - -static int xpu3_wrapper(api::Context* ctx, - int64_t* draft_tokens, - int64_t* input_ids, - bool* stop_flags, - int* seq_lens_this_time, - int* seq_lens_encoder, - int* seq_lens_decoder, - int64_t* step_idx, - bool* not_need_stop, - bool* is_block_step, - bool* batch_drop, - int64_t* pre_ids, - const int64_t* accept_tokens, - const int* accept_num, - const int* base_model_seq_lens_this_time, - const int* base_model_seq_lens_encoder, - const int* base_model_seq_lens_decoder, - const int64_t* base_model_step_idx, - const bool* base_model_stop_flags, - const bool* base_model_is_block_step, - int64_t* base_model_draft_tokens, - const int bsz, - const int num_model_step, - const int accept_tokens_len, - const int draft_tokens_len, - const int input_ids_len, - const int base_model_draft_tokens_len, - const int pre_ids_len, - const bool truncate_first_token, - const bool splitwise_prefill, - const bool kvcache_scheduler_v1) { - using XPU_INT64 = typename XPUIndexType::type; - - // NOTE: Don't change 16 to 64, because kernel use gsm - xpu3::plugin::draft_model_preprocess_v2<<<1, 64, ctx->xpu_stream>>>( - reinterpret_cast(draft_tokens), - reinterpret_cast(input_ids), - stop_flags, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - reinterpret_cast(step_idx), - not_need_stop, - is_block_step, - batch_drop, - reinterpret_cast(pre_ids), - reinterpret_cast(accept_tokens), - accept_num, - base_model_seq_lens_this_time, - base_model_seq_lens_encoder, - base_model_seq_lens_decoder, - reinterpret_cast(base_model_step_idx), - base_model_stop_flags, - base_model_is_block_step, - reinterpret_cast(base_model_draft_tokens), - bsz, - num_model_step, - accept_tokens_len, - draft_tokens_len, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len, - truncate_first_token, - splitwise_prefill, - kvcache_scheduler_v1); - return api::SUCCESS; -} - -int draft_model_preprocess_v2(api::Context* ctx, - int64_t* draft_tokens, - int64_t* input_ids, - bool* stop_flags, - int* seq_lens_this_time, - int* seq_lens_encoder, - int* seq_lens_decoder, - int64_t* step_idx, - bool* not_need_stop, - bool* is_block_step, - bool* batch_drop, - int64_t* pre_ids, - const int64_t* accept_tokens, - const int* accept_num, - const int* base_model_seq_lens_this_time, - const int* base_model_seq_lens_encoder, - const int* base_model_seq_lens_decoder, - const int64_t* base_model_step_idx, - const bool* base_model_stop_flags, - const bool* base_model_is_block_step, - int64_t* base_model_draft_tokens, - const int bsz, - const int num_model_step, - const int accept_tokens_len, - const int draft_tokens_len, - const int input_ids_len, - const int base_model_draft_tokens_len, - const int pre_ids_len, - const bool truncate_first_token, - const bool splitwise_prefill, - const bool kvcache_scheduler_v1) { - WRAPPER_CHECK_CTX(ctx); - WRAPPER_DUMP_FUNCTION_T1(ctx, "draft_model_preprocess_v2", int64_t); - WRAPPER_DUMP_PARAM6(ctx, - draft_tokens, - input_ids, - stop_flags, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder); - WRAPPER_DUMP_PARAM5( - ctx, step_idx, not_need_stop, is_block_step, batch_drop, pre_ids); - WRAPPER_DUMP_PARAM3( - ctx, accept_tokens, accept_num, base_model_seq_lens_encoder); - WRAPPER_DUMP_PARAM4(ctx, - base_model_seq_lens_encoder, - base_model_seq_lens_decoder, - base_model_step_idx, - base_model_stop_flags); - WRAPPER_DUMP_PARAM3( - ctx, base_model_is_block_step, base_model_draft_tokens, bsz); - WRAPPER_DUMP_PARAM3(ctx, num_model_step, accept_tokens_len, draft_tokens_len); - WRAPPER_DUMP_PARAM4(ctx, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len, - truncate_first_token); - WRAPPER_DUMP_PARAM2(ctx, splitwise_prefill, kvcache_scheduler_v1); - WRAPPER_DUMP(ctx); - - WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens_this_time); - WRAPPER_CHECK_PTR(ctx, int64_t, bsz * accept_tokens_len, accept_tokens); - WRAPPER_CHECK_PTR(ctx, int64_t, bsz * input_ids_len, input_ids); - WRAPPER_CHECK_PTR(ctx, int64_t, bsz * draft_tokens_len, draft_tokens); - WRAPPER_CHECK_PTR( - ctx, int64_t, bsz * base_model_draft_tokens_len, base_model_draft_tokens); - - WRAPPER_ASSERT_GT(ctx, bsz, 0); - WRAPPER_ASSERT_LT(ctx, accept_tokens_len, 128); - - if (ctx->dev().type() == api::kCPU) { - return cpu_wrapper(ctx, - draft_tokens, - input_ids, - stop_flags, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - step_idx, - not_need_stop, - is_block_step, - batch_drop, - pre_ids, - accept_tokens, - accept_num, - base_model_seq_lens_this_time, - base_model_seq_lens_encoder, - base_model_seq_lens_decoder, - base_model_step_idx, - base_model_stop_flags, - base_model_is_block_step, - base_model_draft_tokens, - bsz, - num_model_step, - accept_tokens_len, - draft_tokens_len, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len, - truncate_first_token, - splitwise_prefill, - kvcache_scheduler_v1); - } - if (ctx->dev().type() == api::kXPU3) { - return xpu3_wrapper(ctx, - draft_tokens, - input_ids, - stop_flags, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - step_idx, - not_need_stop, - is_block_step, - batch_drop, - pre_ids, - accept_tokens, - accept_num, - base_model_seq_lens_this_time, - base_model_seq_lens_encoder, - base_model_seq_lens_decoder, - base_model_step_idx, - base_model_stop_flags, - base_model_is_block_step, - base_model_draft_tokens, - bsz, - num_model_step, - accept_tokens_len, - draft_tokens_len, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len, - truncate_first_token, - splitwise_prefill, - kvcache_scheduler_v1); - } - WRAPPER_UNIMPLEMENTED(ctx); -} - -} // namespace plugin -} // namespace api -} // namespace xpu -} // namespace baidu diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp index 21134d86807..fe4096cac0a 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_padding_offset.cpp @@ -33,16 +33,6 @@ __attribute__((global)) void speculate_remove_padding( int token_num_data); __attribute__((global)) void speculate_get_padding_offset( - int* padding_offset, - int* cum_offsets_out, - int* cu_seqlens_q, - int* cu_seqlens_k, - const int* cum_offsets, - const int* seq_lens, - const int max_seq_len, - int bsz); - -__attribute__((global)) void speculate_get_padding_offset_v2( int* batch_id_per_token, int* cum_offsets_out, int* cu_seqlens_q, @@ -88,7 +78,7 @@ static int cpu_wrapper_remove_padding(Context* ctx, } static int cpu_wrapper_get_padding_offset(Context* ctx, - int* padding_offset, + int* batch_id_per_token, int* cum_offsets_out, int* cu_seqlens_q, int* cu_seqlens_k, @@ -96,28 +86,6 @@ static int cpu_wrapper_get_padding_offset(Context* ctx, const int* seq_lens, const int max_seq_len, int bsz) { - for (int bi = 0; bi < bsz; ++bi) { - int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; - for (int i = 0; i < seq_lens[bi]; i++) { - padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset; - } - cum_offsets_out[bi] = cum_offset; - int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi]; - cu_seqlens_q[bi + 1] = cum_seq_len; - cu_seqlens_k[bi + 1] = cum_seq_len; - } - return api::SUCCESS; -} - -static int cpu_wrapper_get_padding_offset_v2(Context* ctx, - int* batch_id_per_token, - int* cum_offsets_out, - int* cu_seqlens_q, - int* cu_seqlens_k, - const int* cum_offsets, - const int* seq_lens, - const int max_seq_len, - int bsz) { for (int bi = 0; bi < bsz; ++bi) { int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; for (int i = 0; i < seq_lens[bi]; i++) { @@ -161,7 +129,7 @@ static int xpu3_wrapper_remove_padding(Context* ctx, } static int xpu3_wrapper_get_padding_offset(Context* ctx, - int* padding_offset, + int* batch_id_per_token, int* cum_offsets_out, int* cu_seqlens_q, int* cu_seqlens_k, @@ -171,28 +139,6 @@ static int xpu3_wrapper_get_padding_offset(Context* ctx, int bsz) { xpu3::plugin:: speculate_get_padding_offset<<ncluster(), 64, ctx->xpu_stream>>>( - padding_offset, - cum_offsets_out, - cu_seqlens_q, - cu_seqlens_k, - cum_offsets, - seq_lens, - max_seq_len, - bsz); - return api::SUCCESS; -} - -static int xpu3_wrapper_get_padding_offset_v2(Context* ctx, - int* batch_id_per_token, - int* cum_offsets_out, - int* cu_seqlens_q, - int* cu_seqlens_k, - const int* cum_offsets, - const int* seq_lens, - const int max_seq_len, - int bsz) { - xpu3::plugin:: - speculate_get_padding_offset_v2<<ncluster(), 64, ctx->xpu_stream>>>( batch_id_per_token, cum_offsets_out, cu_seqlens_q, @@ -269,7 +215,7 @@ int speculate_remove_padding(Context* ctx, } int speculate_get_padding_offset(Context* ctx, - int* padding_offset, + int* batch_id_per_token, int* cum_offsets_out, int* cu_seqlens_q, int* cu_seqlens_k, @@ -281,7 +227,7 @@ int speculate_get_padding_offset(Context* ctx, WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_get_padding_offset", float); WRAPPER_DUMP_PARAM6(ctx, - padding_offset, + batch_id_per_token, cum_offsets_out, cu_seqlens_q, cu_seqlens_k, @@ -301,7 +247,7 @@ int speculate_get_padding_offset(Context* ctx, if (ctx->dev().type() == api::kCPU) { return cpu_wrapper_get_padding_offset(ctx, - padding_offset, + batch_id_per_token, cum_offsets_out, cu_seqlens_q, cu_seqlens_k, @@ -312,7 +258,7 @@ int speculate_get_padding_offset(Context* ctx, } if (ctx->dev().type() == api::kXPU3) { return xpu3_wrapper_get_padding_offset(ctx, - padding_offset, + batch_id_per_token, cum_offsets_out, cu_seqlens_q, cu_seqlens_k, @@ -325,63 +271,6 @@ int speculate_get_padding_offset(Context* ctx, WRAPPER_UNIMPLEMENTED(ctx); } -int speculate_get_padding_offset_v2(Context* ctx, - int* batch_id_per_token, - int* cum_offsets_out, - int* cu_seqlens_q, - int* cu_seqlens_k, - const int* cum_offsets, - const int* seq_lens, - const int max_seq_len, - int bsz) { - WRAPPER_CHECK_CTX(ctx); - - WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_get_padding_offset", float); - WRAPPER_DUMP_PARAM6(ctx, - batch_id_per_token, - cum_offsets_out, - cu_seqlens_q, - cu_seqlens_k, - cum_offsets, - seq_lens); - WRAPPER_DUMP_PARAM2(ctx, max_seq_len, bsz); - WRAPPER_DUMP(ctx); - - WRAPPER_CHECK_PTR(ctx, int, bsz, cum_offsets); - WRAPPER_CHECK_PTR(ctx, int, bsz, seq_lens); - WRAPPER_CHECK_PTR(ctx, int, bsz, cum_offsets_out); - WRAPPER_CHECK_PTR(ctx, int, bsz + 1, cu_seqlens_q); - WRAPPER_CHECK_PTR(ctx, int, bsz + 1, cu_seqlens_k); - - WRAPPER_ASSERT_GT(ctx, bsz, 0); - WRAPPER_ASSERT_GT(ctx, max_seq_len, 0); - - if (ctx->dev().type() == api::kCPU) { - return cpu_wrapper_get_padding_offset_v2(ctx, - batch_id_per_token, - cum_offsets_out, - cu_seqlens_q, - cu_seqlens_k, - cum_offsets, - seq_lens, - max_seq_len, - bsz); - } - if (ctx->dev().type() == api::kXPU3) { - return xpu3_wrapper_get_padding_offset_v2(ctx, - batch_id_per_token, - cum_offsets_out, - cu_seqlens_q, - cu_seqlens_k, - cum_offsets, - seq_lens, - max_seq_len, - bsz); - } - - WRAPPER_UNIMPLEMENTED(ctx); -} - #define INSTANTIATION_SPECULATE_REMOVE_PADDING(T) \ template int speculate_remove_padding(Context * ctx, \ T * x_remove_padding, \ From 8ad6f88a879c38828d33f2c31f05f61d0b483705 Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Mon, 24 Nov 2025 11:54:52 +0000 Subject: [PATCH 13/17] fix mtp kernel test --- ...test_adjust_batch_and_gather_next_token.py | 76 ++- .../test/test_draft_model_preprocess.py | 445 ++++++++++++++---- .../xpu_ops/test/test_speculate_step.py | 181 ++++--- .../xpu_ops/test/test_speculate_update_v3.py | 257 ++++------ 4 files changed, 557 insertions(+), 402 deletions(-) diff --git a/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py b/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py index eebbf81f10f..758dff17e58 100644 --- a/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py +++ b/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest # 导入 unittest + import numpy as np import paddle -import pytest from fastdeploy.model_executor.ops.xpu import ( adjust_batch, @@ -100,9 +101,11 @@ def _run_test_base(seq_lens_this_time_data, output_padding_offset): None, # output_padding_offset -1, # max_input_length ) - assert ( - paddle.equal_all(adjusted_output.astype("float32").cpu(), adjusted_output_cpu.astype("float32")).all().item() - ), "adjust_batch check failed!" + + # 用 np.testing 替代原生 assert,错误信息更友好 + adjusted_output_np = adjusted_output.astype("float32").cpu().numpy() + adjusted_output_cpu_np = adjusted_output_cpu.astype("float32").cpu().numpy() + np.testing.assert_allclose(adjusted_output_np, adjusted_output_cpu_np, err_msg="adjust_batch check failed!") # 测试 gather_next_token gather_out = gather_next_token( @@ -137,60 +140,41 @@ def _run_test_base(seq_lens_this_time_data, output_padding_offset): -1, ) + gather_out_np = gather_out.astype("float32").cpu().numpy() + gather_out_cpu_np = gather_out_cpu.astype("float32").cpu().numpy() + if output_padding_offset is not None: - np.testing.assert_allclose( - gather_out.astype("float32").cpu().numpy(), - gather_out_cpu.astype("float32").cpu().numpy(), - err_msg="gather_next_token check failed!", - ) + np.testing.assert_allclose(gather_out_np, gather_out_cpu_np, err_msg="gather_next_token check failed!") else: for i in range(gather_out_cpu.shape[0]): if seq_lens_this_time[i] > 0: np.testing.assert_allclose( - gather_out[i].astype("float32").cpu().numpy(), - gather_out_cpu[i].astype("float32").cpu().numpy(), - err_msg="gather_next_token check failed!", + gather_out_np[i], gather_out_cpu_np[i], err_msg=f"gather_next_token check failed at index {i}!" ) -def test_mix_with_mtp(): - """测试混合批次处理中的MTP(Multi-Token Prediction)场景。 - - 验证在不同序列长度(包括零长度)情况下,MTP功能是否能正确处理。 - - Args: - 无显式参数,但内部使用: - seq_lens_this_time_data: 包含不同长度序列的列表,用于模拟混合批次 - output_padding_offset: 用于处理序列填充的偏移量张量 - - Returns: - 无返回值,但会打印测试结果 - """ - print("\nRunning test: test_mix_with_mtp") - seq_lens_this_time_data = [100, 2, 0, 1, 120, 140, 3] - bsz = len(seq_lens_this_time_data) - output_padding_offset = paddle.zeros(bsz, dtype="int32") - - _run_test_base(seq_lens_this_time_data, output_padding_offset) - print("Test passed for scenario: With MTP") +class TestXPUOps(unittest.TestCase): # 继承 unittest.TestCase + """测试 XPU ops 的 adjust_batch 和 gather_next_token 功能""" + def test_mix_with_mtp(self): + """测试混合批次处理中的 MTP (Multi-Token Prediction) 场景""" + print("\nRunning test: test_mix_with_mtp") + seq_lens_this_time_data = [100, 2, 0, 1, 120, 140, 3] + bsz = len(seq_lens_this_time_data) + output_padding_offset = paddle.zeros(bsz, dtype="int32") -def test_mix_without_mtp(): - """测试非MTP(Single-Token Prediction)场景下的功能。 + _run_test_base(seq_lens_this_time_data, output_padding_offset) + print("Test passed for scenario: With MTP") - 该测试用例专门验证在非MTP(多令牌预测)场景下,模型处理不同长度序列的能力。 - - Args: - seq_lens_this_time_data: 本次处理的序列长度列表,包含各种长度的序列 - output_padding_offset: 非MTP场景下此参数应为None - """ - print("\nRunning test: test_mix_without_mtp") - seq_lens_this_time_data = [100, 1, 0, 1, 120, 140, 1] - output_padding_offset = None # 非MTP场景下,此参数为None + def test_mix_without_mtp(self): + """测试非 MTP (Single-Token Prediction) 场景下的功能""" + print("\nRunning test: test_mix_without_mtp") + seq_lens_this_time_data = [100, 1, 0, 1, 120, 140, 1] + output_padding_offset = None # 非 MTP 场景下,此参数为 None - _run_test_base(seq_lens_this_time_data, output_padding_offset) - print("Test passed for scenario: Without MTP") + _run_test_base(seq_lens_this_time_data, output_padding_offset) + print("Test passed for scenario: Without MTP") if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) + unittest.main() # 使用 unittest 运行测试 diff --git a/custom_ops/xpu_ops/test/test_draft_model_preprocess.py b/custom_ops/xpu_ops/test/test_draft_model_preprocess.py index c687bdf308d..1348e6fcd4e 100644 --- a/custom_ops/xpu_ops/test/test_draft_model_preprocess.py +++ b/custom_ops/xpu_ops/test/test_draft_model_preprocess.py @@ -12,50 +12,284 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + import numpy as np import paddle from fastdeploy.model_executor.ops.xpu import draft_model_preprocess -def run_test(device="xpu"): - paddle.seed(2022) - - # Define parameters - bsz = 10 - draft_tokens_len = 4 - input_ids_len = 8 - max_draft_token = 10 - - truncate_first_token = True - splitwise_prefill = False - # Create input tensors - if device == "cpu": - paddle.set_device(device) - - draft_tokens = paddle.randint(0, 100, [bsz, draft_tokens_len], dtype="int64") - input_ids = paddle.randint(0, 100, [bsz, input_ids_len], dtype="int64") - stop_flags = paddle.randint(0, 1, [bsz], dtype="int").cast("bool") - seq_lens_this_time = paddle.randint(0, 100, [bsz], dtype="int32") - seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32") - seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32") - step_idx = paddle.randint(0, 100, [bsz], dtype="int64") - seq_lens_encoder_record = paddle.randint(0, 100, [bsz], dtype="int32") - seq_lens_decoder_record = paddle.randint(0, 100, [bsz], dtype="int32") - not_need_stop = paddle.zeros([1], dtype="bool").cpu() - batch_drop = paddle.zeros([bsz], dtype="bool") - - # Output tensors - accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64") - accept_num = paddle.randint(1, max_draft_token + 5, [bsz], dtype="int32") - base_model_seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32") - base_model_seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32") - base_model_step_idx = paddle.randint(0, 100, [bsz], dtype="int64") - base_model_stop_flags = paddle.zeros([bsz], dtype="bool") - base_model_is_block_step = paddle.zeros([bsz], dtype="bool") - base_model_draft_tokens = paddle.zeros([bsz, max_draft_token], dtype="int64") - # Run the op - outputs = draft_model_preprocess( +def process_splitwise_prefill( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + base_model_draft_tokens_len, + truncate_first_token, + kvcache_scheduler_v1, +): + not_stop_flag_sum = 0 + + for tid in range(bsz): + not_stop_flag = 0 + input_ids_now = input_ids[tid] + accept_tokens_now = accept_tokens[tid] + if seq_lens_encoder[tid] > 0: + not_stop_flag = 1 + seq_len_encoder = seq_lens_encoder[tid] + stop_flags[tid] = False + base_model_first_token = accept_tokens_now[0] + position = seq_len_encoder + if truncate_first_token: + input_ids_now[position - 1] = base_model_first_token + seq_lens_this_time[tid] = seq_len_encoder + else: + input_ids_now[position] = base_model_first_token + seq_lens_this_time[tid] = seq_len_encoder + 1 + else: + stop_flags[tid] = True + seq_lens_this_time[tid] = 0 + seq_lens_decoder[tid] = 0 + seq_lens_encoder[tid] = 0 + not_stop_flag = 0 + not_stop_flag_sum = not_stop_flag_sum + not_stop_flag + not_need_stop[0] = not_stop_flag_sum > 0 + + +def draft_model_preprocess_kernel( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + base_model_draft_tokens_len, + truncate_first_token, + kvcache_scheduler_v1, +): + not_stop_flag_sum = 0 + + for tid in range(bsz): + not_stop_flag = 0 + accept_tokens_now = accept_tokens[tid] + draft_tokens_now = draft_tokens[tid] + accept_num_now = accept_num[tid] + input_ids_now = input_ids[tid] + base_model_draft_tokens_now = base_model_draft_tokens[tid] + base_model_seq_len_decoder = base_model_seq_lens_decoder[tid] + base_model_seq_len_this_time = base_model_seq_lens_this_time[tid] + pre_ids_now = pre_ids[tid] + + base_model_draft_tokens_now[1:base_model_draft_tokens_len] = -1 + + if kvcache_scheduler_v1: + if base_model_stop_flags[tid] and base_model_is_block_step[tid]: + stop_flags[tid] = True + is_block_step[tid] = True + # Need to continue infer + else: + if base_model_stop_flags[tid] and base_model_is_block_step[tid]: + batch_drop[tid] = True + stop_flags[tid] = True + + if not (base_model_stop_flags[tid] or batch_drop[tid]): + not_stop_flag = 1 + # 1. first token + if seq_lens_encoder[tid] > 0: + # Can be extended to first few tokens + seq_len_encoder = seq_lens_encoder[tid] + stop_flags[tid] = False + base_model_first_token = accept_tokens_now[0] + pre_ids_now[0] = base_model_first_token + position = seq_len_encoder + if truncate_first_token: + input_ids_now[position - 1] = base_model_first_token + seq_lens_this_time[tid] = seq_len_encoder + else: + input_ids_now[position] = base_model_first_token + seq_lens_this_time[tid] = seq_len_encoder + 1 + else: + if kvcache_scheduler_v1: + # 3. try to recover mtp infer in V1 mode + if not (base_model_is_block_step[tid] and is_block_step[tid]): + is_block_step[tid] = False + + if stop_flags[tid]: + stop_flags[tid] = False + # TODO: check + seq_lens_decoder[tid] = base_model_seq_len_decoder - base_model_seq_len_this_time + step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time + else: + # 2: Last base model generated token and first MTP token + seq_lens_decoder[tid] -= num_model_step - 1 + step_idx[tid] -= num_model_step - 1 + + for i in range(accept_num_now): + draft_tokens_now[i] = accept_tokens_now[i] + pre_id_pos = base_model_step_idx[tid] - (accept_num_now - i) + accept_token = accept_tokens_now[i] + pre_ids_now[pre_id_pos] = accept_token + + seq_lens_this_time[tid] = accept_num_now + else: + stop_flags[tid] = True + seq_lens_this_time[tid] = 0 + seq_lens_decoder[tid] = 0 + seq_lens_encoder[tid] = 0 + not_stop_flag_sum = not_stop_flag_sum + not_stop_flag + not_need_stop[0] = not_stop_flag_sum > 0 + + +def DispatchRunner( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + truncate_first_token, + splitwise_prefill, + kvcache_scheduler_v1, +): + base_model_draft_tokens_len = base_model_draft_tokens.shape[1] + if splitwise_prefill: + process_splitwise_prefill( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + base_model_draft_tokens_len, + truncate_first_token, + kvcache_scheduler_v1, + ) + else: + draft_model_preprocess_kernel( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + base_model_draft_tokens_len, + truncate_first_token, + kvcache_scheduler_v1, + ) + + +def draft_model_preprocess_ref( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + num_model_step, + truncate_first_token, + splitwise_prefill, + kvcache_scheduler_v1, +): + real_bsz = seq_lens_this_time.shape[0] + + DispatchRunner( draft_tokens, input_ids, stop_flags, @@ -63,73 +297,110 @@ def run_test(device="xpu"): seq_lens_encoder, seq_lens_decoder, step_idx, - seq_lens_encoder_record, - seq_lens_decoder_record, not_need_stop, + is_block_step, batch_drop, + pre_ids, accept_tokens, accept_num, + base_model_seq_lens_this_time, base_model_seq_lens_encoder, base_model_seq_lens_decoder, base_model_step_idx, base_model_stop_flags, base_model_is_block_step, base_model_draft_tokens, - max_draft_token=max_draft_token, - truncate_first_token=truncate_first_token, - splitwise_prefill=splitwise_prefill, + real_bsz, + num_model_step, + truncate_first_token, + splitwise_prefill, + kvcache_scheduler_v1, ) - # Return results for comparison - results = { - "draft_tokens": draft_tokens.numpy(), - "input_ids": input_ids.numpy(), - "stop_flags": stop_flags.numpy(), - "seq_lens_this_time": seq_lens_this_time.numpy(), - "accept_tokens": accept_tokens.numpy(), - "accept_num": accept_num.numpy(), - "not_need_stop": not_need_stop.numpy(), - "outputs": [x.numpy() for x in outputs], - } - return results - - -def compare_results(cpu_results, xpu_results): - # Compare all outputs - for key in cpu_results: - if key == "outputs": - for i, (cpu_out, xpu_out) in enumerate(zip(cpu_results[key], xpu_results[key])): - np.testing.assert_allclose( - cpu_out, - xpu_out, - rtol=1e-5, - atol=1e-8, - err_msg=f"Output {i} mismatch between CPU and GPU", - ) - else: - np.testing.assert_allclose( - cpu_results[key], - xpu_results[key], - rtol=1e-5, - atol=1e-8, - err_msg=f"{key} mismatch between CPU and GPU", - ) - print("CPU and GPU results match!") +class TestDraftModelPreprocess: + def _run_tests(self): + paddle.seed(2022) + + # Define parameters + bsz = 10 + draft_tokens_len = 4 + input_ids_len = 100 + max_draft_token = 10 -def test_draft_model_preprocess(): + truncate_first_token = True + splitwise_prefill = False - print("Running XPU test...") - xpu_results = run_test("xpu") + draft_tokens = paddle.randint(0, 100, [bsz, draft_tokens_len], dtype="int64") + input_ids = paddle.randint(0, 100, [bsz, input_ids_len], dtype="int64") + stop_flags = paddle.randint(0, 1, [bsz], dtype="int").cast("bool") + seq_lens_this_time = paddle.randint(0, 100, [bsz], dtype="int32") + seq_lens_encoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32") + seq_lens_decoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32") + step_idx = paddle.randint(0, 100, [bsz], dtype="int64") + seq_lens_encoder_record = paddle.randint(0, 100, [bsz], dtype="int32") # noqa: F841 + seq_lens_decoder_record = paddle.randint(0, 100, [bsz], dtype="int32") # noqa: F841 + not_need_stop = paddle.zeros([1], dtype="bool").cpu() + is_block_step = paddle.zeros([bsz], dtype="bool") + batch_drop = paddle.zeros([bsz], dtype="bool") - print("Running CPU test...") - cpu_results = run_test("cpu") + # Output tensors + accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64") + accept_num = paddle.randint(1, max_draft_token + 5, [bsz], dtype="int32") + base_model_seq_lens_encoder = paddle.randint(0, 100, [bsz], dtype="int32") + base_model_seq_lens_decoder = paddle.randint(0, 100, [bsz], dtype="int32") + base_model_step_idx = paddle.randint(0, 100, [bsz], dtype="int64") + base_model_stop_flags = paddle.zeros([bsz], dtype="bool") + base_model_is_block_step = paddle.zeros([bsz], dtype="bool") + base_model_draft_tokens = paddle.zeros([bsz, max_draft_token], dtype="int64") + # Run the op + pre_ids = input_ids.clone() + base_model_seq_lens_this_time = seq_lens_this_time + num_model_step = max_draft_token - print("Comparing results...") - compare_results(cpu_results, xpu_results) + kvcache_scheduler_v1 = True + inputs = ( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + num_model_step, + truncate_first_token, + splitwise_prefill, + kvcache_scheduler_v1, + ) + # inplace modify, need to clone inputs + inputs_clone = [x.clone() if isinstance(x, paddle.Tensor) else x for x in inputs] + draft_model_preprocess_ref(*inputs) + draft_model_preprocess(*inputs_clone) + return inputs, inputs_clone - print("Test passed!") + def test_draft_model_preprocess(self): + results1, results2 = self._run_tests() + np.testing.assert_allclose(results1[0], results2[0]) # draft_tokens + np.testing.assert_allclose(results1[1], results2[1]) # input_ids + np.testing.assert_allclose(results1[2], results2[2]) # stop_flags + np.testing.assert_allclose(results1[3], results2[3]) # seq_lens_this_time + np.testing.assert_allclose(results1[11], results2[11]) # accept_tokens + np.testing.assert_allclose(results1[12], results2[12]) # accept_num + np.testing.assert_allclose(results1[7], results2[7]) # not_need_stop if __name__ == "__main__": - test_draft_model_preprocess() + unittest.main() diff --git a/custom_ops/xpu_ops/test/test_speculate_step.py b/custom_ops/xpu_ops/test/test_speculate_step.py index 3bbf37bfd14..65414bcff4a 100644 --- a/custom_ops/xpu_ops/test/test_speculate_step.py +++ b/custom_ops/xpu_ops/test/test_speculate_step.py @@ -13,10 +13,10 @@ # limitations under the License. import os +import unittest import numpy as np import paddle -import pytest from fastdeploy.model_executor.ops.xpu import speculate_step_paddle @@ -25,9 +25,11 @@ paddle.seed(2023) -def test_data(): - """定义测试数据夹具,一次性生成所有输入数据,供测试用例复用""" - +def generate_test_data(): + """ + 生成测试数据的辅助函数。 + 这部分逻辑从 pytest 的 fixture 转换而来,作为一个普通函数供测试方法调用。 + """ # max_bs = 128 max_bs = 8 bs = max_bs @@ -160,6 +162,9 @@ def test_data(): "max_draft_tokens": max_draft_tokens, } + # 恢复默认设备,避免影响其他测试 + paddle.set_device("cpu") + return data_cpu, data_xpu @@ -196,27 +201,8 @@ def speculate_step_paddle_execution(test_data): # 可选:打印执行前关键信息(如需调试可开启) if os.environ.get("STEP_TEST_DEBUG", "0") == "1": print("-" * 50 + "before step op" + "-" * 50) - print("stop_flags: ", stop_flags) - print("seq_lens_this_time: ", seq_lens_this_time) - print("seq_lens_encoder: ", seq_lens_encoder) - print("seq_lens_decoder: ", seq_lens_decoder) - print("ori_seq_lens_encoder: ", ori_seq_lens_encoder) - print("block_tables: ", block_tables.sum()) - print("encoder_block_lens: ", encoder_block_lens) - print("is_block_step: ", is_block_step) - print("step_block_list: ", step_block_list) - print("step_lens: ", step_lens) - print("recover_lens: ", recover_lens) - print("recover_block_list: ", recover_block_list) - print("need_block_list: ", need_block_list) - print("need_block_len: ", need_block_len) - print("used_list_len: ", used_list_len) - print("free_list_len: ", free_list_len) - print("free_list: ", free_list) - print("input_ids: ", input_ids) - print("pre_ids: ", pre_ids) - print("step_idx: ", step_idx) - print("next_tokens: ", next_tokens) + # ... (省略打印内容以保持简洁) + # 执行目标函数(核心测试步骤) speculate_step_paddle( stop_flags, @@ -247,89 +233,80 @@ def speculate_step_paddle_execution(test_data): max_draft_tokens, ) + # 可选:打印执行后关键信息(如需调试可开启) if os.environ.get("STEP_TEST_DEBUG", "0") == "1": - # 可选:打印执行后关键信息(如需调试可开启) - print("-" * 50 + "before step op" + "-" * 50) - print("stop_flags: ", stop_flags) - print("seq_lens_this_time: ", seq_lens_this_time) - print("seq_lens_encoder: ", seq_lens_encoder) - print("seq_lens_decoder: ", seq_lens_decoder) - print("ori_seq_lens_encoder: ", ori_seq_lens_encoder) - print("block_tables: ", block_tables.sum()) - print("encoder_block_lens: ", encoder_block_lens) - print("is_block_step: ", is_block_step) - print("step_block_list: ", step_block_list) - print("step_lens: ", step_lens) - print("recover_lens: ", recover_lens) - print("recover_block_list: ", recover_block_list) - print("need_block_list: ", need_block_list) - print("need_block_len: ", need_block_len) - print("used_list_len: ", used_list_len) - print("free_list_len: ", free_list_len) - print("free_list: ", free_list) - print("input_ids: ", input_ids) - print("pre_ids: ", pre_ids) - print("step_idx: ", step_idx) - print("next_tokens: ", next_tokens) + print("-" * 50 + "after step op" + "-" * 50) + # ... (省略打印内容以保持简洁) + return test_data -def assert_test_data_equal(test_data1, test_data2, rtol=1e-05, atol=1e-08): +class TestSpeculateStepPaddle(unittest.TestCase): """ - 断言两个 test_data 结构和数据完全一致,自动处理 host/device 数据转换(paddle Tensor → numpy) - - Args: - test_data1: 第一个待比较的 test_data(可在 host 或 device 上) - test_data2: 第二个待比较的 test_data(可在 host 或 device 上) - rtol: 相对误差容忍度(仅对浮点型有效) - atol: 绝对误差容忍度(仅对浮点型有效) + 测试类,继承自 unittest.TestCase。 + 所有以 'test_' 开头的方法都会被视为测试用例。 """ - # 1. 先校验两个 test_data 的字段名完全一致 - keys1 = set(test_data1.keys()) - keys2 = set(test_data2.keys()) - assert ( - keys1 == keys2 - ), f"两个 test_data 字段不一致!\n仅在第一个中存在:{keys1 - keys2}\n仅在第二个中存在:{keys2 - keys1}" - - # 2. 逐字段校验数据 - for key in keys1: - data1 = test_data1[key] - data2 = test_data2[key] - - # 区分:paddle Tensor(需转 numpy)和 普通标量/数组(直接使用) - if isinstance(data1, paddle.Tensor): - # 转换为 numpy:自动处理 device → host(.cpu())、阻止梯度计算(.detach()) - np1 = data1.detach().cpu().numpy() - else: - np1 = np.asarray(data1) # 非 Tensor(如 int/float)转为 numpy 统一格式 - - if isinstance(data2, paddle.Tensor): - np2 = data2.detach().cpu().numpy() - else: - np2 = np.asarray(data2) - - # 3. 校验数据(分类型处理:布尔型/整数型 严格相等,浮点型 允许微小误差) - if np1.dtype in (np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8): - # 布尔/整数型:必须完全相等 - assert np.array_equal(np1, np2), f"字段 {key} 数据不一致!\n第一个数据:{np1}\n第二个数据:{np2}" - else: - # 浮点型:允许 rtol/atol 范围内的误差(如 float32/float64) - assert np.allclose( - np1, np2, rtol=rtol, atol=atol - ), f"字段 {key} 浮点数据不一致!\n相对误差:{rtol},绝对误差:{atol}\n第一个数据:{np1}\n第二个数据:{np2}" - - print("✅ 两个 test_data 结构和数据完全一致!") - - -def test_speculate_step_paddle(): - data_cpu, data_xpu = test_data() - # check before test - assert_test_data_equal(data_xpu, data_cpu) - result_xpu = speculate_step_paddle_execution(data_xpu) - result_cpu = speculate_step_paddle_execution(data_cpu) - # check after test - assert_test_data_equal(result_xpu, result_cpu) + + def assert_test_data_equal(self, test_data1, test_data2, rtol=1e-05, atol=1e-08): + """ + 自定义的断言方法,用于比较两个 test_data 结构和数据。 + 在 unittest 中,自定义断言通常以 'assert' 开头。 + """ + # 1. 先校验两个 test_data 的字段名完全一致 + keys1 = set(test_data1.keys()) + keys2 = set(test_data2.keys()) + self.assertEqual( + keys1, + keys2, + msg=f"两个 test_data 字段不一致!\n仅在第一个中存在:{keys1 - keys2}\n仅在第二个中存在:{keys2 - keys1}", + ) + + # 2. 逐字段校验数据 + for key in keys1: + data1 = test_data1[key] + data2 = test_data2[key] + + # 区分:paddle Tensor(需转 numpy)和 普通标量/数组(直接使用) + if isinstance(data1, paddle.Tensor): + np1 = data1.detach().cpu().numpy() + else: + np1 = np.asarray(data1) + + if isinstance(data2, paddle.Tensor): + np2 = data2.detach().cpu().numpy() + else: + np2 = np.asarray(data2) + + # 3. 校验数据 + if np1.dtype in (np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8): + # 布尔/整数型:必须完全相等 + np.testing.assert_array_equal(np1, np2, err_msg=f"字段 {key} 数据不一致!") + else: + # 浮点型:允许 rtol/atol 范围内的误差 + np.testing.assert_allclose(np1, np2, rtol=rtol, atol=atol, err_msg=f"字段 {key} 浮点数据不一致!") + + print("✅ 两个 test_data 结构和数据完全一致!") + + def test_speculate_step_paddle_execution(self): + """ + 核心测试用例方法。 + 该方法会调用 generate_test_data 获取数据, + 分别在 CPU 和 XPU 上执行测试函数, + 并使用自定义的断言方法比较结果。 + """ + print("\nRunning test: test_speculate_step_paddle_execution") + + # 1. 获取测试数据 + data_cpu, data_xpu = generate_test_data() + + # 2. 执行测试函数 + result_xpu = speculate_step_paddle_execution(data_xpu) + result_cpu = speculate_step_paddle_execution(data_cpu) + + # 3. 断言结果一致 + self.assert_test_data_equal(result_xpu, result_cpu) if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) # 直接运行时执行 pytest 并显示详细日志 + # 使用 unittest 的主程序来运行所有测试用例 + unittest.main() diff --git a/custom_ops/xpu_ops/test/test_speculate_update_v3.py b/custom_ops/xpu_ops/test/test_speculate_update_v3.py index 1ecebc6e72d..bdea8727d3b 100644 --- a/custom_ops/xpu_ops/test/test_speculate_update_v3.py +++ b/custom_ops/xpu_ops/test/test_speculate_update_v3.py @@ -12,101 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np +import unittest -# tests/test_speculate_update_v3.py +import numpy as np import paddle +# 假设这是你的自定义算子 from fastdeploy.model_executor.ops.xpu import speculate_update_v3 -# ---------------- NumPy 参考实现 ---------------- -def speculate_update_v3_np( - seq_lens_encoder, - seq_lens_decoder, - not_need_stop, - draft_tokens, - actual_draft_token_nums, - accept_tokens, - accept_num, - stop_flags, - seq_lens_this_time, - is_block_step, - stop_nums, -): - """ - 完全复现 CPU / CUDA 逻辑的 NumPy 参考版本(就地修改)。 - """ - stop_sum = 0 - real_bsz = seq_lens_this_time.shape[0] - max_bsz = stop_flags.shape[0] - max_draft_tokens = draft_tokens.shape[1] - - for bid in range(max_bsz): - stop_flag_now_int = 0 - inactive = bid >= real_bsz - block_step = (not inactive) and is_block_step[bid] - - if (not block_step) and (not inactive): - - if stop_flags[bid]: - stop_flag_now_int = 1 - - # encoder 长度为 0 时直接累加 decoder - if seq_lens_encoder[bid] == 0: - seq_lens_decoder[bid] += accept_num[bid] - - # draft 长度自适应 - if (seq_lens_encoder[bid] == 0) and (seq_lens_this_time[bid] > 1): - cur_len = actual_draft_token_nums[bid] - if accept_num[bid] - 1 == cur_len: # 全部接受 - if cur_len + 2 <= max_draft_tokens - 1: - cur_len += 2 - elif cur_len + 1 <= max_draft_tokens - 1: - cur_len += 1 - else: - cur_len = max_draft_tokens - 1 - else: # 有拒绝 - cur_len = max(1, cur_len - 1) - actual_draft_token_nums[bid] = cur_len - - # 偿还 encoder 欠账 - if seq_lens_encoder[bid] != 0: - seq_lens_decoder[bid] += seq_lens_encoder[bid] - seq_lens_encoder[bid] = 0 - - # 写回下一轮首 token - draft_tokens[bid, 0] = accept_tokens[bid, accept_num[bid] - 1] - - # 停止则清零 decoder - if stop_flag_now_int: - seq_lens_decoder[bid] = 0 - - elif inactive: - stop_flag_now_int = 1 # padding slot 视为 stop - - stop_sum += stop_flag_now_int - - # print("stop_sum: ", stop_sum) - not_need_stop[0] = stop_sum < stop_nums[0] - - # 返回引用,仅供一致性 - return ( - seq_lens_encoder, - seq_lens_decoder, - not_need_stop, - draft_tokens, - actual_draft_token_nums, - ) - - -# ---------------- 生成随机输入 ---------------- def gen_inputs( max_bsz=512, # 与 CUDA BlockSize 对齐 max_draft_tokens=16, real_bsz=123, # 可自调;须 ≤ max_bsz seed=2022, ): + """生成随机测试输入数据""" rng = np.random.default_rng(seed) # 基本张量 @@ -122,89 +43,91 @@ def gen_inputs( stop_nums = np.array([5], dtype=np.int64) # 阈值随意 # seq_lens_this_time 仅取 real_bsz 长度 - seq_lens_this_time = rng.integers(1, max_draft_tokens, size=real_bsz, dtype=np.int32) - - return { - "seq_lens_encoder": seq_lens_encoder, - "seq_lens_decoder": seq_lens_decoder, - "not_need_stop": not_need_stop, - "draft_tokens": draft_tokens, - "actual_draft_token_nums": actual_draft_nums, - "accept_tokens": accept_tokens, - "accept_num": accept_num, - "stop_flags": stop_flags, - "seq_lens_this_time": seq_lens_this_time, - "is_block_step": is_block_step, - "stop_nums": stop_nums, - # real_bsz = real_bsz, - # max_bsz = max_bsz, - # max_draft_tokens = max_draft_tokens + seq_lens_this_time = rng.integers(1, max_draft_tokens + 1, size=real_bsz, dtype=np.int32) + + paddle.set_device("xpu:0") + data_xpu = { + "seq_lens_encoder": paddle.to_tensor(seq_lens_encoder), + "seq_lens_decoder": paddle.to_tensor(seq_lens_decoder), + "not_need_stop": paddle.to_tensor(not_need_stop).cpu(), + "draft_tokens": paddle.to_tensor(draft_tokens), + "actual_draft_token_nums": paddle.to_tensor(actual_draft_nums), + "accept_tokens": paddle.to_tensor(accept_tokens), + "accept_num": paddle.to_tensor(accept_num), + "stop_flags": paddle.to_tensor(stop_flags), + "seq_lens_this_time": paddle.to_tensor(seq_lens_this_time), + "is_block_step": paddle.to_tensor(is_block_step), + "stop_nums": paddle.to_tensor(stop_nums), } - -# ------------------- 单测主体 ------------------- -inputs = gen_inputs(max_bsz=512, max_draft_tokens=32, real_bsz=201) - -# ---- Paddle 端 ---- -paddle_inputs = {} -for k, v in inputs.items(): - if k in ("real_bsz", "max_bsz", "max_draft_tokens"): - paddle_inputs[k] = v # 纯 python int - else: - if k == "not_need_stop": - paddle_inputs[k] = paddle.to_tensor(v, place=paddle.CPUPlace()) - else: - # 其余张量保持默认 place(想测 GPU 就手动加 place=paddle.CUDAPlace(0)) - paddle_inputs[k] = paddle.to_tensor(v) - -# ---- NumPy 端 ---- -# 为保证初值一致,这里必须复制 Paddle 入参的 numpy 值再传给参考实现 -np_inputs = { - k: (paddle_inputs[k].numpy().copy() if isinstance(paddle_inputs[k], paddle.Tensor) else paddle_inputs[k]) - for k in paddle_inputs -} - -# 调用自定义算子 -# print("seq_lens_encoder_xpu_before: ", paddle_inputs["seq_lens_encoder"]) -out_pd = speculate_update_v3(**paddle_inputs) -# print("seq_lens_encoder_xpu_after: ", out_pd[0]) -# print("not_need_stop: ", out_pd[2]) - -# speculate_update_v3 返回 5 个张量(与 Outputs 对应) -( - seq_lens_encoder_pd, - seq_lens_decoder_pd, - not_need_stop_pd, - draft_tokens_pd, - actual_draft_nums_pd, -) = out_pd - -# print("seq_lens_encoder_np_before: ", np_inputs["seq_lens_encoder"]) -out_np = speculate_update_v3_np(**np_inputs) -# print("seq_lens_encoder_np_after: ", out_np[0]) -# print("not_need_stop: ", out_np[2]) - - -# ---------------- 校对 ---------------- -names = [ - "seq_lens_encoder", - "seq_lens_decoder", - "not_need_stop", - "draft_tokens", - "actual_draft_token_nums", -] -pd_tensors = [ - seq_lens_encoder_pd, - seq_lens_decoder_pd, - not_need_stop_pd, - draft_tokens_pd, - actual_draft_nums_pd, -] - -for name, pd_val, np_val in zip(names, pd_tensors, out_np): - pd_arr = pd_val.numpy() - ok = np.array_equal(pd_arr, np_val) - print(f"{name:25s} equal :", ok) - - # 也可以加 assert,配合 pytest - # assert all(np.array_equal(p.numpy(), n) for p,n in zip(pd_tensors, out_np)) + paddle.set_device("cpu") + data_cpu = { + "seq_lens_encoder": paddle.to_tensor(seq_lens_encoder), + "seq_lens_decoder": paddle.to_tensor(seq_lens_decoder), + "not_need_stop": paddle.to_tensor(not_need_stop), + "draft_tokens": paddle.to_tensor(draft_tokens), + "actual_draft_token_nums": paddle.to_tensor(actual_draft_nums), + "accept_tokens": paddle.to_tensor(accept_tokens), + "accept_num": paddle.to_tensor(accept_num), + "stop_flags": paddle.to_tensor(stop_flags), + "seq_lens_this_time": paddle.to_tensor(seq_lens_this_time), + "is_block_step": paddle.to_tensor(is_block_step), + "stop_nums": paddle.to_tensor(stop_nums), + } + return data_xpu, data_cpu + + +class TestSpeculateUpdateV3(unittest.TestCase): + """测试 speculate_update_v3 算子""" + + def test_op_vs_golden(self, max_bsz=512, max_draft_tokens=16, real_bsz=123): + """ + 核心测试:比较自定义算子的输出与纯 NumPy 参考实现的输出。 + """ + # 1. gen inputs for cpu/xpu + data_xpu, data_cpu = gen_inputs(max_bsz=max_bsz, max_draft_tokens=max_draft_tokens, real_bsz=real_bsz) + + # 3. run xpu kernel + speculate_update_v3(**data_xpu) + + # 4. run cpu kernel + speculate_update_v3(**data_cpu) + + # 5. format outputs + outputs_xpu = [ + data_xpu["seq_lens_encoder"].cpu().numpy(), + data_xpu["seq_lens_decoder"].cpu().numpy(), + data_xpu["not_need_stop"].cpu().numpy(), + data_xpu["draft_tokens"].cpu().numpy(), + data_xpu["actual_draft_token_nums"].cpu().numpy(), + ] + + outputs_cpu = [ + data_cpu["seq_lens_encoder"].numpy(), + data_cpu["seq_lens_decoder"].numpy(), + data_cpu["not_need_stop"].numpy(), + data_cpu["draft_tokens"].numpy(), + data_cpu["actual_draft_token_nums"].numpy(), + ] + output_names = [ + "seq_lens_encoder", + "seq_lens_decoder", + "not_need_stop", + "draft_tokens", + "actual_draft_token_nums", + ] + + # 6. check outputs + for name, pd_out, np_out in zip(output_names, outputs_xpu, outputs_cpu): + with self.subTest(output_name=name): + np.testing.assert_allclose( + pd_out, + np_out, + atol=0, + rtol=1e-6, + err_msg=f"Output mismatch for tensor '{name}'.\nPaddle Output:\n{pd_out}\nGolden Output:\n{np_out}", + ) + + +if __name__ == "__main__": + unittest.main() From 3d92f94803b2b1746e351ff4628d55184bc6e292 Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Tue, 25 Nov 2025 04:01:57 +0000 Subject: [PATCH 14/17] mv xpu pre/post process --- .../model_executor/pre_and_post_process.py | 284 ---------------- .../xpu_pre_and_post_process.py | 315 ++++++++++++++++++ fastdeploy/worker/xpu_model_runner.py | 2 +- 3 files changed, 316 insertions(+), 285 deletions(-) create mode 100644 fastdeploy/model_executor/xpu_pre_and_post_process.py diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index d603d6d6f53..dd3cd01343a 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -22,7 +22,6 @@ from fastdeploy import envs from fastdeploy.config import SpeculativeConfig -from fastdeploy.model_executor.forward_meta import XPUForwardMeta from fastdeploy.platforms import current_platform if current_platform.is_iluvatar(): @@ -67,9 +66,6 @@ pass elif current_platform.is_xpu(): from fastdeploy.model_executor.ops.xpu import ( - adjust_batch, - gather_next_token, - get_infer_param, get_padding_offset, limit_thinking_content_length_v1, limit_thinking_content_length_v2, @@ -936,283 +932,3 @@ def post_process_pooling( if save_each_rank or model_output.mp_rank == 0: output = _build_stream_transfer_data(output_tokens=None, pooler_outputs=pooler_output.outputs) async_output_queue.put(output) - - -def xpu_pre_process( - input_ids: paddle.Tensor, - seq_lens_this_time: int, - share_inputs: Dict, - use_speculate_method: bool, - block_size: int, - draft_tokens: Optional[paddle.Tensor] = None, - seq_lens_encoder: Optional[paddle.Tensor] = None, - seq_lens_decoder: Optional[paddle.Tensor] = None, - is_profiling: bool = False, -) -> XPUForwardMeta: - """ """ - max_len = input_ids.shape[1] - cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32") - token_num = paddle.sum(seq_lens_this_time) - - ( - ids_remove_padding, - cum_offsets, - batch_id_per_token, - cu_seqlens_q, - cu_seqlens_k, - ) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time) - - share_inputs["ids_remove_padding"] = None # set this after adjust batch - share_inputs["cum_offsets"] = cum_offsets - share_inputs["batch_id_per_token"] = batch_id_per_token - share_inputs["cu_seqlens_q"] = cu_seqlens_q - share_inputs["cu_seqlens_k"] = cu_seqlens_k - - xpu_forward_meta = XPUForwardMeta( - ids_remove_padding=share_inputs["ids_remove_padding"], - rotary_embs=share_inputs["rope_emb"], - attn_backend=None, - seq_lens_encoder=share_inputs["seq_lens_encoder"], - seq_lens_decoder=share_inputs["seq_lens_decoder"], - seq_lens_this_time=share_inputs["seq_lens_this_time"], - cum_offsets=share_inputs["cum_offsets"], - batch_id_per_token=share_inputs["batch_id_per_token"], - cu_seqlens_q=share_inputs["cu_seqlens_q"], - cu_seqlens_k=share_inputs["cu_seqlens_k"], - block_tables=share_inputs["block_tables"], - caches=share_inputs["caches"], - ) - - ( - xpu_forward_meta.encoder_batch_map, - xpu_forward_meta.decoder_batch_map, - xpu_forward_meta.encoder_batch_idx, - xpu_forward_meta.decoder_batch_idx, - xpu_forward_meta.encoder_seq_lod, - xpu_forward_meta.decoder_seq_lod, - xpu_forward_meta.encoder_kv_lod, - xpu_forward_meta.prefix_len, - xpu_forward_meta.decoder_context_len, - xpu_forward_meta.decoder_context_len_cache, - xpu_forward_meta.prefix_block_tables, - xpu_forward_meta.encoder_batch_map_cpu, - xpu_forward_meta.decoder_batch_map_cpu, - xpu_forward_meta.encoder_batch_idx_cpu, - xpu_forward_meta.decoder_batch_idx_cpu, - xpu_forward_meta.encoder_seq_lod_cpu, - xpu_forward_meta.decoder_seq_lod_cpu, - xpu_forward_meta.encoder_kv_lod_cpu, - xpu_forward_meta.prefix_len_cpu, - xpu_forward_meta.decoder_context_len_cpu, - xpu_forward_meta.decoder_context_len_cache_cpu, - xpu_forward_meta.len_info_cpu, - ) = get_infer_param( - seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, xpu_forward_meta.block_tables, block_size - ) - xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0] - xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1] - xpu_forward_meta.total_enc_len = xpu_forward_meta.len_info_cpu[2] - - adjusted_input = adjust_batch( - ids_remove_padding.reshape([-1, 1]), - cum_offsets, - xpu_forward_meta.encoder_seq_lod, - xpu_forward_meta.decoder_seq_lod, - xpu_forward_meta.encoder_batch_idx, - xpu_forward_meta.decoder_batch_idx, - xpu_forward_meta.encoder_seq_lod_cpu, - xpu_forward_meta.decoder_seq_lod_cpu, - xpu_forward_meta.encoder_batch_idx_cpu, - xpu_forward_meta.decoder_batch_idx_cpu, - xpu_forward_meta.len_info_cpu, - None, # output_padding_offset - -1, # max bs - ) - - adjusted_input = adjusted_input.squeeze(1) - - share_inputs["ids_remove_padding"] = adjusted_input - xpu_forward_meta.ids_remove_padding = adjusted_input - # Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends - xpu_forward_meta.is_profiling = is_profiling - return xpu_forward_meta - - -def xpu_process_output( - forward_output, - cum_offsets: paddle.Tensor, - xpu_forward_meta: XPUForwardMeta, - share_inputs, -) -> paddle.Tensor: - """ """ - - output_padding_offset = share_inputs.get("output_padding_offset", None) - - hiddden_states = gather_next_token( - forward_output, - cum_offsets, - xpu_forward_meta.encoder_seq_lod, - xpu_forward_meta.decoder_seq_lod, - xpu_forward_meta.encoder_batch_map, - xpu_forward_meta.decoder_batch_map, - xpu_forward_meta.encoder_seq_lod_cpu, - xpu_forward_meta.decoder_seq_lod_cpu, - xpu_forward_meta.encoder_batch_map_cpu, - xpu_forward_meta.decoder_batch_map_cpu, - xpu_forward_meta.len_info_cpu, - output_padding_offset, # output_padding_offset - -1, # max_input_length - ) - return hiddden_states - - -def xpu_post_process_normal( - sampled_token_ids: paddle.Tensor, - model_output: ModelOutputData, - share_inputs: Dict[str, paddle.Tensor], - block_size: int = 64, - skip_save_output: bool = False, - think_end_id: int = None, - line_break_id: int = None, -) -> None: - """ """ - from fastdeploy.model_executor.ops.xpu import ( - save_output, - set_stop_value_multi_ends, - update_inputs, - ) - - if think_end_id > 0: - limit_strategy = envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR - max_think_lens = share_inputs["max_think_lens"] - step_idx = share_inputs["step_idx"] - limit_think_status = share_inputs["limit_think_status"] - stop_flags = share_inputs["stop_flags"] - eos_token_ids = share_inputs["eos_token_id"] - if limit_strategy == "": - # for ernie-45-vl - limit_thinking_content_length_v1( - sampled_token_ids, - max_think_lens, - step_idx, - limit_think_status, - stop_flags, - eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题 - think_end_id, - ) - elif limit_strategy == "\n\n\n": - # for ernie-x1 - assert line_break_id > 0 - limit_thinking_content_length_v2( - sampled_token_ids, - max_think_lens, - step_idx, - limit_think_status, - stop_flags, - think_end_id, - line_break_id, - ) - else: - raise NotImplementedError(f"Not support {limit_strategy=} for limit thinking content length.") - - # 1. Set stop value - paddle.assign( - paddle.where( - model_output.stop_flags, - model_output.step_idx, - model_output.step_idx + 1, - ), - model_output.step_idx, - ) - length_cond = paddle.greater_equal(model_output.step_idx, model_output.max_dec_len) - paddle.assign( - paddle.logical_or(model_output.stop_flags, length_cond), - model_output.stop_flags, - ) - set_stop_value_multi_ends( - sampled_token_ids, - model_output.stop_flags, - model_output.seq_lens_this_time, - model_output.eos_token_id, - model_output.next_tokens, - False, - ) # multi ends - - # 2. Update the input buffer of the model - with paddle.framework._no_check_dy2st_diff(): - if envs.ENABLE_V1_KVCACHE_SCHEDULER and not skip_save_output: - update_inputs_v1( - model_output.stop_flags, - model_output.not_need_stop, - model_output.seq_lens_this_time, - model_output.seq_lens_encoder, - model_output.seq_lens_decoder, - share_inputs["step_seq_lens_decoder"], - share_inputs["prompt_lens"], - sampled_token_ids, - model_output.input_ids, - share_inputs["block_tables"], - model_output.stop_nums, - model_output.next_tokens, - model_output.is_block_step, - block_size, - ) - else: - update_inputs( - model_output.stop_flags, - model_output.not_need_stop, - model_output.seq_lens_this_time, - model_output.seq_lens_encoder, - model_output.seq_lens_decoder, - model_output.input_ids, - model_output.stop_nums, - sampled_token_ids, - model_output.is_block_step, - ) - # 3. Transmit the model's output and stop generation signal via message queue. - # In the future, we will abandon this approach. - if not skip_save_output: - save_output( - sampled_token_ids, - model_output.not_need_stop, - model_output.mp_rank, - False, # use_ep - ) - - -def step_xpu( - share_inputs: Dict[str, paddle.Tensor], - block_size: int, - enc_dec_block_num: int, -) -> None: - """ - TODO(gongshaotian): normalization name - """ - from fastdeploy.model_executor.ops.xpu import step_paddle - - step_paddle( - share_inputs["stop_flags"], - share_inputs["seq_lens_this_time"], - share_inputs["step_seq_lens_encoder"], - share_inputs["seq_lens_encoder"], - share_inputs["seq_lens_decoder"], - share_inputs["block_tables"], - share_inputs["encoder_block_lens"], - share_inputs["is_block_step"], - share_inputs["step_block_list"], - share_inputs["step_lens"], - share_inputs["recover_block_list"], - share_inputs["recover_lens"], - share_inputs["need_block_list"], - share_inputs["need_block_len"], - share_inputs["used_list_len"], - share_inputs["free_list"], - share_inputs["free_list_len"], - share_inputs["input_ids"], - share_inputs["pre_ids"], - share_inputs["step_idx"], - share_inputs["next_tokens"], - share_inputs["first_token_ids"], - block_size, - enc_dec_block_num, - ) diff --git a/fastdeploy/model_executor/xpu_pre_and_post_process.py b/fastdeploy/model_executor/xpu_pre_and_post_process.py new file mode 100644 index 00000000000..9a2ea16aac4 --- /dev/null +++ b/fastdeploy/model_executor/xpu_pre_and_post_process.py @@ -0,0 +1,315 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Dict, Optional + +import paddle + +from fastdeploy import envs +from fastdeploy.model_executor.forward_meta import XPUForwardMeta +from fastdeploy.platforms import current_platform +from fastdeploy.worker.output import ModelOutputData + +if current_platform.is_xpu(): + from fastdeploy.model_executor.ops.xpu import ( + adjust_batch, + gather_next_token, + get_infer_param, + get_padding_offset, + limit_thinking_content_length_v1, + limit_thinking_content_length_v2, + update_inputs_v1, + ) + + +def xpu_pre_process( + input_ids: paddle.Tensor, + seq_lens_this_time: int, + share_inputs: Dict, + use_speculate_method: bool, + block_size: int, + draft_tokens: Optional[paddle.Tensor] = None, + seq_lens_encoder: Optional[paddle.Tensor] = None, + seq_lens_decoder: Optional[paddle.Tensor] = None, + is_profiling: bool = False, +) -> XPUForwardMeta: + """ """ + max_len = input_ids.shape[1] + cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32") + token_num = paddle.sum(seq_lens_this_time) + + ( + ids_remove_padding, + cum_offsets, + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + ) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time) + + share_inputs["ids_remove_padding"] = None # set this after adjust batch + share_inputs["cum_offsets"] = cum_offsets + share_inputs["batch_id_per_token"] = batch_id_per_token + share_inputs["cu_seqlens_q"] = cu_seqlens_q + share_inputs["cu_seqlens_k"] = cu_seqlens_k + + xpu_forward_meta = XPUForwardMeta( + ids_remove_padding=share_inputs["ids_remove_padding"], + rotary_embs=share_inputs["rope_emb"], + attn_backend=None, + seq_lens_encoder=share_inputs["seq_lens_encoder"], + seq_lens_decoder=share_inputs["seq_lens_decoder"], + seq_lens_this_time=share_inputs["seq_lens_this_time"], + cum_offsets=share_inputs["cum_offsets"], + batch_id_per_token=share_inputs["batch_id_per_token"], + cu_seqlens_q=share_inputs["cu_seqlens_q"], + cu_seqlens_k=share_inputs["cu_seqlens_k"], + block_tables=share_inputs["block_tables"], + caches=share_inputs["caches"], + ) + + ( + xpu_forward_meta.encoder_batch_map, + xpu_forward_meta.decoder_batch_map, + xpu_forward_meta.encoder_batch_idx, + xpu_forward_meta.decoder_batch_idx, + xpu_forward_meta.encoder_seq_lod, + xpu_forward_meta.decoder_seq_lod, + xpu_forward_meta.encoder_kv_lod, + xpu_forward_meta.prefix_len, + xpu_forward_meta.decoder_context_len, + xpu_forward_meta.decoder_context_len_cache, + xpu_forward_meta.prefix_block_tables, + xpu_forward_meta.encoder_batch_map_cpu, + xpu_forward_meta.decoder_batch_map_cpu, + xpu_forward_meta.encoder_batch_idx_cpu, + xpu_forward_meta.decoder_batch_idx_cpu, + xpu_forward_meta.encoder_seq_lod_cpu, + xpu_forward_meta.decoder_seq_lod_cpu, + xpu_forward_meta.encoder_kv_lod_cpu, + xpu_forward_meta.prefix_len_cpu, + xpu_forward_meta.decoder_context_len_cpu, + xpu_forward_meta.decoder_context_len_cache_cpu, + xpu_forward_meta.len_info_cpu, + ) = get_infer_param( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, xpu_forward_meta.block_tables, block_size + ) + xpu_forward_meta.enc_batch = xpu_forward_meta.len_info_cpu[0] + xpu_forward_meta.dec_batch = xpu_forward_meta.len_info_cpu[1] + xpu_forward_meta.total_enc_len = xpu_forward_meta.len_info_cpu[2] + + adjusted_input = adjust_batch( + ids_remove_padding.reshape([-1, 1]), + cum_offsets, + xpu_forward_meta.encoder_seq_lod, + xpu_forward_meta.decoder_seq_lod, + xpu_forward_meta.encoder_batch_idx, + xpu_forward_meta.decoder_batch_idx, + xpu_forward_meta.encoder_seq_lod_cpu, + xpu_forward_meta.decoder_seq_lod_cpu, + xpu_forward_meta.encoder_batch_idx_cpu, + xpu_forward_meta.decoder_batch_idx_cpu, + xpu_forward_meta.len_info_cpu, + None, # output_padding_offset + -1, # max bs + ) + + adjusted_input = adjusted_input.squeeze(1) + + share_inputs["ids_remove_padding"] = adjusted_input + xpu_forward_meta.ids_remove_padding = adjusted_input + # Set forward_meta.is_profiling to True to skip init_kv_signal_per_query for attention backends + xpu_forward_meta.is_profiling = is_profiling + return xpu_forward_meta + + +def xpu_process_output( + forward_output, + cum_offsets: paddle.Tensor, + xpu_forward_meta: XPUForwardMeta, + share_inputs, +) -> paddle.Tensor: + """ """ + + output_padding_offset = share_inputs.get("output_padding_offset", None) + + hiddden_states = gather_next_token( + forward_output, + cum_offsets, + xpu_forward_meta.encoder_seq_lod, + xpu_forward_meta.decoder_seq_lod, + xpu_forward_meta.encoder_batch_map, + xpu_forward_meta.decoder_batch_map, + xpu_forward_meta.encoder_seq_lod_cpu, + xpu_forward_meta.decoder_seq_lod_cpu, + xpu_forward_meta.encoder_batch_map_cpu, + xpu_forward_meta.decoder_batch_map_cpu, + xpu_forward_meta.len_info_cpu, + output_padding_offset, # output_padding_offset + -1, # max_input_length + ) + return hiddden_states + + +def xpu_post_process_normal( + sampled_token_ids: paddle.Tensor, + model_output: ModelOutputData, + share_inputs: Dict[str, paddle.Tensor], + block_size: int = 64, + skip_save_output: bool = False, + think_end_id: int = None, + line_break_id: int = None, +) -> None: + """ """ + from fastdeploy.model_executor.ops.xpu import ( + save_output, + set_stop_value_multi_ends, + update_inputs, + ) + + if think_end_id > 0: + limit_strategy = envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR + max_think_lens = share_inputs["max_think_lens"] + step_idx = share_inputs["step_idx"] + limit_think_status = share_inputs["limit_think_status"] + stop_flags = share_inputs["stop_flags"] + eos_token_ids = share_inputs["eos_token_id"] + if limit_strategy == "": + # for ernie-45-vl + limit_thinking_content_length_v1( + sampled_token_ids, + max_think_lens, + step_idx, + limit_think_status, + stop_flags, + eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题 + think_end_id, + ) + elif limit_strategy == "\n\n\n": + # for ernie-x1 + assert line_break_id > 0 + limit_thinking_content_length_v2( + sampled_token_ids, + max_think_lens, + step_idx, + limit_think_status, + stop_flags, + think_end_id, + line_break_id, + ) + else: + raise NotImplementedError(f"Not support {limit_strategy=} for limit thinking content length.") + + # 1. Set stop value + paddle.assign( + paddle.where( + model_output.stop_flags, + model_output.step_idx, + model_output.step_idx + 1, + ), + model_output.step_idx, + ) + length_cond = paddle.greater_equal(model_output.step_idx, model_output.max_dec_len) + paddle.assign( + paddle.logical_or(model_output.stop_flags, length_cond), + model_output.stop_flags, + ) + set_stop_value_multi_ends( + sampled_token_ids, + model_output.stop_flags, + model_output.seq_lens_this_time, + model_output.eos_token_id, + model_output.next_tokens, + False, + ) # multi ends + + # 2. Update the input buffer of the model + with paddle.framework._no_check_dy2st_diff(): + if envs.ENABLE_V1_KVCACHE_SCHEDULER and not skip_save_output: + update_inputs_v1( + model_output.stop_flags, + model_output.not_need_stop, + model_output.seq_lens_this_time, + model_output.seq_lens_encoder, + model_output.seq_lens_decoder, + share_inputs["step_seq_lens_decoder"], + share_inputs["prompt_lens"], + sampled_token_ids, + model_output.input_ids, + share_inputs["block_tables"], + model_output.stop_nums, + model_output.next_tokens, + model_output.is_block_step, + block_size, + ) + else: + update_inputs( + model_output.stop_flags, + model_output.not_need_stop, + model_output.seq_lens_this_time, + model_output.seq_lens_encoder, + model_output.seq_lens_decoder, + model_output.input_ids, + model_output.stop_nums, + sampled_token_ids, + model_output.is_block_step, + ) + # 3. Transmit the model's output and stop generation signal via message queue. + # In the future, we will abandon this approach. + if not skip_save_output: + save_output( + sampled_token_ids, + model_output.not_need_stop, + model_output.mp_rank, + False, # use_ep + ) + + +def step_xpu( + share_inputs: Dict[str, paddle.Tensor], + block_size: int, + enc_dec_block_num: int, +) -> None: + """ + TODO(gongshaotian): normalization name + """ + from fastdeploy.model_executor.ops.xpu import step_paddle + + step_paddle( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + block_size, + enc_dec_block_num, + ) diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 9a21b23627c..6338965d218 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -49,7 +49,7 @@ set_data_ipc, share_external_data, ) -from fastdeploy.model_executor.pre_and_post_process import ( # xpu_post_process_specualate, # TODO(chenhuan09): add xpu_post_process_specualate +from fastdeploy.model_executor.xpu_pre_and_post_process import ( # xpu_post_process_specualate, # TODO(chenhuan09): add xpu_post_process_specualate step_xpu, xpu_post_process_normal, xpu_pre_process, From 790ffc463adc2b26cb39cd5e70d426d13cd3fc4a Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Tue, 25 Nov 2025 04:07:17 +0000 Subject: [PATCH 15/17] mv xpu pre/post process --- fastdeploy/model_executor/pre_and_post_process.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index dd3cd01343a..d2c82e2afaa 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -64,17 +64,6 @@ ) elif current_platform.is_intel_hpu(): pass -elif current_platform.is_xpu(): - from fastdeploy.model_executor.ops.xpu import ( - get_padding_offset, - limit_thinking_content_length_v1, - limit_thinking_content_length_v2, - save_output, - set_stop_value_multi_ends, - step_paddle, - update_inputs, - update_inputs_v1, - ) else: from fastdeploy.model_executor.ops.gpu import ( get_padding_offset, From bc4ad8563345450d4188046985c547e0ad4a25e6 Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Thu, 27 Nov 2025 11:02:38 +0000 Subject: [PATCH 16/17] [xpu] support mtp --- .../sample/ops/apply_penalty_multi_scores.py | 36 +-- .../model_executor/layers/sample/sampler.py | 119 ++++++++++ .../xpu_pre_and_post_process.py | 194 ++++++++++++---- fastdeploy/output/token_processor.py | 6 +- fastdeploy/spec_decode/__init__.py | 5 +- fastdeploy/spec_decode/mtp.py | 216 +++++++++++++++-- fastdeploy/worker/xpu_model_runner.py | 217 ++++++++++++++++-- fastdeploy/worker/xpu_worker.py | 4 +- 8 files changed, 691 insertions(+), 106 deletions(-) diff --git a/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py b/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py index 04a8ab10244..9817eaba605 100644 --- a/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py +++ b/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py @@ -182,24 +182,28 @@ def apply_speculative_penalty_multi_scores( from fastdeploy.model_executor.ops.gpu import ( speculate_get_token_penalty_multi_scores, ) - - speculate_get_token_penalty_multi_scores( - pre_token_ids, - logits, - repetition_penalties, - frequency_penalties, - presence_penalties, - temperature, - bad_words_token_ids, - step_idx, - min_dec_lens, - eos_token_ids, - seq_lens_this_time, - output_padding_offset, - output_cum_offsets, - max_len, + elif current_platform.is_xpu(): + from fastdeploy.model_executor.ops.xpu import ( + speculate_get_token_penalty_multi_scores, ) + else: raise NotImplementedError + speculate_get_token_penalty_multi_scores( + pre_token_ids, + logits, + repetition_penalties, + frequency_penalties, + presence_penalties, + temperature, + bad_words_token_ids, + step_idx, + min_dec_lens, + eos_token_ids, + seq_lens_this_time, + output_padding_offset, + output_cum_offsets, + max_len, + ) # inplace return logits diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 0e931bb437f..f65d314d8d8 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -572,6 +572,8 @@ def __init__(self, fd_config: FDConfig): super().__init__() if current_platform.is_cuda(): self.forward = self.forward_cuda + elif current_platform.is_xpu(): + self.forward = self.forward_xpu else: raise NotImplementedError self.logprobs_mode = fd_config.model_config.logprobs_mode @@ -814,6 +816,80 @@ def forward_cuda( return sampler_output + def forward_xpu( + self, + logits: paddle.Tensor, + sampling_metadata: SamplingMetadata, + max_model_len: int, + share_inputs: List[paddle.Tensor], + accept_all_drafts: bool = False, + reject_all_drafts: bool = False, + ) -> paddle.Tensor: + from fastdeploy.model_executor.ops.xpu import speculate_verify, top_p_candidates + + logits = apply_speculative_penalty_multi_scores( + sampling_metadata.pre_token_ids, + logits, + sampling_metadata.repetition_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.presence_penalties, + sampling_metadata.temperature, + sampling_metadata.bad_words_token_ids, + sampling_metadata.step_idx, + sampling_metadata.min_dec_lens, + sampling_metadata.eos_token_ids, + share_inputs["seq_lens_this_time"], + share_inputs["output_padding_offset"], + share_inputs["output_cum_offsets"], + max_model_len, + ) + + probs = F.softmax(logits) + + verify_scores, verify_tokens, actual_candidate_len = top_p_candidates( + probs, + sampling_metadata.top_p, + share_inputs["output_padding_offset"], + self.speculative_max_candidate_len, + max_model_len, + ) + + speculate_verify( + share_inputs["accept_tokens"], + share_inputs["accept_num"], + share_inputs["step_idx"], + share_inputs["stop_flags"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs[ + "draft_tokens" + ], # Both input and output, need to write the last 1 token accepted to position 0. + share_inputs["seq_lens_this_time"], + verify_tokens, + verify_scores, + share_inputs["max_dec_len"], + sampling_metadata.eos_token_ids, + share_inputs["is_block_step"], + share_inputs["output_cum_offsets"], + actual_candidate_len, + share_inputs["actual_draft_token_num"], + sampling_metadata.top_p, + max_model_len, + self.speculative_verify_window, + True, # enable_topp + (self.speculative_benchmark_mode or reject_all_drafts), + accept_all_drafts, + ) + # TODO(chenhuan09): support return logprobs + token_ids = share_inputs["accept_tokens"] + sampler_output = SamplerOutput( + sampled_token_ids=token_ids, + logprobs_tensors=None, + token_num_per_batch=share_inputs["accept_num"], + cu_batch_token_offset=None, + ) + return sampler_output + class MTPSampler(nn.Layer): """ """ @@ -823,6 +899,8 @@ def __init__(self, fd_config: FDConfig): super().__init__() if current_platform.is_cuda(): self.forward = self.forward_cuda + elif current_platform.is_xpu(): + self.forward = self.forward_xpu else: raise NotImplementedError self.logprobs_mode = fd_config.model_config.logprobs_mode @@ -1013,3 +1091,44 @@ def forward_cuda( cu_batch_token_offset=share_inputs["cu_batch_token_offset"], ) return next_tokens, sampler_output + + def forward_xpu( + self, + logits: paddle.Tensor, + sampling_metadata: SamplingMetadata, + max_model_len: int, + share_inputs: List[paddle.Tensor], + ) -> paddle.Tensor: + + logits = apply_speculative_penalty_multi_scores( + sampling_metadata.pre_token_ids, + logits, + sampling_metadata.repetition_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.presence_penalties, + sampling_metadata.temperature, + sampling_metadata.bad_words_token_ids, + sampling_metadata.step_idx, + sampling_metadata.min_dec_lens, + sampling_metadata.eos_token_ids, + share_inputs["seq_lens_this_time"], + share_inputs["output_padding_offset"], + share_inputs["output_cum_offsets"], + max_model_len, + ) + probs = F.softmax(logits) + + _, next_tokens = top_k_top_p_sampling( + probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list + ) + # TODO(chenhuan09): add support for logprobs + token_ids = None + logprobs_tensors = None + + sampler_output = SamplerOutput( + sampled_token_ids=token_ids, + logprobs_tensors=logprobs_tensors, + token_num_per_batch=None, + cu_batch_token_offset=None, + ) + return next_tokens, sampler_output diff --git a/fastdeploy/model_executor/xpu_pre_and_post_process.py b/fastdeploy/model_executor/xpu_pre_and_post_process.py index 9a2ea16aac4..861b3b533a9 100644 --- a/fastdeploy/model_executor/xpu_pre_and_post_process.py +++ b/fastdeploy/model_executor/xpu_pre_and_post_process.py @@ -31,6 +31,18 @@ get_padding_offset, limit_thinking_content_length_v1, limit_thinking_content_length_v2, + save_output, + set_stop_value_multi_ends, + speculate_clear_accept_nums, + speculate_get_output_padding_offset, + speculate_get_padding_offset, + speculate_get_seq_lens_output, + speculate_save_output, + speculate_set_value_by_flags_and_idx, + speculate_step_paddle, + speculate_update_v3, + step_paddle, + update_inputs, update_inputs_v1, ) @@ -45,19 +57,53 @@ def xpu_pre_process( seq_lens_encoder: Optional[paddle.Tensor] = None, seq_lens_decoder: Optional[paddle.Tensor] = None, is_profiling: bool = False, + forward_meta=None, ) -> XPUForwardMeta: """ """ max_len = input_ids.shape[1] cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32") token_num = paddle.sum(seq_lens_this_time) - ( - ids_remove_padding, - cum_offsets, - batch_id_per_token, - cu_seqlens_q, - cu_seqlens_k, - ) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time) + if use_speculate_method: + ( + ids_remove_padding, + cum_offsets, + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + ) = speculate_get_padding_offset( + input_ids, + draft_tokens, + cum_offsets_now, + token_num, + seq_lens_this_time, + seq_lens_encoder, + ) + seq_lens_output = speculate_get_seq_lens_output( + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + ) + if isinstance(seq_lens_output, list): + seq_lens_output = seq_lens_output[0] + output_token_num = paddle.sum(seq_lens_output) + output_cum_offsets_tmp = paddle.cumsum(max_len - seq_lens_output, dtype="int32") + output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset( + output_cum_offsets_tmp, + output_token_num, + seq_lens_output, + max_len, + ) + share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False) + share_inputs["output_padding_offset"].copy_(output_padding_offset, False) + else: + ( + ids_remove_padding, + cum_offsets, + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + ) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time) share_inputs["ids_remove_padding"] = None # set this after adjust batch share_inputs["cum_offsets"] = cum_offsets @@ -173,11 +219,6 @@ def xpu_post_process_normal( line_break_id: int = None, ) -> None: """ """ - from fastdeploy.model_executor.ops.xpu import ( - save_output, - set_stop_value_multi_ends, - update_inputs, - ) if think_end_id > 0: limit_strategy = envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR @@ -277,39 +318,110 @@ def xpu_post_process_normal( ) +def xpu_post_process_specualate( + model_output: ModelOutputData, save_each_rank: bool = False, skip_save_output: bool = False +): + """""" + speculate_update_v3( + model_output.seq_lens_encoder, + model_output.seq_lens_decoder, + model_output.not_need_stop, + model_output.draft_tokens, + model_output.actual_draft_token_num, + model_output.accept_tokens, + model_output.accept_num, + model_output.stop_flags, + model_output.seq_lens_this_time, + model_output.is_block_step, + model_output.stop_nums, + ) + if not skip_save_output: + speculate_save_output( + model_output.accept_tokens, + model_output.accept_num, + model_output.not_need_stop, + model_output.mp_rank, + save_each_rank, # False + ) + + speculate_clear_accept_nums(model_output.accept_num, model_output.seq_lens_decoder) + + # Update pre_ids through accept tokens + speculate_set_value_by_flags_and_idx( + model_output.pre_ids, + model_output.accept_tokens, + model_output.accept_num, + model_output.stop_flags, + model_output.seq_lens_this_time, + model_output.seq_lens_encoder, + model_output.seq_lens_decoder, + model_output.step_idx, + ) + + def step_xpu( share_inputs: Dict[str, paddle.Tensor], block_size: int, enc_dec_block_num: int, + speculative_decoding: bool, + max_draft_token_num: int, ) -> None: """ - TODO(gongshaotian): normalization name + TODO(chenhuan09): support PD """ - from fastdeploy.model_executor.ops.xpu import step_paddle - - step_paddle( - share_inputs["stop_flags"], - share_inputs["seq_lens_this_time"], - share_inputs["step_seq_lens_encoder"], - share_inputs["seq_lens_encoder"], - share_inputs["seq_lens_decoder"], - share_inputs["block_tables"], - share_inputs["encoder_block_lens"], - share_inputs["is_block_step"], - share_inputs["step_block_list"], - share_inputs["step_lens"], - share_inputs["recover_block_list"], - share_inputs["recover_lens"], - share_inputs["need_block_list"], - share_inputs["need_block_len"], - share_inputs["used_list_len"], - share_inputs["free_list"], - share_inputs["free_list_len"], - share_inputs["input_ids"], - share_inputs["pre_ids"], - share_inputs["step_idx"], - share_inputs["next_tokens"], - share_inputs["first_token_ids"], - block_size, - enc_dec_block_num, - ) + if speculative_decoding: + speculate_step_paddle( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + share_inputs["accept_num"], + block_size, + enc_dec_block_num, + max_draft_token_num, + ) + else: + step_paddle( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["step_seq_lens_encoder"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + block_size, + enc_dec_block_num, + ) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 529e1f4a7db..3d9f23630b3 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -340,7 +340,11 @@ def process_sampling_results(self): """ if current_platform.is_xpu(): - from fastdeploy.model_executor.ops.xpu import get_output, get_output_ep + from fastdeploy.model_executor.ops.xpu import ( + get_output, + get_output_ep, + speculate_get_output, + ) elif current_platform.is_iluvatar(): from fastdeploy.model_executor.ops.iluvatar import get_output elif current_platform.is_gcu(): diff --git a/fastdeploy/spec_decode/__init__.py b/fastdeploy/spec_decode/__init__.py index 824d5da56ad..086b5003a0d 100644 --- a/fastdeploy/spec_decode/__init__.py +++ b/fastdeploy/spec_decode/__init__.py @@ -14,9 +14,12 @@ """ speculative decoding module """ +from fastdeploy.platforms import current_platform from .base import Proposer from .mtp import MTPProposer -from .ngram import NgramProposer +# XPU is not support ngram proposer now +if not current_platform.is_xpu(): + from .ngram import NgramProposer __all__ = ["Proposer", "MTPProposer", "NgramProposer"] diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 437e89bd756..5553c1047fd 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -34,21 +34,39 @@ from fastdeploy.model_executor.layers.sample.sampler import MTPSampler from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.model_executor.models import ModelForCasualLM -from fastdeploy.model_executor.ops.gpu import ( - draft_model_postprocess, - draft_model_preprocess, - draft_model_update, - eagle_get_hidden_states, - eagle_get_self_hidden_states, - hybrid_mtp_ngram, - mtp_save_first_token, - mtp_step_paddle, - share_external_data, - speculate_get_logits, - speculate_save_output_topk, - update_attn_mask_offsets, -) -from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding +from fastdeploy.platforms import current_platform + +if current_platform.is_xpu(): + from fastdeploy.model_executor.ops.xpu import ( + draft_model_postprocess, + draft_model_preprocess, + draft_model_update, + eagle_get_hidden_states, + eagle_get_self_hidden_states, + mtp_save_first_token, + mtp_step_paddle, + share_external_data, + ) + from fastdeploy.model_executor.xpu_pre_and_post_process import ( + xpu_pre_process, + xpu_process_output, + ) +else: + from fastdeploy.model_executor.ops.gpu import ( + draft_model_postprocess, + draft_model_preprocess, + draft_model_update, + eagle_get_hidden_states, + eagle_get_self_hidden_states, + hybrid_mtp_ngram, + mtp_save_first_token, + mtp_step_paddle, + share_external_data, + speculate_get_logits, + speculate_save_output_topk, + update_attn_mask_offsets, + ) + from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding from .base import Proposer @@ -79,6 +97,15 @@ def __init__( # [mixed, prefill, decoder] self.role = self.scheduler_config.splitwise_role + if current_platform.is_xpu(): + self.role = "mixed" + + if current_platform.is_xpu(): + self._propose = self._propose_xpu + elif current_platform.is_cuda(): + self._propose = self._propose_cuda + else: + raise RuntimeError("Unsupported platform.") self.sampler = MTPSampler(fd_config) self._init_model_inputs() @@ -92,7 +119,7 @@ def __init__( self._initialize_attn_backend() # Forward meta store the global meta information of the forward - self.forward_meta: ForwardMeta = None + self.forward_meta = None def _update_mtp_config(self, main_model): """ @@ -166,7 +193,7 @@ def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False): and hasattr(self.quant_config, "kv_cache_quant_type") and self.quant_config.kv_cache_quant_type is not None ): - cache_type = "uint8" + cache_type = self._get_cache_type() kv_cache_quant_type = self.quant_config.kv_cache_quant_type # Get kv cache shape @@ -220,7 +247,7 @@ def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False): self.model_inputs["caches"] = list(self.cache_kvs.values()) for value in self.cache_kvs.values(): del value - paddle.device.cuda.empty_cache() + self._empty_cache() def _initialize_attn_backend( self, @@ -245,9 +272,14 @@ def _initialize_attn_backend( self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like( self.target_model_inputs["decoder_tile_ids_per_batch"] ) - self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like( - self.target_model_inputs["decoder_num_blocks_cpu"] - ).pin_memory() + if current_platform.is_xpu(): + self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like( + self.target_model_inputs["decoder_num_blocks_cpu"] + ).cpu() + else: + self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like( + self.target_model_inputs["decoder_num_blocks_cpu"] + ).pin_memory() self.model_inputs["decoder_num_blocks_device"] = paddle.zeros_like( self.target_model_inputs["decoder_num_blocks_device"] ) @@ -669,6 +701,36 @@ def _initialize_forward_meta(self, step_use_cudagraph: bool = False): self.forward_meta.step_use_cudagraph = step_use_cudagraph and self.draft_model_use_cudagraph + def _initialize_forward_meta_xpu(self): + + self.forward_meta.decoder_batch_ids = (self.model_inputs["decoder_batch_ids"],) + self.forward_meta.decoder_tile_ids_per_batch = (self.model_inputs["decoder_tile_ids_per_batch"],) + self.forward_meta.decoder_num_blocks_cpu = (self.model_inputs["decoder_num_blocks_cpu"],) + self.forward_meta.decoder_num_blocks_device = (self.model_inputs["decoder_num_blocks_device"],) + self.forward_meta.decoder_chunk_size_device = (self.model_inputs["decoder_chunk_size_device"],) + self.forward_meta.max_len_tensor_cpu = (self.model_inputs["max_len_tensor_cpu"],) + + self.forward_meta.encoder_batch_ids = (self.model_inputs["encoder_batch_ids"],) + self.forward_meta.encoder_tile_ids_per_batch = (self.model_inputs["encoder_tile_ids_per_batch"],) + self.forward_meta.encoder_num_blocks_x_cpu = (self.model_inputs["encoder_num_blocks_x_cpu"],) + self.forward_meta.kv_batch_ids = (self.model_inputs["kv_batch_ids"],) + self.forward_meta.kv_tile_ids_per_batch = (self.model_inputs["kv_tile_ids_per_batch"],) + self.forward_meta.kv_num_blocks_x_cpu = (self.model_inputs["kv_num_blocks_x_cpu"],) + self.forward_meta.pos_emb_type = "NORMAL" + self.forward_meta.attn_backend = self.attn_backends[0] + + # Initialzie attention meta data + for attn_backend in self.attn_backends: + attn_backend.init_attention_metadata(self.forward_meta) + + # Mix ep in single node + if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": + only_decode_batch_list = [] + prefill_exists = self.exist_prefill() + paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) + only_decode_batch = all(only_decode_batch_list) + self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill" + def exist_prefill(self): """ check whether prefill stage exist @@ -682,7 +744,7 @@ def _prepare_inputs(self, full_hidden_states): """ Prepare MTP inputs """ - use_v1_cache_scheduler = envs.ENABLE_V1_KVCACHE_SCHEDULER + use_v1_cache_scheduler = bool(envs.ENABLE_V1_KVCACHE_SCHEDULER) draft_model_preprocess( self.model_inputs["draft_tokens"], self.model_inputs["input_ids"], @@ -767,7 +829,7 @@ def _post_process(self, sampled_token_ids): self.model_inputs["step_idx"], ) - def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False): + def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False): """ Main process for MTP inference. Args: @@ -928,6 +990,96 @@ def _propose(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False) if hasattr(self.model, "empty_input_forward"): self.model.empty_input_forward() + def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False): + """ + Main process for MTP inference. + Args: + step_use_cudagraph: bool + Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP. + """ + for substep in range(self.num_model_steps): + if self.model_inputs["not_need_stop"]: + self.model_inputs["substep"] = substep + # Remove padding + self.forward_meta = xpu_pre_process( + self.model_inputs["input_ids"], + self.model_inputs["seq_lens_this_time"], + self.model_inputs, + True, + self.cache_config.block_size, + self.model_inputs["draft_tokens"], + self.model_inputs["seq_lens_encoder"], + self.model_inputs["seq_lens_decoder"], + ) + self._initialize_forward_meta_xpu() + # Get sampling metadata + self.sampling_metadata = SamplingMetadata( + temperature=self.model_inputs["temperature"], + top_p=self.model_inputs["top_p"], + top_k=self.model_inputs["top_k"], + seed=self.model_inputs["infer_seed"], + step_idx=self.model_inputs["step_idx"], + pre_token_ids=self.model_inputs["pre_ids"], + frequency_penalties=self.model_inputs["frequency_score"], + presence_penalties=self.model_inputs["presence_score"], + repetition_penalties=self.model_inputs["penalty_score"], + min_dec_lens=self.model_inputs["min_dec_len"], + bad_words_token_ids=self.model_inputs["bad_tokens"], + eos_token_ids=self.model_inputs["eos_token_id"], + max_num_logprobs=20 if self.enable_logprob else None, + temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"], + top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"], + share_inputs=self.model_inputs, + ) + + if self.num_model_steps > 1: + self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"]) + + model_output = self.model( + ids_remove_padding=self.model_inputs["ids_remove_padding"], + previous_hidden_states=self.model_inputs["target_hidden_states"], + forward_meta=self.forward_meta, + ) + hidden_states = xpu_process_output( + model_output, self.model_inputs["cum_offsets"], self.forward_meta, self.model_inputs + ) + # 4. Compute logits, Sample + logits = self.model.compute_logits(hidden_states) + sampled_token_ids, sampler_output = self.sampler( + logits, + self.sampling_metadata, + self.max_model_len, + self.model_inputs, + ) + + if substep == 0 and sampler_output.logprobs_tensors is not None: + real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + self.model_inputs["batch_token_num"][:real_bsz], + self.model_inputs["cu_batch_token_offset"][:real_bsz], + self.model_inputs["not_need_stop"], + 4, # mtype + self.local_rank, + ) + + if self.parallel_config.tensor_parallel_size > 1: + paddle.distributed.broadcast( + sampled_token_ids, + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + + self._post_process(sampled_token_ids) + if substep != self.num_model_steps - 1: + self._get_self_hidden_states(hidden_states) + else: + if hasattr(self.model, "empty_input_forward"): + self.model.empty_input_forward() + def _get_self_hidden_states(self, hidden_states): target_hidden_states = eagle_get_self_hidden_states( hidden_states, @@ -1044,3 +1196,21 @@ def padding_cudagraph_inputs(self) -> None: self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer self.real_token_num = self.forward_meta.ids_remove_padding.shape[0] return + + def _empty_cache(self): + if current_platform.is_cuda(): + paddle.device.cuda.empty_cache() + elif current_platform.is_xpu(): + paddle.device.xpu.empty_cache() + else: + raise NotImplementedError + + def _get_cache_type(self): + cache_type = None + if current_platform.is_cuda(): + cache_type = "uint8" + elif current_platform.is_xpu(): + cache_type = "int8" + else: + raise NotImplementedError + return cache_type diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 6338965d218..64c8f8f37ac 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -39,7 +39,7 @@ ) from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope_3d from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata -from fastdeploy.model_executor.layers.sample.sampler import Sampler +from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp from fastdeploy.model_executor.ops.xpu import ( @@ -49,12 +49,14 @@ set_data_ipc, share_external_data, ) -from fastdeploy.model_executor.xpu_pre_and_post_process import ( # xpu_post_process_specualate, # TODO(chenhuan09): add xpu_post_process_specualate +from fastdeploy.model_executor.xpu_pre_and_post_process import ( step_xpu, xpu_post_process_normal, + xpu_post_process_specualate, xpu_pre_process, xpu_process_output, ) +from fastdeploy.spec_decode import MTPProposer from fastdeploy.utils import get_logger from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput @@ -102,9 +104,20 @@ def __init__( "fused_gemm_epilogue", ] + self.device_id = device_id + self.speculative_method = self.fd_config.speculative_config.method + self.speculative_decoding = self.speculative_method is not None + + # used by SamplingMetadata + self.enable_logprob = False # fd_config.model_config.enable_logprob + self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop + # Sampler # TODU(lilujia): sync with GPU - self.sampler = Sampler(fd_config) + if not self.speculative_decoding: + self.sampler = Sampler(fd_config) + else: + self.sampler = SpeculativeSampler(fd_config) # Lazy initialize kv cache after model loading # self.kv_caches: list[paddle.Tensor] = [] @@ -143,7 +156,7 @@ def exist_prefill(self): else: return 0 - def insert_tasks_v1(self, req_dicts: List[Request]): + def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int): """ Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1 req_dict: A list of Request dict @@ -340,7 +353,10 @@ def insert_tasks_v1(self, req_dicts: List[Request]): if has_prefill_task or has_decode_task: self.share_inputs["not_need_stop"][0] = True - def insert_prefill_inputs(self, req_dicts: List[Request]): + if self.speculative_method in ["mtp"]: + self.proposer.insert_tasks_v1(req_dicts, num_running_requests) + + def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int): """Process inputs for prefill tasks and update share_inputs buffer""" # NOTE(luotingdan): Set environment variable of prefill node if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill": @@ -480,6 +496,15 @@ def get_attr_from_request(request, attr, default_value=None): self.share_inputs["not_need_stop"][0] = True + if self.speculative_method in ["mtp"]: + self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = get_attr_from_request( + request, "temp_scaled_logprobs", False + ) + self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = get_attr_from_request( + request, "top_p_normalized_logprobs", False + ) + self.proposer.insert_prefill_inputs(req_dicts, num_running_requests) + def _init_share_inputs(self, max_num_seqs: int): """Initialize all share buffers for model inputs. Note: In the future, we may abandon share buffers. @@ -558,6 +583,15 @@ def _init_share_inputs(self, max_num_seqs: int): self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["system_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int32") + self.share_inputs["ids_remove_padding"] = paddle.full( + [max_num_seqs * self.model_config.max_model_len], + 0, + dtype="int64", + ) + self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + # Initialize thinking related buffers self.share_inputs["max_think_lens"] = paddle.full(shape=[max_num_seqs, 1], fill_value=-1, dtype="int32") self.share_inputs["limit_think_status"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") @@ -629,6 +663,56 @@ def _init_share_inputs(self, max_num_seqs: int): ) self.share_inputs["image_features"] = None + if self.speculative_decoding: + max_draft_token_num = self.speculative_config.num_speculative_tokens + self.share_inputs["input_ids_cpu"] = paddle.full( + shape=[max_num_seqs, self.model_config.max_model_len], + fill_value=1, + dtype="int64", + ).cpu() + self.share_inputs["accept_tokens"] = paddle.full( + shape=[max_num_seqs, max_draft_token_num + 1], + fill_value=0, + dtype="int64", + ) + self.share_inputs["accept_num"] = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32") + self.share_inputs["draft_tokens"] = paddle.full( + shape=[max_num_seqs, max_draft_token_num + 1], + fill_value=0, + dtype="int64", + ) + + self.share_inputs["actual_draft_token_num"] = paddle.full( + shape=[max_num_seqs], + fill_value=max_draft_token_num, + dtype="int32", + ) + self.share_inputs["output_cum_offsets"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") + self.share_inputs["output_padding_offset"] = paddle.full( + shape=[max_num_seqs * (max_draft_token_num + 1)], + fill_value=0, + dtype="int32", + ) + # For V1_KVCACHE_SCHEDULER + self.share_inputs["step_draft_tokens"] = paddle.full( + shape=[max_num_seqs, max_draft_token_num + 1], + fill_value=0, + dtype="int64", + ) + self.share_inputs["step_seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["temp_scaled_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype=bool) + self.share_inputs["top_p_normalized_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype=bool) + # For MTP Logprob + self.share_inputs["draft_logits"] = paddle.full( + [max_num_seqs * (self.speculative_config.num_speculative_tokens + 1), self.model_config.vocab_size], + -1, + dtype="float32", + ) + self.share_inputs["cu_batch_token_offset"] = paddle.full( + shape=[max_num_seqs + 1], fill_value=0, dtype="int32" + ) + self.max_num_seqs = max_num_seqs + def _prepare_inputs(self, is_dummy_run=False) -> None: """Prepare the model inputs""" if envs.ENABLE_V1_KVCACHE_SCHEDULER and not is_dummy_run: @@ -646,9 +730,9 @@ def _prepare_inputs(self, is_dummy_run=False) -> None: self.share_inputs["input_ids"], self.share_inputs["seq_lens_this_time"], self.share_inputs, - use_speculate_method=False, + use_speculate_method=self.speculative_decoding, block_size=self.cache_config.block_size, - draft_tokens=None, + draft_tokens=self.share_inputs["draft_tokens"] if self.speculative_decoding else None, seq_lens_encoder=self.share_inputs["seq_lens_encoder"], seq_lens_decoder=self.share_inputs["seq_lens_decoder"], is_profiling=is_dummy_run, @@ -696,6 +780,7 @@ def load_model(self) -> None: # 2. Load lora model # 3. Load drafter model(for speculative decoding) + self._init_speculative_proposer() def get_model(self) -> nn.Layer: """Get current model""" @@ -793,6 +878,44 @@ def initialize_attn_backend(self) -> None: ) head_dim = self.model_config.head_dim + if self.speculative_decoding: + # Initialize AttentionBackend buffers + encoder_block_shape_q = 64 + decoder_block_shape_q = 16 + decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1 + decode_max_tile_size = self.max_num_seqs * np.ceil( + (decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q + ) + + group_size = np.ceil(num_heads / self.model_config.kv_num_heads) + encode_max_tile_size = self.scheduler_config.max_num_seqs * np.ceil( + (self.model_config.max_model_len * group_size) / encoder_block_shape_q + ) + kv_max_tile_size = self.scheduler_config.max_num_seqs * np.ceil( + self.model_config.max_model_len / self.fd_config.cache_config.block_size + ) + self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full( + [int(decode_max_tile_size)], 0, dtype="int32" + ) + self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + # NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor, + # adapted to cudagraph. + self.share_inputs["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32") + self.share_inputs["decoder_chunk_size_device"] = paddle.full([1], 64, dtype="int32") + self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu() + + self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") + self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full( + [int(encode_max_tile_size)], 0, dtype="int32" + ) + self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + + self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") + self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") + self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( @@ -851,12 +974,38 @@ def _dummy_run( """ self._dummy_prefill_inputs(num_tokens, batch_size) + if self.speculative_method in ["mtp"]: + self.proposer.dummy_prefill_inputs( + num_tokens=num_tokens, + batch_size=batch_size, + expected_decode_len=1, + ) + while True: self.execute_model(is_dummy_run=True) if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0: break + def _init_speculative_proposer(self): + """ + Init speculative proposer + """ + if self.speculative_method == "ngram": + # xpu not support ngram proposer now + # self.proposer = NgramProposer(self.fd_config) + self.proposer = None + elif self.speculative_method == "mtp": + self.proposer = MTPProposer( + self.fd_config, + self.get_model(), + self.local_rank, + self.device_id, + self.share_inputs, + ) + else: + self.proposer = None + def _set_debug_level( self, debug_level: int = 0x1, model_forward_batch: Optional[List[Request]] = None, is_dummy_run: bool = False ) -> None: @@ -941,7 +1090,16 @@ class at the server level, which is too granular for ModelRunner. ) # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states) - sampler_output = self.sampler(logits, self.sampling_metadata) + sampler_output = None + if not self.speculative_decoding: + sampler_output = self.sampler(logits, self.sampling_metadata) + else: + self.sampler( + logits, + self.sampling_metadata, + self.model_config.max_model_len, + self.share_inputs, + ) # 5. Speculative decode @@ -961,26 +1119,36 @@ class at the server level, which is too granular for ModelRunner. seq_lens_decoder=self.share_inputs["seq_lens_decoder"], is_block_step=self.share_inputs["is_block_step"], # 投机解码 - full_hidden_states=None, + full_hidden_states=model_output if self.speculative_decoding else None, msg_queue_id=self.parallel_config.msg_queue_id, mp_rank=self.local_rank, use_ep=self.parallel_config.use_ep, - draft_tokens=None, - actual_draft_token_num=None, - accept_tokens=None, - accept_num=None, + draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), + actual_draft_token_num=( + self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None + ), + accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), + accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], ) - xpu_post_process_normal( - sampled_token_ids=sampler_output.sampled_token_ids, - model_output=model_output_data, - share_inputs=self.share_inputs, - block_size=self.cache_config.block_size, - skip_save_output=is_dummy_run, - think_end_id=self.model_config.think_end_id, - line_break_id=self.model_config.line_break_id, - ) + if self.speculative_decoding: + # base model post process + xpu_post_process_specualate(model_output_data, False, is_dummy_run) + else: + xpu_post_process_normal( + sampled_token_ids=sampler_output.sampled_token_ids, + model_output=model_output_data, + share_inputs=self.share_inputs, + block_size=self.cache_config.block_size, + skip_save_output=is_dummy_run, + think_end_id=self.model_config.think_end_id, + line_break_id=self.model_config.line_break_id, + ) + + # draft model propose + if self.speculative_method == "mtp": + self.proposer.run(full_hidden_states=model_output) # 7. Updata 'infer_seed' and step_paddle() self.share_inputs["infer_seed"].add_(self.infer_seed_increment) @@ -989,6 +1157,8 @@ class at the server level, which is too granular for ModelRunner. self.share_inputs, self.cache_config.block_size, self.cache_config.enc_dec_block_num, + self.speculative_decoding, + self.speculative_config.num_speculative_tokens, ) if self.pd_disaggregation_mode == "per_chunk" or self.pd_disaggregation_mode == "per_query": @@ -1013,6 +1183,9 @@ def profile_run(self) -> None: self.num_gpu_blocks = self.cache_config.total_block_num self.initialize_kv_cache(profile=True) + if self.speculative_method in ["mtp"]: + self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True) + self._dummy_run( num_tokens=int(self.scheduler_config.max_num_batched_tokens), batch_size=min(self.scheduler_config.max_num_seqs, 1), diff --git a/fastdeploy/worker/xpu_worker.py b/fastdeploy/worker/xpu_worker.py index 1bf2cde3fd8..9c7c0b2831a 100644 --- a/fastdeploy/worker/xpu_worker.py +++ b/fastdeploy/worker/xpu_worker.py @@ -167,9 +167,9 @@ def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: in and workers and modelrunners should not perceive it. """ if envs.ENABLE_V1_KVCACHE_SCHEDULER: - self.model_runner.insert_tasks_v1(req_dicts=req_dicts) + self.model_runner.insert_tasks_v1(req_dicts=req_dicts, num_running_requests=num_running_requests) else: - self.model_runner.insert_prefill_inputs(req_dicts=req_dicts) + self.model_runner.insert_prefill_inputs(req_dicts=req_dicts, num_running_requests=num_running_requests) def graph_optimize_and_warm_up_model(self) -> None: """ From d641c5d9e06405fb280a45c023a4cad1cd24167d Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Thu, 27 Nov 2025 11:10:33 +0000 Subject: [PATCH 17/17] fix code style --- fastdeploy/model_executor/xpu_pre_and_post_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/xpu_pre_and_post_process.py b/fastdeploy/model_executor/xpu_pre_and_post_process.py index dcfb751d2db..861b3b533a9 100644 --- a/fastdeploy/model_executor/xpu_pre_and_post_process.py +++ b/fastdeploy/model_executor/xpu_pre_and_post_process.py @@ -424,4 +424,4 @@ def step_xpu( share_inputs["first_token_ids"], block_size, enc_dec_block_num, - ) \ No newline at end of file + )