diff --git a/custom_ops/xpu_ops/src/ops/gather_next_token.cc b/custom_ops/xpu_ops/src/ops/gather_next_token.cc index 9a35f91f9c8..88a073fcfa6 100644 --- a/custom_ops/xpu_ops/src/ops/gather_next_token.cc +++ b/custom_ops/xpu_ops/src/ops/gather_next_token.cc @@ -89,32 +89,27 @@ std::vector GatherNextToken( return {out}; } - if (enc_batch <= 0) { - out = x.copy_to(x.place(), false); + if (output_padding_offset) { + int r = baidu::xpu::api::plugin::eb_mtp_gather_next_token( + ctx, + reinterpret_cast(x.data()), + reinterpret_cast(out.data()), + encoder_seqs_lods_vp, + decoder_seqs_lods_vp, + encoder_batch_map_vp, + decoder_batch_map_vp, + dim); + PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed."); } else { - if (output_padding_offset) { - int r = - baidu::xpu::api::plugin::eb_mtp_gather_next_token( - ctx, - reinterpret_cast(x.data()), - reinterpret_cast(out.data()), - encoder_seqs_lods_vp, - decoder_seqs_lods_vp, - encoder_batch_map_vp, - decoder_batch_map_vp, - dim); - PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed."); - } else { - int r = baidu::xpu::api::plugin::eb_gather_next_token( - ctx, - reinterpret_cast(x.data()), - reinterpret_cast(out.data()), - encoder_seqs_lods_vp, - encoder_batch_map_vp, - decoder_batch_map_vp, - dim); - PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed."); - } + int r = baidu::xpu::api::plugin::eb_gather_next_token( + ctx, + reinterpret_cast(x.data()), + reinterpret_cast(out.data()), + encoder_seqs_lods_vp, + encoder_batch_map_vp, + decoder_batch_map_vp, + dim); + PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed."); } return {out}; }