Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ void cuda_host_free(uintptr_t ptr) {
check_cuda_error(cudaFreeHost(reinterpret_cast<void*>(ptr)));
}

paddle::Tensor GetStop(paddle::Tensor& not_need_stop);

void SetStop(paddle::Tensor& not_need_stop, bool flag);

void FlashAttentionMask(const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
Expand Down Expand Up @@ -437,7 +441,7 @@ void GetStopFlagsMulti(const paddle::Tensor& topk_ids,
const bool beam_search);

void UpdateInputs(const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop, // only on cpu
const paddle::Tensor& not_need_stop, // on device
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
Expand All @@ -446,7 +450,7 @@ void UpdateInputs(const paddle::Tensor& stop_flags,
const paddle::Tensor& is_block_step);

void UpdateInputsV1(const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop, // only on cpu
const paddle::Tensor& not_need_stop, // on device
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
Expand Down Expand Up @@ -1711,4 +1715,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("reasoning_phase_token_constraint",
&ReasoningPhaseTokenConstraint,
"reasoning_phase_token_constraint function");

m.def("get_stop", &GetStop, "get_stop function");

m.def("set_stop", &SetStop, "set_stop function");
}
30 changes: 30 additions & 0 deletions custom_ops/gpu_ops/set_stop.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"

paddle::Tensor GetStop(paddle::Tensor& not_need_stop) {
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个后续反馈给框架同学修一下

auto not_need_stop_cpu =
GetEmptyTensor({1}, paddle::DataType::BOOL, paddle::CPUPlace());
bool* not_need_stop_cpu_data =
const_cast<bool*>(not_need_stop_cpu.data<bool>());
not_need_stop_cpu_data[0] = not_need_stop_data[0];
return not_need_stop_cpu;
}

void SetStop(paddle::Tensor& not_need_stop, bool flag) {
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = flag;
}
9 changes: 2 additions & 7 deletions custom_ops/gpu_ops/update_inputs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ __global__ void update_inputs_kernel(bool* not_need_stop,
}

void UpdateInputs(const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop, // only on cpu
const paddle::Tensor& not_need_stop, // on gpu
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
Expand All @@ -86,9 +86,8 @@ void UpdateInputs(const paddle::Tensor& stop_flags,
const int max_bsz = stop_flags.shape()[0];
const int now_bsz = seq_lens_this_time.shape()[0];
const int input_ids_stride = input_ids.shape()[1];
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
update_inputs_kernel<1024><<<1, 1024, 0, cu_stream>>>(
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<bool*>(not_need_stop.data<bool>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
Expand All @@ -99,10 +98,6 @@ void UpdateInputs(const paddle::Tensor& stop_flags,
now_bsz,
max_bsz,
input_ids_stride);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}

PD_BUILD_STATIC_OP(update_inputs)
Expand Down
9 changes: 2 additions & 7 deletions custom_ops/gpu_ops/update_inputs_v1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ __global__ void update_inputs_kernel_v1(bool* not_need_stop,
}

void UpdateInputsV1(const paddle::Tensor& stop_flags,
const paddle::Tensor& not_need_stop, // only on cpu
const paddle::Tensor& not_need_stop, // on gpu
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
Expand Down Expand Up @@ -137,9 +137,8 @@ void UpdateInputsV1(const paddle::Tensor& stop_flags,
const int now_bsz = seq_lens_this_time.shape()[0];
const int input_ids_stride = input_ids.shape()[1];
const int block_num_per_seq = block_tables.shape()[1];
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
update_inputs_kernel_v1<1024><<<1, 1024, 0, cu_stream>>>(
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<bool*>(not_need_stop.data<bool>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
Expand All @@ -157,10 +156,6 @@ void UpdateInputsV1(const paddle::Tensor& stop_flags,
block_num_per_seq,
block_size,
prefill_one_step_stop);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}

PD_BUILD_STATIC_OP(update_inputs_v1)
Expand Down
1 change: 1 addition & 0 deletions custom_ops/setup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def find_end_files(directory, end_str):
"gpu_ops/stop_generation.cu",
"gpu_ops/stop_generation_multi_ends.cu",
"gpu_ops/set_flags.cu",
"gpu_ops/set_stop.cu",
"gpu_ops/update_inputs_v1.cu",
"gpu_ops/recover_decode_task.cu",
"gpu_ops/step.cu",
Expand Down
81 changes: 43 additions & 38 deletions fastdeploy/model_executor/pre_and_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,6 @@ def post_process_normal(
share_inputs: Dict[str, paddle.Tensor],
sampling_metadata: SamplingMetadata,
block_size: int = 64,
save_each_rank: bool = False,
skip_save_output: bool = False,
async_output_queue: queue.Queue = None,
think_end_id: int = -1,
line_break_id: int = -1,
enable_entropy: bool = False,
Expand Down Expand Up @@ -388,7 +385,7 @@ def post_process_normal(
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
update_inputs_v1(
model_output.stop_flags,
model_output.not_need_stop,
model_output.not_need_stop_device,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
Expand All @@ -404,44 +401,53 @@ def post_process_normal(
else:
update_inputs(
model_output.stop_flags,
model_output.not_need_stop,
model_output.not_need_stop_device,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.input_ids,
sampler_output.sampled_token_ids,
model_output.is_block_step,
)
# 3. Transmit the model's output and stop generation signal via message queue.
# In the future, we will abandon this approach.
if not skip_save_output:
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
if save_each_rank or model_output.mp_rank == 0:
output = _build_stream_transfer_data(
sampler_output.sampled_token_ids,
logprobs=sampler_output.logprobs_tensors,
prompt_logprobs_list=model_output.prompt_logprobs_list,
)
async_output_queue.put(output)


def save_output_normal(
model_output: ModelOutputData,
sampler_output: SamplerOutput,
share_inputs: Dict[str, paddle.Tensor],
async_output_queue: queue.Queue = None,
save_each_rank: bool = False,
):
# Transmit the model's output and stop generation signal via message queue.
# In the future, we will abandon this approach.
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
if save_each_rank or model_output.mp_rank == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是走 V1,这里会有同步问题吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是走 V1,这里会有同步问题吗?

还没测,如果是CPU上的操作理论上不会有同步

output = _build_stream_transfer_data(
sampler_output.sampled_token_ids,
logprobs=sampler_output.logprobs_tensors,
prompt_logprobs_list=model_output.prompt_logprobs_list,
)
async_output_queue.put(output)
else:
if sampler_output.logprobs_tensors is None:
save_output(
share_inputs["sampled_token_ids"],
model_output.not_need_stop,
share_inputs["preempted_idx"],
model_output.mp_rank,
save_each_rank,
)
else:
if sampler_output.logprobs_tensors is None:
save_output(
sampler_output.sampled_token_ids,
model_output.not_need_stop,
share_inputs["preempted_idx"],
model_output.mp_rank,
save_each_rank,
)
else:
save_output_topk(
sampler_output.sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
model_output.not_need_stop,
share_inputs["preempted_idx"],
model_output.mp_rank,
)
save_output_topk(
sampler_output.sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
model_output.not_need_stop,
share_inputs["preempted_idx"],
model_output.mp_rank,
)
share_inputs["preempted_idx"][:] = 0


def post_process_specualate(
Expand Down Expand Up @@ -540,6 +546,7 @@ def post_process_specualate(
model_output.seq_lens_decoder,
model_output.step_idx,
)
share_inputs["preempted_idx"][:] = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是因为什么加的?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是因为什么加的?

只是换了个位置,这个操作之前是在post_process最后执行的,现在save_output提取出来了就跟着放在之后,不然会影响调度抢占的逻辑



def post_process(
Expand Down Expand Up @@ -588,14 +595,10 @@ def post_process(
share_inputs,
sampling_metadata,
block_size,
save_each_rank,
skip_save_output,
async_output_queue,
think_end_id,
line_break_id,
enable_entropy,
)
share_inputs["preempted_idx"][:] = 0


def step_cuda(
Expand Down Expand Up @@ -936,3 +939,5 @@ def post_process_pooling(
if save_each_rank or model_output.mp_rank == 0:
output = _build_stream_transfer_data(output_tokens=None, pooler_outputs=pooler_output.outputs)
async_output_queue.put(output)

share_inputs["preempted_idx"][:] = 0
Loading
Loading