Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/usage/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,8 @@ environment_variables: dict[str, Callable[[], Any]] = {

# Count for cache_transfer_manager process error
"FD_CACHE_PROC_ERROR_COUNT": lambda: int(os.getenv("FD_CACHE_PROC_ERROR_COUNT", "10")),

# Whether to force the inference engine to synchronize token_ids sampled by TP groups.
"FD_SYNC_TOKEN_IDS_ACROSS_TP": lambda: bool(int(os.getenv("FD_SYNC_TOKEN_IDS_ACROSS_TP", "0"))),
}
```
6 changes: 5 additions & 1 deletion docs/zh/usage/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_CACHE_PROC_EXIT_TIMEOUT": lambda: int(os.getenv("FD_CACHE_PROC_EXIT_TIMEOUT", "600")),

# cache_transfer_manager 进程残留时连续错误阈值
"FD_CACHE_PROC_ERROR_COUNT": lambda: int(os.getenv("FD_CACHE_PROC_ERROR_COUNT", "10")),}
"FD_CACHE_PROC_ERROR_COUNT": lambda: int(os.getenv("FD_CACHE_PROC_ERROR_COUNT", "10")),

# 是否强制推理引擎同步TP组采样到的 token_ids, 默认不同步
"FD_SYNC_TOKEN_IDS_ACROSS_TP": lambda: bool(int(os.getenv("FD_SYNC_TOKEN_IDS_ACROSS_TP", "0"))),
}
```
1 change: 1 addition & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
"FD_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS": lambda: int(os.getenv("FD_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS", "500")),
"FD_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE": lambda: int(os.getenv("FD_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE", "64")),
"FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT": lambda: int(os.getenv("FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT", "120")),
"FD_SYNC_TOKEN_IDS_ACROSS_TP": lambda: bool(int(os.getenv("FD_SYNC_TOKEN_IDS_ACROSS_TP", "0"))),
}


Expand Down
10 changes: 8 additions & 2 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1854,7 +1854,10 @@ def _dummy_sampler_run(
self.share_inputs["stop_flags"],
)
sampler_output = self.sampler(logits, self.sampling_metadata)
if self.parallel_config.tensor_parallel_size > 1:
if (
envs.FD_SYNC_TOKEN_IDS_ACROSS_TP
or self.fd_config.structured_outputs_config.guided_decoding_backend != "off"
) and self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(
sampler_output.sampled_token_ids,
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
Expand Down Expand Up @@ -2447,7 +2450,10 @@ class at the server level, which is too granular for ModelRunner.
[sampler_output.sampled_token_ids.shape[0]], device="cpu", dtype="int64"
),
)
if self.parallel_config.tensor_parallel_size > 1:
if (
envs.FD_SYNC_TOKEN_IDS_ACROSS_TP
or self.fd_config.structured_outputs_config.guided_decoding_backend != "off"
) and self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(
sampler_output.sampled_token_ids,
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
Expand Down
9 changes: 4 additions & 5 deletions fastdeploy/worker/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,7 @@ def _broadcast_model_weights_signal(self, src: int, group) -> int:
return model_weights_signal_tensor.item()

def _tp_barrier_wait(self):
if current_platform.is_xpu():
self.task_queue.worker_process_tp_barrier.wait()
else:
paddle.distributed.barrier(self.parallel_config.tp_group)
self.task_queue.worker_process_tp_barrier.wait()

def _init_eplb_signal(self):
if not self.eplb_config.enable_eplb:
Expand Down Expand Up @@ -478,7 +475,9 @@ def event_loop_normal(self) -> None:
time.sleep(0.01)
continue

if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or (
self.nnode > 1 and self.task_queue.read_finish_flag.get() == 1
):
logger.info(f"Rank: {self.local_rank} Detected new requests.")

tasks, read_finish = self.task_queue.get_tasks()
Expand Down
Loading