Skip to content
22 changes: 8 additions & 14 deletions custom_ops/xpu_ops/src/ops/gather_next_token.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ std::vector<paddle::Tensor> GatherNextToken(
const paddle::Tensor& encoder_batch_map_cpu,
const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& len_info_cpu,
const paddle::optional<paddle::Tensor>& output_padding_offset,
bool is_speculative,
int max_bsz) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
Expand Down Expand Up @@ -73,7 +73,7 @@ std::vector<paddle::Tensor> GatherNextToken(
const_cast<int32_t*>(decoder_batch_map.data<int32_t>())};

paddle::Tensor out;
if (output_padding_offset) {
if (is_speculative) {
int need_delete_token_num = 0;
if (enc_batch > 0) {
need_delete_token_num =
Expand All @@ -88,7 +88,7 @@ std::vector<paddle::Tensor> GatherNextToken(
return {out};
}

if (output_padding_offset) {
if (is_speculative) {
int r = fastdeploy::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
ctx,
reinterpret_cast<const XPUType*>(x.data<data_t>()),
Expand Down Expand Up @@ -124,14 +124,10 @@ std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
const std::vector<int64_t>& encoder_batch_map_cpu_shape,
const std::vector<int64_t>& decoder_batch_map_cpu_shape,
const std::vector<int64_t>& len_info_cpu_shape,
const paddle::optional<std::vector<int64_t>>& output_padding_offset_shape) {
// if (output_padding_offset_shape) {
// PD_THROW("speculative decoding is not supported in XPU.");
// }
// int64_t bsz = cum_offsets_shape[0];
bool is_speculative) {
int64_t bsz = 0;
int64_t dim_embed = x_shape[1];
if (output_padding_offset_shape) {
if (is_speculative) {
return {{-1, dim_embed}};
} else {
return {{bsz, dim_embed}};
Expand All @@ -148,8 +144,7 @@ std::vector<paddle::DataType> GatherNextTokenInferDtype(
const paddle::DataType& decoder_seq_lod_cpu_dtype,
const paddle::DataType& encoder_batch_map_cpu_dtype,
const paddle::DataType& decoder_batch_map_cpu_dtype,
const paddle::DataType& len_info_cpu_dtype,
const paddle::optional<paddle::DataType>& output_padding_offset_dtype) {
const paddle::DataType& len_info_cpu_dtype) {
return {x_dtype};
}

Expand All @@ -163,10 +158,9 @@ PD_BUILD_STATIC_OP(gather_next_token)
"decoder_seq_lod_cpu",
"encoder_batch_map_cpu",
"decoder_batch_map_cpu",
"len_info_cpu",
paddle::Optional("output_padding_offset")})
"len_info_cpu"})
.Outputs({"out"})
.Attrs({"max_bsz: int"})
.Attrs({"is_speculative: bool", "max_bsz: int"})
.SetKernelFn(PD_KERNEL(GatherNextToken))
.SetInferShapeFn(PD_INFER_SHAPE(GatherNextTokenInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype));
30 changes: 28 additions & 2 deletions custom_ops/xpu_ops/src/ops/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ std::vector<paddle::Tensor> GatherNextToken(
const paddle::Tensor& encoder_batch_map_cpu,
const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& len_info_cpu,
const paddle::optional<paddle::Tensor>& output_padding_offset,
bool is_speculative,
int max_bsz);

std::vector<paddle::Tensor> GetImgBoundaries(
Expand Down Expand Up @@ -945,7 +945,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("encoder_batch_map_cpu"),
py::arg("decoder_batch_map_cpu"),
py::arg("len_info_cpu"),
py::arg("output_padding_offset"),
py::arg("is_speculative"),
py::arg("max_bsz"),
"Gather next token for XPU");

Expand Down Expand Up @@ -1073,6 +1073,32 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("max_draft_tokens"),
"Unified update model status");
Comment thread
cmcamdy marked this conversation as resolved.

m.def("verify_draft_tokens",
&VerifyDraftTokens,
py::arg("step_output_ids"),
py::arg("step_output_len"),
py::arg("step_input_ids"),
py::arg("target_tokens"),
py::arg("candidate_ids"),
py::arg("candidate_scores"),
py::arg("candidate_lens"),
py::arg("topp"),
py::arg("stop_flags"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_this_time"),
py::arg("end_tokens"),
py::arg("is_block_step"),
py::arg("cu_seqlens_q_output"),
py::arg("reasoning_status"),
py::arg("max_dec_len"),
py::arg("step_idx"),
py::arg("max_seq_len"),
py::arg("verify_window"),
py::arg("verify_strategy"),
py::arg("reject_all"),
py::arg("accept_all"),
"Perform speculative verification for decoding v2");

m.def("mtp_step_paddle",
&MTPStepPaddle,
py::arg("base_model_stop_flags"),
Expand Down
1 change: 0 additions & 1 deletion custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,6 @@ DLL_EXPORT int speculate_limit_thinking_content_length_kernel(
const int eos_token_id_len,
const int inject_len,
const bool splitwise_role_is_decode);

DLL_EXPORT int verify_draft_tokens(
api::Context* ctx,
// Core I/O
Expand Down
17 changes: 6 additions & 11 deletions custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)


def _run_test_base(seq_lens_this_time_data, output_padding_offset):
def _run_test_base(seq_lens_this_time_data, is_speculative):
"""
通用的基础测试执行函数,包含了两个场景共有的逻辑。
"""
Expand Down Expand Up @@ -120,7 +120,7 @@ def _run_test_base(seq_lens_this_time_data, output_padding_offset):
encoder_batch_map_cpu,
decoder_batch_map_cpu,
len_info_cpu,
output_padding_offset,
is_speculative,
-1,
)
Comment thread
cmcamdy marked this conversation as resolved.

Expand All @@ -136,14 +136,14 @@ def _run_test_base(seq_lens_this_time_data, output_padding_offset):
encoder_batch_map_cpu,
decoder_batch_map_cpu,
len_info_cpu,
output_padding_offset,
is_speculative,
-1,
)

gather_out_np = gather_out.astype("float32").cpu().numpy()
gather_out_cpu_np = gather_out_cpu.astype("float32").cpu().numpy()

if output_padding_offset is not None:
if is_speculative:
np.testing.assert_allclose(gather_out_np, gather_out_cpu_np, err_msg="gather_next_token check failed!")
else:
for i in range(gather_out_cpu.shape[0]):
Expand All @@ -160,19 +160,14 @@ def test_mix_with_mtp(self):
"""测试混合批次处理中的 MTP (Multi-Token Prediction) 场景"""
print("\nRunning test: test_mix_with_mtp")
seq_lens_this_time_data = [100, 2, 0, 1, 120, 140, 3]
bsz = len(seq_lens_this_time_data)
output_padding_offset = paddle.zeros(bsz, dtype="int32")

_run_test_base(seq_lens_this_time_data, output_padding_offset)
_run_test_base(seq_lens_this_time_data, True)
print("Test passed for scenario: With MTP")

def test_mix_without_mtp(self):
"""测试非 MTP (Single-Token Prediction) 场景下的功能"""
print("\nRunning test: test_mix_without_mtp")
seq_lens_this_time_data = [100, 1, 0, 1, 120, 140, 1]
output_padding_offset = None # 非 MTP 场景下,此参数为 None

_run_test_base(seq_lens_this_time_data, output_padding_offset)
_run_test_base(seq_lens_this_time_data, False)
print("Test passed for scenario: Without MTP")


Expand Down
1 change: 1 addition & 0 deletions fastdeploy/model_executor/forward_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ class XPUForwardMeta(ForwardMeta):
hidden_states: Optional[paddle.Tensor] = None

is_draft: bool = False
is_speculative: bool = False
# max bs
max_num_seqs: int = 0

Expand Down
186 changes: 128 additions & 58 deletions fastdeploy/model_executor/layers/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,17 +1045,129 @@ def forward_cuda(
sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu()
return sampler_output

def forward_xpu(
def _normal_sample_xpu(
self,
logits: paddle.Tensor,
probs: paddle.Tensor,
sampling_metadata: SamplingMetadata,
share_inputs: List[paddle.Tensor],
) -> SamplerOutput:
Comment thread
cmcamdy marked this conversation as resolved.
Comment thread
cmcamdy marked this conversation as resolved.
"""Normal sampling for NAIVE mode on XPU."""
top_p, top_k, topp_seed = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.seed,
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
)
_, next_tokens = top_k_top_p_sampling(
probs,
top_p=top_p,
top_k=top_k,
top_k_list=sampling_metadata.top_k_list,
topp_seed=topp_seed,
)
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
running_mask = (paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]) > 0).cast("int32")
share_inputs["accept_tokens"][:real_bsz, 0] = next_tokens.squeeze(-1)
share_inputs["accept_num"][:real_bsz] = running_mask
return SamplerOutput(
sampled_token_ids=share_inputs["accept_tokens"],
logprobs_tensors=None,
token_num_per_batch=share_inputs["accept_num"],
logits=logits,
)

def _verify_and_sample_xpu(
self,
logits: paddle.Tensor,
probs: paddle.Tensor,
sampling_metadata: SamplingMetadata,
max_model_len: int,
share_inputs: List[paddle.Tensor],
accept_all_drafts: bool = False,
reject_all_drafts: bool = False,
) -> paddle.Tensor:
from fastdeploy.model_executor.ops.xpu import speculate_verify, top_p_candidates
) -> SamplerOutput:
"""Verify draft tokens (MTP/Ngram mode) on XPU using verify_draft_tokens."""
from fastdeploy.model_executor.ops.xpu import (
top_p_candidates,
verify_draft_tokens,
)

target_tokens = None
candidate_ids, candidate_scores, candidate_lens = None, None, None

if self.verify_strategy == VerifyStrategy.TARGET_MATCH:
top_p, top_k, topp_seed = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.seed,
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
)
_, target_tokens = top_k_top_p_sampling(
probs,
top_p=top_p,
top_k=top_k,
top_k_list=sampling_metadata.top_k_list,
topp_seed=topp_seed,
)
elif self.verify_strategy == VerifyStrategy.GREEDY:
target_tokens = paddle.argmax(probs, axis=-1)
elif self.verify_strategy == VerifyStrategy.TOPP:
candidate_scores, candidate_ids, candidate_lens = top_p_candidates(
probs,
sampling_metadata.top_p,
share_inputs["batch_id_per_token_output"],
self.speculative_max_candidate_len,
max_model_len,
)
else:
raise ValueError(f"Unknown verify strategy: {self.verify_strategy}")

final_accept_all = self.config_accept_all or accept_all_drafts
final_reject_all = self.config_reject_all or reject_all_drafts or self.speculative_benchmark_mode

verify_draft_tokens(
share_inputs["accept_tokens"],
share_inputs["accept_num"],
share_inputs["draft_tokens"],
target_tokens,
candidate_ids,
candidate_scores,
candidate_lens,
sampling_metadata.top_p,
share_inputs["stop_flags"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_this_time"],
sampling_metadata.eos_token_ids,
share_inputs["is_block_step"],
share_inputs["cu_seqlens_q_output"],
share_inputs["reasoning_status"],
share_inputs["max_dec_len"],
share_inputs["step_idx"],
max_model_len,
self.speculative_verify_window,
self.verify_strategy.value,
final_reject_all,
final_accept_all,
)
return SamplerOutput(
sampled_token_ids=share_inputs["accept_tokens"],
logprobs_tensors=None,
token_num_per_batch=share_inputs["accept_num"],
logits=logits,
)

def forward_xpu(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
max_model_len: int,
share_inputs: List[paddle.Tensor],
accept_all_drafts: bool = False,
reject_all_drafts: bool = False,
) -> SamplerOutput:
logits = apply_speculative_penalty_multi_scores(
sampling_metadata.token_ids_all,
sampling_metadata.prompt_lens,
Expand All @@ -1078,61 +1190,19 @@ def forward_xpu(

probs = F.softmax(logits)

top_p, top_k, topp_seed = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.seed,
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
)
_, sampled_token_ids = top_k_top_p_sampling(
probs, top_p=top_p, top_k=top_k, top_k_list=sampling_metadata.top_k_list, topp_seed=topp_seed
)

verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
probs,
sampling_metadata.top_p,
share_inputs["batch_id_per_token_output"],
self.speculative_max_candidate_len,
max_model_len,
)

speculate_verify(
sampled_token_ids,
share_inputs["accept_tokens"],
share_inputs["accept_num"],
share_inputs["step_idx"],
share_inputs["stop_flags"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs[
"draft_tokens"
], # Both input and output, need to write the last 1 token accepted to position 0.
share_inputs["seq_lens_this_time"],
verify_tokens,
verify_scores,
share_inputs["max_dec_len"],
sampling_metadata.eos_token_ids,
share_inputs["is_block_step"],
share_inputs["cu_seqlens_q_output"],
actual_candidate_len,
share_inputs["actual_draft_token_num"],
sampling_metadata.top_p,
max_model_len,
self.speculative_verify_window,
True, # enable_topp
(self.speculative_benchmark_mode or reject_all_drafts),
accept_all_drafts,
)
# TODO(chenhuan09): support return logprobs
token_ids = share_inputs["accept_tokens"]
sampler_output = SamplerOutput(
sampled_token_ids=token_ids,
logprobs_tensors=None,
token_num_per_batch=share_inputs["accept_num"],
cu_batch_token_offset=None,
)
return sampler_output
is_naive = self.spec_method is None or self.spec_method == SpecMethod.NAIVE
if is_naive:
return self._normal_sample_xpu(logits, probs, sampling_metadata, share_inputs)
else:
return self._verify_and_sample_xpu(
logits,
probs,
sampling_metadata,
max_model_len,
share_inputs,
accept_all_drafts,
reject_all_drafts,
)


class MTPSampler(nn.Layer):
Expand Down
Loading
Loading