diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 2870b981655..37eadd76872 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -229,8 +229,15 @@ def __init__( self.think_end_id = args.get("think_end_id", -1) self.im_patch_id = args.get("image_patch_id", -1) self.line_break_id = args.get("line_break_id", -1) - if self.max_logprobs < -1: + + num_max_logprobs = args.get("max_logprobs", None) + if num_max_logprobs is not None and num_max_logprobs < -1: raise ValueError(" The possible values for max_logprobs can't be less than -1 ") + if self.ori_vocab_size is not None and num_max_logprobs is not None: + if num_max_logprobs > self.ori_vocab_size: + raise ValueError( + f" The possible values for max_logprobs can't be greater than the vocabulary size {self.ori_vocab_size}" + ) self._post_init() diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 0bc8faf2330..2fe8e1317cb 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -31,12 +31,7 @@ from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.entrypoints.openai.protocol import ToolCall from fastdeploy.utils import data_processor_logger -from fastdeploy.worker.output import ( - LogprobsLists, - LogprobsTensors, - PromptLogprobs, - SampleLogprobs, -) +from fastdeploy.worker.output import LogprobsLists, PromptLogprobs, SampleLogprobs class RequestStatus(Enum): @@ -519,7 +514,6 @@ def __init__( prompt: Optional[str] = None, prompt_token_ids: Optional[list[int]] = None, prompt_logprobs: Optional[PromptLogprobs] = None, - prompt_logprobs_tensors: Optional[LogprobsTensors] = None, output_type: Optional[int] = 3, outputs: CompletionOutput = None, finished: bool = False, @@ -537,7 +531,6 @@ def __init__( self.prompt = prompt self.prompt_token_ids = prompt_token_ids self.prompt_logprobs = prompt_logprobs - self.prompt_logprobs_tensors = prompt_logprobs_tensors self.output_type = output_type self.outputs = outputs self.finished = finished diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index 0b0ae5f807d..908ba652100 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -16,12 +16,13 @@ from __future__ import annotations -import os import random from dataclasses import dataclass, fields from enum import Enum from typing import Any, List, Optional, Union +from fastdeploy import envs + @dataclass class SamplingParams: @@ -207,12 +208,17 @@ def _verify_args(self) -> None: raise ValueError( f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}." ) - if self.logprobs is not None and self.logprobs < -1: - raise ValueError(f"logprobs must be greater than -1, got {self.logprobs}.") - if self.logprobs is not None and self.logprobs > 20 and os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0") == "0": - raise ValueError("Invalid value for 'top_logprobs': must be less than or equal to 20.") - if self.prompt_logprobs is not None and self.prompt_logprobs < -1: - raise ValueError(f"prompt_logprobs must be greater than or equal to -1, got {self.prompt_logprobs}.") + + if not envs.FD_USE_GET_SAVE_OUTPUT_V1: # False (0) + if self.logprobs is not None and (self.logprobs < 0 or self.logprobs > 20): + raise ValueError("Invalid value for 'top_logprobs': must be between 0 and 20.") + if self.prompt_logprobs is not None: + raise ValueError("prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled.") + else: # True (1) + if self.logprobs is not None and self.logprobs < -1: + raise ValueError(f"logprobs must be a non-negative value or -1, got {self.logprobs}.") + if self.prompt_logprobs is not None and self.prompt_logprobs < -1: + raise ValueError(f"prompt_logprobs a must be non-negative value or -1, got {self.prompt_logprobs}.") if not 0 <= self.seed <= 922337203685477580: raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.") diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index e87b650f514..78918314509 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -56,10 +56,11 @@ class EngineClient: EngineClient is a class that handles the communication between the client and the server. """ - def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers: int = 1): + def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers: int = 1, max_logprobs: int = 20): self.fd_config = fd_config self.tensor_parallel_size = self.fd_config.parallel_config.tensor_parallel_size self.enable_mm = self.fd_config.model_config.enable_mm + self.max_logprobs = max_logprobs input_processor = InputPreprocessor( self.fd_config.model_config, self.fd_config.structured_outputs_config.reasoning_parser, @@ -70,6 +71,11 @@ def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers ) self.enable_logprob = self.fd_config.model_config.enable_logprob self.data_processor = input_processor.create_processor() + self.ori_vocab_size = ( + len(self.data_processor.tokenizer.sp_model) + if hasattr(self.data_processor.tokenizer, "sp_model") + else len(self.data_processor.tokenizer.vocab) + ) self.max_model_len = self.fd_config.model_config.max_model_len self.enable_prefix_caching = self.fd_config.cache_config.enable_prefix_caching self.enable_splitwise = self.fd_config.scheduler_config.splitwise_role != "mixed" @@ -424,6 +430,53 @@ def valid_parameters(self, data): elif logprobs: raise ParameterError("logprobs", "Invalid type for 'logprobs'") + max_logprobs = self.max_logprobs + if max_logprobs == -1: + max_logprobs = self.ori_vocab_size + if max_logprobs < -1: + err_msg = f"Invalid 'max_logprobs': must be >= -1, got {max_logprobs}." + api_server_logger.error(err_msg) + raise ValueError("max_logprobs", err_msg) + if max_logprobs > self.ori_vocab_size: + err_msg = f"Invalid 'max_logprobs': must be <= vocab_size {self.ori_vocab_size}, got {max_logprobs}." + api_server_logger.error(err_msg) + raise ValueError("max_logprobs", err_msg) + + prompt_logprobs = data.get("prompt_logprobs", None) + + if prompt_logprobs is not None: + if not self.enable_logprob: + err_msg = "`enable_logprob` is disabled, please enable it in startup config." + api_server_logger.error(err_msg) + raise ParameterError("prompt_logprobs", err_msg) + + if not envs.FD_USE_GET_SAVE_OUTPUT_V1: + err_msg = "prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled." + api_server_logger.error(err_msg) + raise ParameterError("prompt_logprobs", err_msg) + + if self.enable_prefix_caching: + err_msg = "prompt_logprobs is not support when prefix caching is enabled." + api_server_logger.error(err_msg) + raise ParameterError("prompt_logprobs", err_msg) + + if prompt_logprobs == -1 and self.ori_vocab_size > max_logprobs: + err_msg = f"The requested value of ({self.ori_vocab_size}) for prompt_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})" + api_server_logger.error(err_msg) + raise ValueError("prompt_logprobs", err_msg) + + if prompt_logprobs < -1: + err_msg = ( + f"prompt_logprobs must be a non-negative value or -1; the current value is {prompt_logprobs}." + ) + api_server_logger.error(err_msg) + raise ValueError("prompt_logprobs", err_msg) + + if prompt_logprobs > max_logprobs: + err_msg = f"Number of prompt_logprobs requested ({prompt_logprobs}) exceeds maximum allowed value ({max_logprobs})." + api_server_logger.error(err_msg) + raise ValueError("prompt_logprobs", err_msg) + # enable_logprob if top_logprobs: if not self.enable_logprob: @@ -437,15 +490,26 @@ def valid_parameters(self, data): api_server_logger.error(err_msg) raise ParameterError("top_logprobs", err_msg) - if top_logprobs < 0: - err_msg = f"Invalid 'top_logprobs': must be >= 0, got {top_logprobs}." - api_server_logger.error(err_msg) - raise ParameterError("top_logprobs", err_msg) - - if top_logprobs > 20: - err_msg = "Invalid value for 'top_logprobs': must be <= 20." - api_server_logger.error(err_msg) - raise ParameterError("top_logprobs", err_msg) + if not envs.FD_USE_GET_SAVE_OUTPUT_V1: + if top_logprobs < 0 or top_logprobs > 20: + err_msg = f"top_logprobs must be between 0 and 20; the current value is {top_logprobs}." + api_server_logger.error(err_msg) + raise ValueError("top_logprobs", err_msg) + else: + if top_logprobs == -1 and self.ori_vocab_size > max_logprobs: + err_msg = f"The requested value of ({self.ori_vocab_size}) for top_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})" + api_server_logger.error(err_msg) + raise ValueError("top_logprobs", err_msg) + + if top_logprobs < -1: + err_msg = f"top_logprobs must be a non-negative value or -1; the current value is {top_logprobs}." + api_server_logger.error(err_msg) + raise ValueError("top_logprobs", err_msg) + + if top_logprobs > max_logprobs: + err_msg = f"Number of logprobs requested ({top_logprobs}) exceeds maximum allowed value ({max_logprobs})." + api_server_logger.error(err_msg) + raise ValueError("top_logprobs", err_msg) def check_health(self, time_interval_threashold=30): """ diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py index adf7d1cd7de..462e3c5950f 100644 --- a/fastdeploy/entrypoints/llm.py +++ b/fastdeploy/entrypoints/llm.py @@ -335,23 +335,40 @@ def _add_request( current_sampling_params = sampling_params[i] else: current_sampling_params = sampling_params - if kwargs.get("stream") and current_sampling_params.prompt_logprobs is not None: - raise ValueError("prompt_logprobs is not supported with streaming.") + + ori_vocab_size = ( + len(self.llm_engine.data_processor.tokenizer.sp_model) + if hasattr(self.llm_engine.data_processor.tokenizer, "sp_model") + else len(self.llm_engine.data_processor.tokenizer.vocab) + ) max_logprobs = self.llm_engine.cfg.model_config.max_logprobs if max_logprobs == -1: - max_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size + max_logprobs = ori_vocab_size + if max_logprobs < -1: + raise ValueError(f"max_logprobs ({max_logprobs}) can't be less than -1.") + if max_logprobs > ori_vocab_size: + raise ValueError(f"max_logprobs ({max_logprobs}) exceeds vocabulary size ({ori_vocab_size}).") + if current_sampling_params.logprobs is not None: num_logprobs = current_sampling_params.logprobs - if num_logprobs == -1: - num_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size + if num_logprobs == -1 and ori_vocab_size > max_logprobs: + raise ValueError( + f"Number of logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})." + ) if num_logprobs > max_logprobs: raise ValueError( f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})." ) if current_sampling_params.prompt_logprobs is not None: + if self.llm_engine.cfg.cache_config.enable_prefix_caching: + raise ValueError("prompt_logprobs is not supported with prefix caching enabled.") + if kwargs.get("stream"): + raise ValueError("prompt_logprobs is not supported with streaming.") num_prompt_logprobs = current_sampling_params.prompt_logprobs - if num_prompt_logprobs == -1: - num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size + if num_prompt_logprobs == -1 and ori_vocab_size > max_logprobs: + raise ValueError( + f"Number of prompt_logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})." + ) if num_prompt_logprobs > max_logprobs: raise ValueError( f"Number of logprobs requested ({num_prompt_logprobs}) exceeds maximum allowed value ({max_logprobs})." @@ -436,7 +453,7 @@ def _build_prompt_logprobs( prompt_token_ranks = ranks.tolist() prompt_logprobs = logprobs.tolist() token_ids = token_ids.tolist() - result: Optional[PromptLogprobs] = [] + result: Optional[PromptLogprobs] = [None] # Make Logprob for each position. for pos in range(num_prompt_tokens): # Handle flattening. @@ -548,11 +565,11 @@ def _run_engine( result.outputs.logprobs = self._build_sample_logprobs( result.outputs.top_logprobs, topk_logprobs ) - if result.prompt_logprobs_tensors and num_prompt_logprobs: + if result.prompt_logprobs is not None and num_prompt_logprobs is not None: if num_prompt_logprobs == -1: num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size result.prompt_logprobs = self._build_prompt_logprobs( - result.prompt_logprobs_tensors, num_prompt_logprobs + result.prompt_logprobs, num_prompt_logprobs ) output[pos] = result diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 5d495b881df..2eb62a6ff84 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -181,6 +181,7 @@ async def lifespan(app: FastAPI): port=int(os.environ.get("INFERENCE_MSG_QUEUE_ID", "0")), fd_config=fd_config, workers=args.workers, + max_logprobs=args.max_logprobs, ) await engine_client.connection_manager.initialize() app.state.dynamic_load_weight = args.dynamic_load_weight diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 339164fca1d..5d242d354ea 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -21,9 +21,17 @@ import uuid from typing import Annotated, Any, Dict, List, Literal, Optional, Union -from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationInfo, + field_validator, + model_validator, +) from fastdeploy.engine.pooling_params import PoolingParams +from fastdeploy.worker.output import PromptLogprobs class InvalidParameterException(Exception): @@ -214,10 +222,12 @@ class ChatCompletionResponseChoice(BaseModel): Chat completion response choice. """ + model_config = ConfigDict(arbitrary_types_allowed=True) index: int message: ChatMessage logprobs: Optional[LogProbs] = None draft_logprobs: Optional[LogProbs] = None + prompt_logprobs: Optional[PromptLogprobs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] @@ -275,10 +285,12 @@ class ChatCompletionResponseStreamChoice(BaseModel): Chat completion response choice for stream response. """ + model_config = ConfigDict(arbitrary_types_allowed=True) index: int delta: DeltaMessage logprobs: Optional[LogProbs] = None draft_logprobs: Optional[LogProbs] = None + prompt_logprobs: Optional[PromptLogprobs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None arrival_time: Optional[float] = None @@ -301,6 +313,7 @@ class CompletionResponseChoice(BaseModel): Completion response choice. """ + model_config = ConfigDict(arbitrary_types_allowed=True) index: int text: str prompt_token_ids: Optional[List[int]] = None @@ -310,6 +323,7 @@ class CompletionResponseChoice(BaseModel): arrival_time: Optional[float] = None logprobs: Optional[CompletionLogprobs] = None draft_logprobs: Optional[CompletionLogprobs] = None + prompt_logprobs: Optional[PromptLogprobs] = None reasoning_content: Optional[str] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None @@ -344,11 +358,13 @@ class CompletionResponseStreamChoice(BaseModel): Completion response choice for stream response. """ + model_config = ConfigDict(arbitrary_types_allowed=True) index: int text: str arrival_time: float = None logprobs: Optional[CompletionLogprobs] = None draft_logprobs: Optional[CompletionLogprobs] = None + prompt_logprobs: Optional[PromptLogprobs] = None prompt_token_ids: Optional[List[int]] = None completion_token_ids: Optional[List[int]] = None prompt_tokens: Optional[str] = None @@ -437,6 +453,7 @@ class CompletionRequest(BaseModel): frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2) logprobs: Optional[int] = None include_draft_logprobs: Optional[bool] = False + prompt_logprobs: Optional[int] = None # For logits and logprobs post processing temp_scaled_logprobs: bool = False top_p_normalized_logprobs: bool = False @@ -569,6 +586,18 @@ def validate_stream_options(cls, data): return data + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if (logprobs := data.get("logprobs")) is not None: + if logprobs < -1: + raise ValueError("`logprobs` must be a greater than -1.") + + if (prompt_logprobs := data.get("prompt_logprobs")) is not None: + if prompt_logprobs < -1: + raise ValueError("`prompt_logprobs` must be a greater than -1.") + return data + class ChatCompletionRequest(BaseModel): """ @@ -583,6 +612,7 @@ class ChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = Field(None, le=2, ge=-2) logprobs: Optional[bool] = False top_logprobs: Optional[int] = 0 + prompt_logprobs: Optional[int] = None include_draft_logprobs: Optional[bool] = False # For logits and logprobs post processing @@ -651,6 +681,7 @@ def to_dict_for_infer(self, request_id=None): req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens req_dict["logprobs"] = self.top_logprobs if self.logprobs else None + req_dict["prompt_logprobs"] = self.prompt_logprobs req_dict["temp_scaled_logprobs"] = self.temp_scaled_logprobs req_dict["top_p_normalized_logprobs"] = self.top_p_normalized_logprobs @@ -751,12 +782,15 @@ def validate_stream_options(cls, data): def check_logprobs(cls, data): if (top_logprobs := data.get("top_logprobs")) is not None: - if top_logprobs < 0: - raise ValueError("`top_logprobs` must be a positive value.") + if top_logprobs < -1: + raise ValueError("`top_logprobs` must be a greater than -1.") - if top_logprobs > 0 and not data.get("logprobs"): + if not data.get("logprobs"): raise ValueError("when using `top_logprobs`, `logprobs` must be set to true.") + if (prompt_logprobs := data.get("prompt_logprobs")) is not None: + if prompt_logprobs < -1: + raise ValueError("`prompt_logprobs` must be a greater than -1.") return data diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index b5e789c9cfa..9bb15f90942 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -15,9 +15,11 @@ """ import asyncio +import itertools import time import traceback import uuid +from collections.abc import Iterable from typing import List, Optional import numpy as np @@ -47,9 +49,17 @@ ErrorType, ParameterError, api_server_logger, + clamp_prompt_logprobs, get_host_ip, ) -from fastdeploy.worker.output import LogprobsLists +from fastdeploy.worker.output import ( + Logprob, + LogprobsLists, + LogprobsTensors, + PromptLogprobs, +) + +NONES = itertools.repeat(None) class OpenAIServingChat: @@ -287,6 +297,17 @@ async def chat_completion_stream_generator( num_input_image_tokens = res.get("num_input_image_tokens", 0) num_input_video_tokens = res.get("num_input_video_tokens", 0) for i in range(num_choices): + prompt_logprobs_res: Optional[PromptLogprobs] = None + prompt_logprobs_tensors = res.get("prompt_logprobs", None) + if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None: + num_prompt_logprobs = ( + request.prompt_logprobs + if request.prompt_logprobs != -1 + else self.engine_client.ori_vocab_size + ) + prompt_logprobs_res = self._build_prompt_logprobs( + prompt_logprobs_tensors, num_prompt_logprobs + ) choice = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage( @@ -296,6 +317,7 @@ async def chat_completion_stream_generator( prompt_token_ids=None, completion_token_ids=None, ), + prompt_logprobs=clamp_prompt_logprobs(prompt_logprobs_res), ) if response_processor.enable_multimodal_content(): choice.delta.multimodal_content = [ @@ -344,12 +366,16 @@ async def chat_completion_stream_generator( logprobs_res: Optional[LogProbs] = None draft_logprobs_res: Optional[LogProbs] = None if request.logprobs and output_top_logprobs is not None: + num_top_logprobs = ( + request.top_logprobs if request.top_logprobs != -1 else self.engine_client.ori_vocab_size + ) logprobs_res = self._create_chat_logprobs( - output_top_logprobs, request.logprobs, request.top_logprobs + output_top_logprobs, request.logprobs, num_top_logprobs ) + if request.include_draft_logprobs and output_draft_top_logprobs is not None: draft_logprobs_res = self._create_chat_logprobs( - output_draft_top_logprobs, request.logprobs, request.top_logprobs + output_draft_top_logprobs, request.logprobs, num_top_logprobs ) delta_message = DeltaMessage( @@ -496,6 +522,7 @@ async def chat_completion_full_generator( enable_mm_output=self.enable_mm_output, decoder_base_url=self.tokenizer_base_url, ) + prompt_logprobs_res_list = [[] for _ in range(num_choices)] choices = [] while num_choices > 0: if self.engine_client.check_model_weight_status(): @@ -538,9 +565,12 @@ async def chat_completion_full_generator( output_top_logprobs = output["top_logprobs"] output_draft_top_logprobs = output["draft_top_logprobs"] if output_top_logprobs is not None: + num_top_logprobs = ( + request.top_logprobs if request.top_logprobs != -1 else self.engine_client.ori_vocab_size + ) # logprobs logprobs_res = self._create_chat_logprobs( - output_top_logprobs, request.logprobs, request.top_logprobs + output_top_logprobs, request.logprobs, num_top_logprobs ) if logprobs_res and logprobs_res.content is not None: logprob_contents[idx].extend(logprobs_res.content) @@ -548,11 +578,20 @@ async def chat_completion_full_generator( # draft_logprobs if request.include_draft_logprobs and output_draft_top_logprobs is not None: draft_logprobs_res = self._create_chat_logprobs( - output_draft_top_logprobs, request.logprobs, request.top_logprobs + output_draft_top_logprobs, request.logprobs, num_top_logprobs ) if draft_logprobs_res and draft_logprobs_res.content is not None: draft_logprob_contents[idx].extend(draft_logprobs_res.content) - + prompt_logprobs_tensors = data.get("prompt_logprobs", None) + if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None: + num_prompt_logprobs = ( + request.prompt_logprobs + if request.prompt_logprobs != -1 + else self.engine_client.ori_vocab_size + ) + prompt_logprobs_res = self._build_prompt_logprobs(prompt_logprobs_tensors, num_prompt_logprobs) + if prompt_logprobs_res: + prompt_logprobs_res_list[idx].extend(clamp_prompt_logprobs(prompt_logprobs_res)) if data["finished"]: num_choices -= 1 reasoning_num_tokens[idx] = data["outputs"].get("reasoning_token_num", 0) @@ -573,6 +612,7 @@ async def chat_completion_full_generator( logprob_contents=logprob_contents, draft_logprob_contents=draft_logprob_contents, response_processor=response_processor, + prompt_logprobs_res_list=prompt_logprobs_res_list, max_tokens=max_tokens, ) choices.append(choice) @@ -624,6 +664,7 @@ async def _create_chat_completion_choice( num_image_tokens: list, logprob_contents: list, draft_logprob_contents: list, + prompt_logprobs_res_list: list, response_processor: ChatResponseProcessor, max_tokens: int, ) -> ChatCompletionResponseChoice: @@ -649,11 +690,14 @@ async def _create_chat_completion_choice( message.content = output["text"] logprobs_full_res = None + draft_logprobs_full_res = None + prompt_logprobs_full_res = None if logprob_contents[idx]: logprobs_full_res = LogProbs(content=logprob_contents[idx]) - draft_logprobs_full_res = None if draft_logprob_contents[idx]: draft_logprobs_full_res = LogProbs(content=draft_logprob_contents[idx]) + if prompt_logprobs_res_list[idx]: + prompt_logprobs_full_res = prompt_logprobs_res_list[idx] num_cached_tokens[idx] = data.get("num_cached_tokens", 0) num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0) @@ -675,6 +719,7 @@ async def _create_chat_completion_choice( message=message, logprobs=logprobs_full_res, draft_logprobs=draft_logprobs_full_res, + prompt_logprobs=prompt_logprobs_full_res, finish_reason=finish_reason, ) @@ -780,3 +825,86 @@ def _get_thinking_status(self, request: ChatCompletionRequest) -> bool: else: enable_thinking = True return enable_thinking + + def _build_prompt_logprobs( + self, + prompt_logprobs_tensors: LogprobsTensors, + num_prompt_logprobs: int, + ): + """Update with prompt logprobs from worker. + Args: + prompt_logprobs_tensors: tuple containing the prompt logprobs + tensors. + """ + + token_ids, logprobs, ranks = prompt_logprobs_tensors + + # Detokenize non-incrementally. + # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] + decoded_tokens = [ + self.engine_client.data_processor.process_logprob_response(token_id) + for token_id in token_ids.flatten().tolist() + ] + + # Recover shapes. + num_prompt_tokens, num_logprobs = logprobs.shape + + # Pythonize the paddle tensors. + prompt_token_ranks = ranks.tolist() + prompt_logprobs = logprobs.tolist() + token_ids = token_ids.tolist() + result: Optional[PromptLogprobs] = [None] + # Make Logprob for each position. + for pos in range(num_prompt_tokens): + # Handle flattening. + offset = pos * num_logprobs + offset_end = offset + num_logprobs + decoded_tokens_for_pos = NONES if decoded_tokens is None else decoded_tokens[offset:offset_end] + + # Update with the Logprob dictionary for this pos. + result.append( + self._make_logprob_dict( + prompt_logprobs[pos], + token_ids[pos], + decoded_tokens_for_pos, + prompt_token_ranks[pos], + num_prompt_logprobs, + ) + ) + return result + + @staticmethod + def _make_logprob_dict( + logprobs: list[float], + logprob_token_ids: list[int], + decoded_tokens: Iterable[str | None], + rank: int, + num_logprobs: int, + ) -> dict[int, Logprob]: + """Make a Logprob dictionary for a position. + Args: + logprobs: list of log probabilities + logprob_token_ids: list of top token ids + decoded_tokens: list of decoded top tokens + rank: rank of the sampled token + num_logprobs: number of logprobs requested + by the user (in addition to sampled logprob) + Returns: + dict[token id, Logprob] + """ + if num_logprobs == -1: + num_logprobs = len(logprobs) + # We do not need a special case for the sampled token + # being in the topk, since inserting duplicated data + # into a dictionary twice is the same as doing it once. + topk_ranks = range(1, num_logprobs + 1) + ranks = itertools.chain((rank,), topk_ranks) + + return { + token_id: Logprob( + logprob=logprob, + rank=rank, + decoded_token=token, + ) + for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens) + } diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 9bf242cd03a..93013531759 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -15,9 +15,11 @@ """ import asyncio +import itertools import time import traceback import uuid +from collections.abc import Iterable from typing import List, Optional import numpy as np @@ -43,9 +45,17 @@ ErrorType, ParameterError, api_server_logger, + clamp_prompt_logprobs, get_host_ip, ) -from fastdeploy.worker.output import LogprobsLists +from fastdeploy.worker.output import ( + Logprob, + LogprobsLists, + LogprobsTensors, + PromptLogprobs, +) + +NONES = itertools.repeat(None) class OpenAIServingCompletion: @@ -249,6 +259,7 @@ async def completion_full_generator( aggregated_top_logprobs = [[[], [], []] for _ in range(num_choices)] aggregated_draft_top_logprobs = [[[], [], []] for _ in range(num_choices)] aggregated_token_ids = [[] for _ in range(num_choices)] + aggregated_prompt_logprobs_tensors = [None] * num_choices completion_batched_token_ids = [[] for _ in range(num_choices)] current_waiting_time = 0 while num_choices > 0: @@ -293,6 +304,10 @@ async def completion_full_generator( aggregated_draft_top_logprobs[rid][1].extend(output_draft_top_logprobs[1]) aggregated_draft_top_logprobs[rid][2].extend(output_draft_top_logprobs[2]) + output_prompt_logprobs_tensors = data.get("prompt_logprobs") or None + if output_prompt_logprobs_tensors is not None: + aggregated_prompt_logprobs_tensors[rid] = output_prompt_logprobs_tensors + aggregated_token_ids[rid].extend(data["outputs"]["token_ids"]) self.engine_client.data_processor.process_response_dict( @@ -305,6 +320,7 @@ async def completion_full_generator( data["outputs"]["top_logprobs"] = aggregated_top_logprobs[rid] data["outputs"]["draft_top_logprobs"] = aggregated_draft_top_logprobs[rid] data["outputs"]["token_ids"] = aggregated_token_ids[rid] + data["prompt_logprobs_tensors"] = aggregated_prompt_logprobs_tensors[rid] valid_results[rid] = data num_choices -= 1 break @@ -426,8 +442,18 @@ async def completion_stream_generator( idx = int(res["request_id"].split("_")[-1]) if res.get("error_code", 200) != 200: raise ValueError("{}".format(res["error_msg"])) - + prompt_logprobs_res: Optional[PromptLogprobs] = None if first_iteration[idx]: + prompt_logprobs_tensors = res.get("prompt_logprobs", None) + if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None: + num_prompt_logprobs = ( + request.prompt_logprobs + if request.prompt_logprobs != -1 + else self.engine_client.ori_vocab_size + ) + prompt_logprobs_res = self._build_prompt_logprobs( + prompt_logprobs_tensors, num_prompt_logprobs + ) if request.return_token_ids: chunk = CompletionStreamResponse( id=request_id, @@ -440,6 +466,7 @@ async def completion_stream_generator( prompt_token_ids=list( prompt_batched_token_ids[idx // (1 if request.n is None else request.n)] ), + prompt_logprobs=clamp_prompt_logprobs(prompt_logprobs_res), prompt_tokens=prompt_tokens_list[ idx // (1 if request.n is None else request.n) ], @@ -468,13 +495,16 @@ async def completion_stream_generator( output_draft_top_logprobs = output["draft_top_logprobs"] logprobs_res: Optional[CompletionLogprobs] = None draft_logprobs_res: Optional[CompletionLogprobs] = None - if request.logprobs and output_top_logprobs is not None: - logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) + if request.logprobs is not None and output_top_logprobs is not None: + num_logprobs = ( + request.logprobs if request.logprobs != -1 else self.engine_client.ori_vocab_size + ) + logprobs_res = self._create_completion_logprobs(output_top_logprobs, num_logprobs, 0) # draft logprobs if request.include_draft_logprobs and output_draft_top_logprobs is not None: draft_logprobs_res = self._create_completion_logprobs( - output_draft_top_logprobs, request.logprobs, 0 + output_draft_top_logprobs, num_logprobs, 0 ) output_tokens[idx] += len(output.get("token_ids", [])) or 0 num_cache_tokens[idx] += output.get("num_cache_tokens") or 0 @@ -492,6 +522,7 @@ async def completion_stream_generator( reasoning_content="", arrival_time=arrival_time, logprobs=logprobs_res, + prompt_logprobs=clamp_prompt_logprobs(prompt_logprobs_res), draft_logprobs=draft_logprobs_res, ) if not res["finished"] and "delta_message" in output: @@ -602,15 +633,22 @@ def request_output_to_completion_response( output_draft_top_logprobs = output.get("draft_top_logprobs") or None aggregated_logprobs: Optional[CompletionLogprobs] = None + num_logprobs = request.logprobs if request.logprobs != -1 else self.engine_client.ori_vocab_size if output_top_logprobs is not None: - aggregated_logprobs = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) + aggregated_logprobs = self._create_completion_logprobs(output_top_logprobs, num_logprobs, 0) aggregated_draft_logprobs: Optional[CompletionLogprobs] = None if output_draft_top_logprobs is not None: aggregated_draft_logprobs = self._create_completion_logprobs( - output_draft_top_logprobs, request.logprobs, 0 + output_draft_top_logprobs, num_logprobs, 0 ) - + prompt_logprobs_res: Optional[PromptLogprobs] = None + prompt_logprobs_tensors = final_res.get("prompt_logprobs_tensors", None) + if request.prompt_logprobs is not None and prompt_logprobs_tensors is not None: + num_prompt_logprobs = ( + request.prompt_logprobs if request.prompt_logprobs != -1 else self.engine_client.ori_vocab_size + ) + prompt_logprobs_res = self._build_prompt_logprobs(prompt_logprobs_tensors, num_prompt_logprobs) if request.echo: prompt_text = self._echo_back_prompt(request, idx // (1 if request.n is None else request.n)) token_ids = [*prompt_token_ids, *output["token_ids"]] @@ -641,6 +679,7 @@ def request_output_to_completion_response( tool_calls=output.get("tool_call"), logprobs=aggregated_logprobs, draft_logprobs=aggregated_draft_logprobs, + prompt_logprobs=clamp_prompt_logprobs(prompt_logprobs_res), finish_reason=finish_reason, ) choices.append(choice_data) @@ -749,13 +788,13 @@ def _build_logprobs_response( [tid], clean_up_tokenization_spaces=False ) if "\ufffd" in token_str: - token_bytes = token_str.encode("utf-8", errors="replace") + raw_token = self.engine_client.data_processor.tokenizer.convert_ids_to_tokens(tid) + token_bytes = raw_token.encode("utf-8", errors="replace") token_str = "bytes:" + "".join(f"\\x{byte:02x}" for byte in token_bytes) if idx == 0: tokens.append(token_str) token_logprobs.append(lp) - else: - top_logprobs[token_str] = lp + top_logprobs[token_str] = lp idx += 1 # Construct the sampled token object (avoid sharing references with top_logprob_entries) @@ -770,3 +809,86 @@ def _build_logprobs_response( except Exception as e: api_server_logger.error(f"Error in _build_logprobs_response: {str(e)}, {str(traceback.format_exc())}") return None + + def _build_prompt_logprobs( + self, + prompt_logprobs_tensors: LogprobsTensors, + num_prompt_logprobs: int, + ): + """Update with prompt logprobs from worker. + Args: + prompt_logprobs_tensors: tuple containing the prompt logprobs + tensors. + """ + + token_ids, logprobs, ranks = prompt_logprobs_tensors + + # Detokenize non-incrementally. + # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] + decoded_tokens = [ + self.engine_client.data_processor.process_logprob_response(token_id) + for token_id in token_ids.flatten().tolist() + ] + + # Recover shapes. + num_prompt_tokens, num_logprobs = logprobs.shape + + # Pythonize the paddle tensors. + prompt_token_ranks = ranks.tolist() + prompt_logprobs = logprobs.tolist() + token_ids = token_ids.tolist() + result: Optional[PromptLogprobs] = [None] + # Make Logprob for each position. + for pos in range(num_prompt_tokens): + # Handle flattening. + offset = pos * num_logprobs + offset_end = offset + num_logprobs + decoded_tokens_for_pos = NONES if decoded_tokens is None else decoded_tokens[offset:offset_end] + + # Update with the Logprob dictionary for this pos. + result.append( + self._make_logprob_dict( + prompt_logprobs[pos], + token_ids[pos], + decoded_tokens_for_pos, + prompt_token_ranks[pos], + num_prompt_logprobs, + ) + ) + return result + + @staticmethod + def _make_logprob_dict( + logprobs: list[float], + logprob_token_ids: list[int], + decoded_tokens: Iterable[str | None], + rank: int, + num_logprobs: int, + ) -> dict[int, Logprob]: + """Make a Logprob dictionary for a position. + Args: + logprobs: list of log probabilities + logprob_token_ids: list of top token ids + decoded_tokens: list of decoded top tokens + rank: rank of the sampled token + num_logprobs: number of logprobs requested + by the user (in addition to sampled logprob) + Returns: + dict[token id, Logprob] + """ + if num_logprobs == -1: + num_logprobs = len(logprobs) + # We do not need a special case for the sampled token + # being in the topk, since inserting duplicated data + # into a dictionary twice is the same as doing it once. + topk_ranks = range(1, num_logprobs + 1) + ranks = itertools.chain((rank,), topk_ranks) + + return { + token_id: Logprob( + logprob=logprob, + rank=rank, + decoded_token=token, + ) + for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens) + } diff --git a/fastdeploy/entrypoints/openai/utils.py b/fastdeploy/entrypoints/openai/utils.py index 6f285f17ef2..9a7fe239a9d 100644 --- a/fastdeploy/entrypoints/openai/utils.py +++ b/fastdeploy/entrypoints/openai/utils.py @@ -18,9 +18,9 @@ import heapq import random import time +from multiprocessing.reduction import ForkingPickler import aiozmq -import msgpack import zmq from fastdeploy.engine.args_utils import EngineArgs @@ -124,7 +124,7 @@ async def _listen_connection(self, dealer, conn_index): while self.running: try: raw_data = await dealer.read() - response = msgpack.unpackb(raw_data[-1]) + response = ForkingPickler.loads(raw_data[-1]) _zmq_metrics_stats = ZMQMetricsStats() _zmq_metrics_stats.msg_recv_total += 1 if "zmq_send_time" in response: diff --git a/fastdeploy/inter_communicator/zmq_server.py b/fastdeploy/inter_communicator/zmq_server.py index c269d9286b7..0dabd855e50 100644 --- a/fastdeploy/inter_communicator/zmq_server.py +++ b/fastdeploy/inter_communicator/zmq_server.py @@ -153,7 +153,7 @@ def pack_aggregated_data(self, data): if len(data) > 1: for response in data[1:]: result.add(response) - result = msgpack.packb([result.to_dict()]) + result = ForkingPickler.dumps([result.to_dict()]) return result def receive_json_once(self, block=False): @@ -278,12 +278,12 @@ def _send_response_per_query(self, req_id, data): if self.aggregate_send: result = self.pack_aggregated_data(new_data) else: - result = msgpack.packb([response.to_dict() for response in new_data]) + result = ForkingPickler.dumps([response.to_dict() for response in new_data]) with self.response_token_lock: _zmq_metrics_stats = ZMQMetricsStats() try: - self.socket.send_multipart([self.req_dict[req_id], b"", result]) + self.socket.send_multipart([self.req_dict[req_id], b"", result], copy=False) _zmq_metrics_stats.msg_bytes_send_total += len(result) except Exception as e: _zmq_metrics_stats.msg_send_failed_total += 1 diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 3d9f23630b3..cd6ee320c7d 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -291,7 +291,7 @@ def _process_batch_output_use_zmq(self, receive_datas): llm_logger.warning(f"Failed to parse logprobs from StreamTransferData: {e}") if getattr(stream_data, "prompt_logprobs", None) is not None: try: - result.prompt_logprobs_tensors = stream_data.prompt_logprobs + result.prompt_logprobs = stream_data.prompt_logprobs except Exception as e: llm_logger.warning(f"Failed to parse prompt_logprobs from StreamTransferData: {e}") if self.tokens_counter[task_id] == 0: diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index e91ea4ec485..a0878fa7c73 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -52,6 +52,7 @@ from fastdeploy import envs from fastdeploy.entrypoints.openai.protocol import ErrorInfo, ErrorResponse from fastdeploy.logger.logger import FastDeployLogger +from fastdeploy.worker.output import PromptLogprobs T = TypeVar("T") from typing import Callable, List, Optional @@ -1073,6 +1074,21 @@ def _optional_type(val: str) -> Optional[T]: return _optional_type +def clamp_prompt_logprobs( + prompt_logprobs: PromptLogprobs | None, +) -> PromptLogprobs | None: + if prompt_logprobs is None: + return prompt_logprobs + + for logprob_dict in prompt_logprobs: + if logprob_dict is None: + continue + for logprob_values in logprob_dict.values(): + if logprob_values.logprob == float("-inf"): + logprob_values.logprob = -9999.0 + return prompt_logprobs + + def to_numpy(tasks: List[Any]): """ Convert PaddlePaddle tensors in multimodal inputs to NumPy arrays. diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 396fc198096..8bed4d9d915 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -30,7 +30,6 @@ class Logprob(NamedTuple): decoded_token: Optional[str] = None -PromptLogprobs = list[dict[int, Logprob] | None] # [{token_id, logprob}] for tokens sampled from the top-k SampleLogprobs = list[dict[int, Logprob]] @@ -125,6 +124,9 @@ def slice_rows(self, start: int, end: int): ) +PromptLogprobs = LogprobsTensors | list[dict[int, Logprob] | None] + + @dataclass class SamplerOutput: """ """ diff --git a/requirements.txt b/requirements.txt index e5f614d8a87..31df63b5d4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,7 +40,7 @@ opentelemetry-instrumentation-mysql opentelemetry-distro opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi -opentelemetry-instrumentation-logging +opentelemetry-instrumentation-logging>=0.57b0 partial_json_parser msgspec einops diff --git a/requirements_dcu.txt b/requirements_dcu.txt index 9e89cccdc80..1cae79f8a92 100644 --- a/requirements_dcu.txt +++ b/requirements_dcu.txt @@ -37,5 +37,5 @@ opentelemetry-instrumentation-mysql opentelemetry-distro  opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi -opentelemetry-instrumentation-logging +opentelemetry-instrumentation-logging>=0.57b0 partial_json_parser diff --git a/requirements_iluvatar.txt b/requirements_iluvatar.txt index 561f24d88e0..9d055815e7d 100644 --- a/requirements_iluvatar.txt +++ b/requirements_iluvatar.txt @@ -37,7 +37,7 @@ opentelemetry-instrumentation-mysql opentelemetry-distro opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi -opentelemetry-instrumentation-logging +opentelemetry-instrumentation-logging>=0.57b0 partial_json_parser msgspec safetensors==0.7.0rc0 diff --git a/requirements_metaxgpu.txt b/requirements_metaxgpu.txt index f04659410a9..ed3c551da9d 100644 --- a/requirements_metaxgpu.txt +++ b/requirements_metaxgpu.txt @@ -40,7 +40,7 @@ opentelemetry-instrumentation-mysql opentelemetry-distro opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi -opentelemetry-instrumentation-logging +opentelemetry-instrumentation-logging>=0.57b0 partial_json_parser msgspec einops diff --git a/tests/engine/test_sampling_params.py b/tests/engine/test_sampling_params.py index e8210c35ed7..12a20f48f4a 100644 --- a/tests/engine/test_sampling_params.py +++ b/tests/engine/test_sampling_params.py @@ -26,17 +26,28 @@ class TestSamplingParamsVerification(unittest.TestCase): def test_logprobs_valid_values(self): """Test valid logprobs values""" - # Test None value (should pass) - params = SamplingParams(logprobs=None) - params._verify_args() # Should not raise + # Test None value (should pass in both modes) + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + params = SamplingParams(logprobs=None) + params._verify_args() # Should not raise + + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + params = SamplingParams(logprobs=None) + params._verify_args() # Should not raise + + # Test -1 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1") + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + params = SamplingParams(logprobs=-1) + params._verify_args() # Should not raise - # Test -1 value (should pass) - params = SamplingParams(logprobs=-1) - params._verify_args() # Should not raise + # Test 0 value (should pass in both modes based on actual behavior) + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + params = SamplingParams(logprobs=0) + params._verify_args() # Should not raise - # Test 0 value (should pass) - params = SamplingParams(logprobs=0) - params._verify_args() # Should not raise + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + params = SamplingParams(logprobs=0) + params._verify_args() # Should not raise # Test 20 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "0") with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): @@ -44,13 +55,23 @@ def test_logprobs_valid_values(self): params._verify_args() # Should not raise def test_logprobs_invalid_less_than_minus_one(self): - """Test logprobs less than -1 should raise ValueError""" - with self.assertRaises(ValueError) as cm: - params = SamplingParams(logprobs=-2) - params._verify_args() + """Test logprobs less than -1 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "1" """ + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + with self.assertRaises(ValueError) as cm: + params = SamplingParams(logprobs=-2) + params._verify_args() - self.assertIn("logprobs must be greater than -1", str(cm.exception)) - self.assertIn("got -2", str(cm.exception)) + self.assertIn("logprobs must be a non-negative value or -1", str(cm.exception)) + self.assertIn("got -2", str(cm.exception)) + + def test_logprobs_invalid_less_than_zero(self): + """Test logprobs less than 0 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "0" """ + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + with self.assertRaises(ValueError) as cm: + params = SamplingParams(logprobs=-1) + params._verify_args() + + self.assertIn("Invalid value for 'top_logprobs': must be between 0 and 20", str(cm.exception)) def test_logprobs_greater_than_20_with_v1_disabled(self): """Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is disabled""" @@ -59,7 +80,7 @@ def test_logprobs_greater_than_20_with_v1_disabled(self): params = SamplingParams(logprobs=21) params._verify_args() - self.assertEqual("Invalid value for 'top_logprobs': must be less than or equal to 20.", str(cm.exception)) + self.assertEqual("Invalid value for 'top_logprobs': must be between 0 and 20.", str(cm.exception)) def test_logprobs_greater_than_20_with_v1_enabled(self): """Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is enabled""" @@ -74,46 +95,67 @@ def test_logprobs_greater_than_20_with_v1_enabled(self): def test_prompt_logprobs_valid_values(self): """Test valid prompt_logprobs values""" - # Test None value (should pass) - params = SamplingParams(prompt_logprobs=None) - params._verify_args() # Should not raise + # Test None value (should pass in both modes based on actual behavior) + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + params = SamplingParams(prompt_logprobs=None) + params._verify_args() # Should not raise + + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + params = SamplingParams(prompt_logprobs=None) + params._verify_args() # Should not raise - # Test -1 value (should pass) - params = SamplingParams(prompt_logprobs=-1) - params._verify_args() # Should not raise + # Test -1 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1") + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + params = SamplingParams(prompt_logprobs=-1) + params._verify_args() # Should not raise - # Test 0 value (should pass) - params = SamplingParams(prompt_logprobs=0) - params._verify_args() # Should not raise + # Test 0 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1") + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + params = SamplingParams(prompt_logprobs=0) + params._verify_args() # Should not raise - # Test positive values (should pass) - params = SamplingParams(prompt_logprobs=10) - params._verify_args() # Should not raise + # Test positive values (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1") + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + params = SamplingParams(prompt_logprobs=10) + params._verify_args() # Should not raise def test_prompt_logprobs_invalid_less_than_minus_one(self): - """Test prompt_logprobs less than -1 should raise ValueError""" - with self.assertRaises(ValueError) as cm: - params = SamplingParams(prompt_logprobs=-2) - params._verify_args() + """Test prompt_logprobs less than -1 should raise ValueError when FD_USE_GET_SAVE_OUTPUT_V1 is "1" """ + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + with self.assertRaises(ValueError) as cm: + params = SamplingParams(prompt_logprobs=-2) + params._verify_args() - self.assertIn("prompt_logprobs must be greater than or equal to -1", str(cm.exception)) - self.assertIn("got -2", str(cm.exception)) + self.assertIn("prompt_logprobs a must be non-negative value or -1", str(cm.exception)) + self.assertIn("got -2", str(cm.exception)) def test_combined_logprobs_and_prompt_logprobs(self): """Test both logprobs and prompt_logprobs together""" - # Test valid combination - params = SamplingParams(logprobs=5, prompt_logprobs=3) - params._verify_args() # Should not raise + # Test valid combination when FD_USE_GET_SAVE_OUTPUT_V1 is "1" + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + params = SamplingParams(logprobs=5, prompt_logprobs=3) + params._verify_args() # Should not raise - # Test invalid logprobs with valid prompt_logprobs - with self.assertRaises(ValueError): - params = SamplingParams(logprobs=-2, prompt_logprobs=5) - params._verify_args() + # Test invalid logprobs with valid prompt_logprobs when FD_USE_GET_SAVE_OUTPUT_V1 is "1" + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + with self.assertRaises(ValueError): + params = SamplingParams(logprobs=-2, prompt_logprobs=5) + params._verify_args() - # Test valid logprobs with invalid prompt_logprobs - with self.assertRaises(ValueError): - params = SamplingParams(logprobs=5, prompt_logprobs=-2) - params._verify_args() + # Test valid logprobs with invalid prompt_logprobs when FD_USE_GET_SAVE_OUTPUT_V1 is "1" + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + with self.assertRaises(ValueError): + params = SamplingParams(logprobs=5, prompt_logprobs=-2) + params._verify_args() + + # Test prompt_logprobs not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0" + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + with self.assertRaises(ValueError) as cm: + params = SamplingParams(logprobs=5, prompt_logprobs=3) + params._verify_args() + self.assertIn( + "prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", str(cm.exception) + ) def test_logprobs_boundary_values(self): """Test boundary values for logprobs""" @@ -130,14 +172,16 @@ def test_logprobs_boundary_values(self): def test_prompt_logprobs_boundary_values(self): """Test boundary values for prompt_logprobs""" - # Test boundary value -1 (should pass) - params = SamplingParams(prompt_logprobs=-1) - params._verify_args() # Should pass + # Test boundary value -1 (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "1") + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + params = SamplingParams(prompt_logprobs=-1) + params._verify_args() # Should pass - # Test boundary value just below -1 (should fail) - with self.assertRaises(ValueError): - params = SamplingParams(prompt_logprobs=-2) - params._verify_args() + # Test boundary value just below -1 (should fail when FD_USE_GET_SAVE_OUTPUT_V1 is "1") + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + with self.assertRaises(ValueError): + params = SamplingParams(prompt_logprobs=-2) + params._verify_args() def test_environment_variable_handling(self): """Test different environment variable values""" @@ -167,55 +211,111 @@ def test_environment_variable_handling(self): if original_value is not None: os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = original_value + # Test prompt_logprobs behavior with different environment variables + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + with self.assertRaises(ValueError) as cm: + params = SamplingParams(prompt_logprobs=5) + params._verify_args() + self.assertIn( + "prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", str(cm.exception) + ) + + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + params = SamplingParams(prompt_logprobs=5) + params._verify_args() # Should pass + def test_error_message_formatting(self): """Test that error messages are properly formatted""" - # Test logprobs error message - with self.assertRaises(ValueError) as cm: - params = SamplingParams(logprobs=-5) - params._verify_args() + # Test logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "1" + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + with self.assertRaises(ValueError) as cm: + params = SamplingParams(logprobs=-5) + params._verify_args() + + error_msg = str(cm.exception) + self.assertIn("logprobs must be a non-negative value or -1", error_msg) + self.assertIn("got -5", error_msg) + + # Test logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "0" + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + with self.assertRaises(ValueError) as cm: + params = SamplingParams(logprobs=-1) + params._verify_args() + + error_msg = str(cm.exception) + self.assertIn("Invalid value for 'top_logprobs': must be between 0 and 20", error_msg) + + # Test prompt_logprobs error message when FD_USE_GET_SAVE_OUTPUT_V1 is "1" + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + with self.assertRaises(ValueError) as cm: + params = SamplingParams(prompt_logprobs=-10) + params._verify_args() - error_msg = str(cm.exception) - self.assertIn("logprobs must be greater than -1", error_msg) - self.assertIn("got -5", error_msg) + error_msg = str(cm.exception) + self.assertIn("prompt_logprobs a must be non-negative value or -1", error_msg) + self.assertIn("got -10", error_msg) - # Test prompt_logprobs error message - with self.assertRaises(ValueError) as cm: - params = SamplingParams(prompt_logprobs=-10) - params._verify_args() + # Test prompt_logprobs not supported error message when FD_USE_GET_SAVE_OUTPUT_V1 is "0" + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + with self.assertRaises(ValueError) as cm: + params = SamplingParams(prompt_logprobs=5) + params._verify_args() - error_msg = str(cm.exception) - self.assertIn("prompt_logprobs must be greater than or equal to -1", error_msg) - self.assertIn("got -10", error_msg) + error_msg = str(cm.exception) + self.assertIn("prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", error_msg) def test_post_init_calls_verify_args(self): """Test that __post_init__ calls _verify_args""" - # This should call _verify_args internally - params = SamplingParams(logprobs=5, prompt_logprobs=3) + # This should call _verify_args internally when FD_USE_GET_SAVE_OUTPUT_V1 is "1" + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + params = SamplingParams(logprobs=5, prompt_logprobs=3) - # The params should be successfully created without errors - self.assertEqual(params.logprobs, 5) - self.assertEqual(params.prompt_logprobs, 3) + # The params should be successfully created without errors + self.assertEqual(params.logprobs, 5) + self.assertEqual(params.prompt_logprobs, 3) - # Test that invalid values are caught during initialization - with self.assertRaises(ValueError): - SamplingParams(logprobs=-2) + # Test that invalid values are caught during initialization + with self.assertRaises(ValueError): + SamplingParams(logprobs=-2) + + with self.assertRaises(ValueError): + SamplingParams(prompt_logprobs=-2) - with self.assertRaises(ValueError): - SamplingParams(prompt_logprobs=-2) + # Test that prompt_logprobs is not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0" + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + with self.assertRaises(ValueError): + SamplingParams(prompt_logprobs=3) + + # Test that logprobs < 0 is not supported when FD_USE_GET_SAVE_OUTPUT_V1 is "0" + with self.assertRaises(ValueError): + SamplingParams(logprobs=-1) def test_logprobs_with_other_parameters(self): """Test logprobs validation with other sampling parameters""" - # Test with temperature - params = SamplingParams(logprobs=5, temperature=0.8) - params._verify_args() # Should pass + # Test with temperature when FD_USE_GET_SAVE_OUTPUT_V1 is "1" + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + params = SamplingParams(logprobs=5, temperature=0.8) + params._verify_args() # Should pass - # Test with top_p - params = SamplingParams(logprobs=5, top_p=0.9) - params._verify_args() # Should pass + # Test with top_p when FD_USE_GET_SAVE_OUTPUT_V1 is "1" + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + params = SamplingParams(logprobs=5, top_p=0.9) + params._verify_args() # Should pass - # Test with all parameters - params = SamplingParams(logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100) - params._verify_args() # Should pass + # Test with all parameters when FD_USE_GET_SAVE_OUTPUT_V1 is "1" + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + params = SamplingParams( + logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100 + ) + params._verify_args() # Should pass + + # Test that prompt_logprobs fails when FD_USE_GET_SAVE_OUTPUT_V1 is "0" + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + with self.assertRaises(ValueError): + params = SamplingParams( + logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100 + ) + params._verify_args() if __name__ == "__main__": diff --git a/tests/entrypoints/openai/test_max_streaming_tokens.py b/tests/entrypoints/openai/test_max_streaming_tokens.py index 7ce6df13e58..48935cba838 100644 --- a/tests/entrypoints/openai/test_max_streaming_tokens.py +++ b/tests/entrypoints/openai/test_max_streaming_tokens.py @@ -453,6 +453,7 @@ async def test_create_chat_completion_choice(self): num_input_image_tokens = [0, 0] num_input_video_tokens = [0, 0] num_image_tokens = [0, 0] + prompt_logprobs_res_list = [[], []] max_tokens_list = [10, 1] for idx, case in enumerate(test_cases): @@ -469,6 +470,7 @@ async def test_create_chat_completion_choice(self): num_image_tokens=num_image_tokens, logprob_contents=logprob_contents, draft_logprob_contents=draft_logprob_contents, + prompt_logprobs_res_list=prompt_logprobs_res_list, response_processor=mock_response_processor, max_tokens=max_tokens_list[idx], ) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 394a23f0f4e..940e569e186 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -14,14 +14,18 @@ # limitations under the License. """ +import json import unittest -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import paddle from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat +from fastdeploy.worker.output import Logprob, LogprobsTensors -class TestOpenAIServingCompletion(unittest.TestCase): +class TestOpenAIServingCompletion(unittest.IsolatedAsyncioTestCase): def setUp(self): """ @@ -66,6 +70,1044 @@ def test_enable_thinking(self): enable_thinking = self.chat_completion_handler._get_thinking_status(request) self.assertEqual(enable_thinking, True) + def test_build_prompt_logprobs_basic(self): + """Test basic functionality of _build_prompt_logprobs""" + # Create mock data + num_prompt_tokens = 2 + num_logprobs = 3 + + # Create tensors + token_ids = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype=paddle.int64) + logprobs = paddle.to_tensor([[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]], dtype=paddle.float32) + ranks = paddle.to_tensor([1, 2], dtype=paddle.int64) + + prompt_logprobs_tensors = LogprobsTensors(token_ids, logprobs, ranks) + + # Mock the data processor + with patch.object( + self.chat_completion_handler.engine_client.data_processor, "process_logprob_response" + ) as mock_decode: + mock_decode.side_effect = ["token1", "token2", "token3", "token4", "token5", "token6"] + + result = self.chat_completion_handler._build_prompt_logprobs(prompt_logprobs_tensors, num_logprobs) + + # Verify result structure (first element is None, then actual results) + self.assertEqual(len(result), num_prompt_tokens + 1) + self.assertIsNone(result[0]) + + # Check first position (index 1 since index 0 is None) + first_pos_result = result[1] + self.assertEqual(len(first_pos_result), num_logprobs) + + # Check token IDs and logprobs for first position + expected_tokens = [1, 2, 3] + expected_logprobs = [float(logprobs[0][i]) for i in range(num_logprobs)] + expected_ranks = [1, 1, 2] # First token uses rank from ranks tensor, then topk ranks start from 1 + + for i, token_id in enumerate(expected_tokens): + self.assertIn(token_id, first_pos_result) + self.assertIsInstance(first_pos_result[token_id], Logprob) + self.assertEqual(first_pos_result[token_id].logprob, expected_logprobs[i]) + self.assertEqual(first_pos_result[token_id].rank, expected_ranks[i]) + self.assertEqual(first_pos_result[token_id].decoded_token, f"token{i+1}") + + def test_build_prompt_logprobs_with_all_logprobs(self): + """Test _build_prompt_logprobs with num_prompt_logprobs=-1 (all logprobs)""" + num_prompt_tokens = 1 + num_logprobs = 2 + + token_ids = paddle.to_tensor([[10, 20]], dtype=paddle.int64) + logprobs = paddle.to_tensor([[-1.0, -2.0]], dtype=paddle.float32) + ranks = paddle.to_tensor([0], dtype=paddle.int64) + + prompt_logprobs_tensors = LogprobsTensors(token_ids, logprobs, ranks) + + with patch.object( + self.chat_completion_handler.engine_client.data_processor, "process_logprob_response" + ) as mock_decode: + mock_decode.side_effect = ["hello", "world"] + + result = self.chat_completion_handler._build_prompt_logprobs(prompt_logprobs_tensors, -1) + + self.assertEqual(len(result), num_prompt_tokens + 1) + self.assertIsNone(result[0]) + first_pos_result = result[1] + self.assertEqual(len(first_pos_result), num_logprobs) + + # Verify all logprobs are included when num_prompt_logprobs=-1 + for token_id in first_pos_result: + self.assertIn(token_id, [10, 20]) + + def test_build_prompt_logprobs_single_token(self): + """Test _build_prompt_logprobs with single prompt token""" + num_prompt_tokens = 1 + num_logprobs = 1 + + token_ids = paddle.to_tensor([[100]], dtype=paddle.int64) + logprobs = paddle.to_tensor([[-0.5]], dtype=paddle.float32) + ranks = paddle.to_tensor([1], dtype=paddle.int64) + + prompt_logprobs_tensors = LogprobsTensors(token_ids, logprobs, ranks) + + with patch.object( + self.chat_completion_handler.engine_client.data_processor, "process_logprob_response" + ) as mock_decode: + mock_decode.return_value = "single_token" + + result = self.chat_completion_handler._build_prompt_logprobs(prompt_logprobs_tensors, num_logprobs) + + self.assertEqual(len(result), num_prompt_tokens + 1) + self.assertIsNone(result[0]) + first_pos_result = result[1] + self.assertEqual(len(first_pos_result), num_logprobs) + + # Check the single token + self.assertIn(100, first_pos_result) + self.assertEqual(first_pos_result[100].logprob, -0.5) + self.assertEqual(first_pos_result[100].rank, 1) + self.assertEqual(first_pos_result[100].decoded_token, "single_token") + + def test_build_prompt_logprobs_multiple_positions(self): + """Test _build_prompt_logprobs with multiple prompt positions""" + num_prompt_tokens = 3 + num_logprobs = 2 + + token_ids = paddle.to_tensor([[1, 2], [3, 4], [5, 6]], dtype=paddle.int64) + logprobs = paddle.to_tensor([[-0.1, -0.2], [-0.3, -0.4], [-0.5, -0.6]], dtype=paddle.float32) + ranks = paddle.to_tensor([1, 2, 3], dtype=paddle.int64) + + prompt_logprobs_tensors = LogprobsTensors(token_ids, logprobs, ranks) + + with patch.object( + self.chat_completion_handler.engine_client.data_processor, "process_logprob_response" + ) as mock_decode: + mock_decode.side_effect = ["t1", "t2", "t3", "t4", "t5", "t6"] + + result = self.chat_completion_handler._build_prompt_logprobs(prompt_logprobs_tensors, num_logprobs) + + self.assertEqual(len(result), num_prompt_tokens + 1) + self.assertIsNone(result[0]) + + # Check each position (index + 1 since index 0 is None) + for pos in range(num_prompt_tokens): + pos_result = result[pos + 1] + self.assertEqual(len(pos_result), num_logprobs) + + # Verify token IDs and their properties + expected_tokens = [int(token_ids[pos][0]), int(token_ids[pos][1])] + expected_ranks = [ + ranks[pos], + 1, + ] # First token uses rank from ranks tensor, second token uses topk rank 1 + + for i, token_id in enumerate(expected_tokens): + self.assertIn(token_id, pos_result) + self.assertEqual(pos_result[token_id].logprob, float(logprobs[pos][i])) + self.assertEqual(pos_result[token_id].rank, expected_ranks[i]) + self.assertEqual(pos_result[token_id].decoded_token, f"t{pos*2 + i + 1}") + + def test_build_prompt_logprobs_empty_tensors(self): + """Test _build_prompt_logprobs with empty tensors""" + num_prompt_tokens = 0 + num_logprobs = 0 + + token_ids = paddle.to_tensor([], dtype=paddle.int64).reshape([0, 0]) + logprobs = paddle.to_tensor([], dtype=paddle.float32).reshape([0, 0]) + ranks = paddle.to_tensor([], dtype=paddle.int64) + + prompt_logprobs_tensors = LogprobsTensors(token_ids, logprobs, ranks) + + result = self.chat_completion_handler._build_prompt_logprobs(prompt_logprobs_tensors, num_logprobs) + + self.assertEqual(len(result), num_prompt_tokens + 1) + self.assertIsNone(result[0]) + + def test_make_logprob_dict(self): + """Test the static method _make_logprob_dict""" + logprobs = [-0.1, -0.2, -0.3] + logprob_token_ids = [1, 2, 3] + decoded_tokens = ["token1", "token2", "token3"] + rank = 1 + num_logprobs = 3 + + result = OpenAIServingChat._make_logprob_dict(logprobs, logprob_token_ids, decoded_tokens, rank, num_logprobs) + + self.assertEqual(len(result), num_logprobs) + + # Check first token (sampled token) + self.assertIn(1, result) + self.assertEqual(result[1].logprob, -0.1) + self.assertEqual(result[1].rank, rank) # rank of sampled token + self.assertEqual(result[1].decoded_token, "token1") + + # Check other tokens - topk ranks start from 1 + expected_ranks = [rank, 1, 2] # First token uses rank, then topk ranks + for i, token_id in enumerate(logprob_token_ids): + self.assertIn(token_id, result) + self.assertEqual(result[token_id].logprob, logprobs[i]) + self.assertEqual(result[token_id].rank, expected_ranks[i]) + self.assertEqual(result[token_id].decoded_token, decoded_tokens[i]) + + def test_make_logprob_dict_with_negative_num_logprobs(self): + """Test _make_logprob_dict with num_logprobs=-1""" + logprobs = [-0.1, -0.2] + logprob_token_ids = [1, 2] + decoded_tokens = ["token1", "token2"] + rank = 1 + num_logprobs = -1 + + result = OpenAIServingChat._make_logprob_dict(logprobs, logprob_token_ids, decoded_tokens, rank, num_logprobs) + + # Should include all logprobs when num_logprobs=-1 + self.assertEqual(len(result), len(logprobs)) + + # Expected ranks: first token uses rank, second token uses topk rank 1 + expected_ranks = [rank, 1] + + for i, token_id in enumerate(logprob_token_ids): + self.assertIn(token_id, result) + self.assertEqual(result[token_id].logprob, logprobs[i]) + self.assertEqual(result[token_id].rank, expected_ranks[i]) + self.assertEqual(result[token_id].decoded_token, decoded_tokens[i]) + + def test_make_logprob_dict_partial_logprobs(self): + """Test _make_logprob_dict with fewer logprobs than available""" + logprobs = [-0.1, -0.2, -0.3, -0.4] + logprob_token_ids = [1, 2, 3, 4] + decoded_tokens = ["token1", "token2", "token3", "token4"] + rank = 2 + num_logprobs = 2 + + result = OpenAIServingChat._make_logprob_dict(logprobs, logprob_token_ids, decoded_tokens, rank, num_logprobs) + + self.assertEqual(len(result), 3) + + # Check sampled token (first token) + self.assertIn(1, result) + self.assertEqual(result[1].logprob, -0.1) + self.assertEqual(result[1].rank, rank) + self.assertEqual(result[1].decoded_token, "token1") + + # Check top-k token (second token) + self.assertIn(2, result) + self.assertEqual(result[2].logprob, -0.2) + self.assertEqual(result[2].rank, 1) # topk rank starts from 1 + self.assertEqual(result[2].decoded_token, "token2") + + async def test_chat_completion_stream_generator_with_prompt_logprobs(self): + """Test chat_completion_stream_generator with prompt_logprobs enabled""" + # Create mock request with prompt_logprobs enabled + request = ChatCompletionRequest( + messages=[{"role": "user", "content": "Hello"}], prompt_logprobs=3, logprobs=False, stream=True + ) + + request_id = "test_request_123" + model_name = "test_model" + prompt_token_ids = [1, 2, 3] + prompt_tokens = "Hello world" + + # Mock the connection manager and response queue + mock_dealer = MagicMock() + mock_response_queue = AsyncMock() + + # Create mock response with prompt_logprobs data + mock_response = { + "request_id": f"{request_id}_0", + "error_code": 200, + "metrics": { + "first_token_time": 1234567890, + "inference_start_time": 1234567880, + "arrival_time": 1234567890, + "request_start_time": 1234567870, + }, + "prompt_logprobs": LogprobsTensors( + logprob_token_ids=paddle.to_tensor([[1, 2, 3, 4]], dtype=paddle.int64), + logprobs=paddle.to_tensor([[-0.1, -0.2, -0.3, -0.4]], dtype=paddle.float32), + selected_token_ranks=paddle.to_tensor([1], dtype=paddle.int64), + ), + "outputs": { + "token_ids": [5], + "text": "Hi", + "top_logprobs": None, + "draft_top_logprobs": None, + "multipart": [{"type": "text", "text": "Hi"}], + }, + "finished": True, + "num_cached_tokens": 0, + "num_input_image_tokens": 0, + "num_input_video_tokens": 0, + } + + mock_response_queue.get.return_value = mock_response + + # Mock the connection manager + self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock( + return_value=(mock_dealer, mock_response_queue) + ) + + # Mock the semaphore + self.chat_completion_handler.engine_client.semaphore = MagicMock() + self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True) + self.chat_completion_handler.engine_client.semaphore.release = MagicMock() + + # Mock the model weight status check + self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False) + + # Mock the response processor + mock_response_processor = MagicMock() + mock_response_processor.enable_multimodal_content.return_value = False + + async def mock_async_generator(): + yield mock_response + + mock_response_processor.process_response_chat.return_value = mock_async_generator() + + # Mock the cleanup method + self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock() + + with patch( + "fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor + ): + with patch.object( + self.chat_completion_handler.engine_client.data_processor, "process_logprob_response" + ) as mock_decode: + mock_decode.side_effect = ["Hello", "world", "test", "token"] + + # Execute the generator + results = [] + async for chunk in self.chat_completion_handler.chat_completion_stream_generator( + request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100 + ): + results.append(chunk) + + # Verify that prompt_logprobs are included in the response + self.assertGreater(len(results), 0) + + # Check that the first chunk contains prompt_logprobs + first_chunk_data = json.loads(results[0].replace("data: ", "").strip()) + self.assertIn("choices", first_chunk_data) + self.assertEqual(len(first_chunk_data["choices"]), 1) + + choice = first_chunk_data["choices"][0] + self.assertIn("prompt_logprobs", choice) + self.assertIsNotNone(choice["prompt_logprobs"]) + + # Verify prompt_logprobs structure + prompt_logprobs = choice["prompt_logprobs"] + self.assertIsInstance(prompt_logprobs, list) + self.assertGreater(len(prompt_logprobs), 0) + + async def test_chat_completion_stream_generator_with_logprobs(self): + """Test chat_completion_stream_generator with logprobs enabled""" + # Create mock request with logprobs enabled + request = ChatCompletionRequest( + messages=[{"role": "user", "content": "Hello"}], + prompt_logprobs=None, + logprobs=True, + top_logprobs=2, + stream=True, + ) + + request_id = "test_request_456" + model_name = "test_model" + prompt_token_ids = [1, 2, 3] + prompt_tokens = "Hello world" + + # Mock the connection manager and response queue + mock_dealer = MagicMock() + mock_response_queue = AsyncMock() + + # Create mock response with logprobs data + mock_response = { + "request_id": f"{request_id}_0", + "error_code": 200, + "metrics": { + "first_token_time": 1234567890, + "inference_start_time": 1234567880, + "arrival_time": 1234567890, + "request_start_time": 1234567870, + }, + "prompt_logprobs": None, + "outputs": { + "token_ids": [5], + "text": "Hi", + "top_logprobs": [ + [[5, 6]], # logprob_token_ids + [[-0.1, -0.2]], # logprobs + [1], # sampled_token_ranks + ], + "draft_top_logprobs": None, + "multipart": [{"type": "text", "text": "Hi"}], + }, + "finished": True, + "num_cached_tokens": 0, + "num_input_image_tokens": 0, + "num_input_video_tokens": 0, + } + + mock_response_queue.get.return_value = mock_response + + # Mock the connection manager + self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock( + return_value=(mock_dealer, mock_response_queue) + ) + + # Mock the semaphore + self.chat_completion_handler.engine_client.semaphore = MagicMock() + self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True) + self.chat_completion_handler.engine_client.semaphore.release = MagicMock() + + # Mock the model weight status check + self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False) + + # Mock the response processor + mock_response_processor = MagicMock() + mock_response_processor.enable_multimodal_content.return_value = False + + async def mock_async_generator(): + yield mock_response + + mock_response_processor.process_response_chat.return_value = mock_async_generator() + + # Mock the cleanup method + self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock() + + # Mock the data processor for logprob response + with patch( + "fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor + ): + with patch.object( + self.chat_completion_handler.engine_client.data_processor, "process_logprob_response" + ) as mock_decode: + mock_decode.return_value = "Hi" + + # Execute the generator + results = [] + async for chunk in self.chat_completion_handler.chat_completion_stream_generator( + request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100 + ): + results.append(chunk) + + # Verify that logprobs are included in the response + self.assertGreater(len(results), 0) + + # Find chunks that contain logprobs + logprobs_chunks = [] + for result in results: + if "logprobs" in result: + logprobs_chunks.append(result) + + self.assertGreater(len(logprobs_chunks), 0) + + # Check logprobs structure in response + for chunk in logprobs_chunks: + chunk_data = json.loads(chunk.replace("data: ", "").strip()) + if "choices" in chunk_data and len(chunk_data["choices"]) > 0: + choice = chunk_data["choices"][0] + if "logprobs" in choice: + self.assertIsNotNone(choice["logprobs"]) + + async def test_chat_completion_stream_generator_with_both_logprobs(self): + """Test chat_completion_stream_generator with both prompt_logprobs and logprobs enabled""" + # Create mock request with both logprobs enabled + request = ChatCompletionRequest( + messages=[{"role": "user", "content": "Hello"}], + prompt_logprobs=2, + logprobs=True, + top_logprobs=2, + stream=True, + ) + + request_id = "test_request_789" + model_name = "test_model" + prompt_token_ids = [1, 2, 3] + prompt_tokens = "Hello world" + + # Mock the connection manager and response queue + mock_dealer = MagicMock() + mock_response_queue = AsyncMock() + + # Create mock response with both logprobs data + mock_response = { + "request_id": f"{request_id}_0", + "error_code": 200, + "metrics": { + "first_token_time": 1234567890, + "inference_start_time": 1234567880, + "arrival_time": 1234567890, + "request_start_time": 1234567870, + }, + "prompt_logprobs": LogprobsTensors( + logprob_token_ids=paddle.to_tensor([[1, 2, 3]], dtype=paddle.int64), + logprobs=paddle.to_tensor([[-0.1, -0.2, -0.3]], dtype=paddle.float32), + selected_token_ranks=paddle.to_tensor([1], dtype=paddle.int64), + ), + "outputs": { + "token_ids": [5], + "text": "Hi", + "top_logprobs": [ + [[5, 6]], # logprob_token_ids + [[-0.1, -0.2]], # logprobs + [1], # sampled_token_ranks + ], + "draft_top_logprobs": None, + "multipart": [{"type": "text", "text": "Hi"}], + }, + "finished": True, + "num_cached_tokens": 0, + "num_input_image_tokens": 0, + "num_input_video_tokens": 0, + } + + mock_response_queue.get.return_value = mock_response + + # Mock the connection manager + self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock( + return_value=(mock_dealer, mock_response_queue) + ) + + # Mock the semaphore + self.chat_completion_handler.engine_client.semaphore = MagicMock() + self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True) + self.chat_completion_handler.engine_client.semaphore.release = MagicMock() + + # Mock the model weight status check + self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False) + + # Mock the response processor + mock_response_processor = MagicMock() + mock_response_processor.enable_multimodal_content.return_value = False + + async def mock_async_generator(): + yield mock_response + + mock_response_processor.process_response_chat.return_value = mock_async_generator() + + # Mock the cleanup method + self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock() + + with patch( + "fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor + ): + with patch.object( + self.chat_completion_handler.engine_client.data_processor, "process_logprob_response" + ) as mock_decode: + mock_decode.return_value = "Hi" + + # Execute the generator + results = [] + async for chunk in self.chat_completion_handler.chat_completion_stream_generator( + request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100 + ): + results.append(chunk) + + # Verify that both types of logprobs are included + self.assertGreater(len(results), 0) + + # Check for prompt_logprobs + first_chunk_data = json.loads(results[0].replace("data: ", "").strip()) + self.assertIn("choices", first_chunk_data) + choice = first_chunk_data["choices"][0] + self.assertIn("prompt_logprobs", choice) + self.assertIsNotNone(choice["prompt_logprobs"]) + + # Check for logprobs in subsequent chunks + logprobs_found = False + for result in results: + # Skip [DONE] message + if result.strip() == "data: [DONE]": + continue + chunk_data = json.loads(result.replace("data: ", "").strip()) + if "choices" in chunk_data and len(chunk_data["choices"]) > 0: + choice = chunk_data["choices"][0] + if "logprobs" in choice and choice["logprobs"] is not None: + logprobs_found = True + break + + self.assertTrue(logprobs_found, "logprobs should be found in response chunks") + + async def test_chat_completion_stream_generator_without_logprobs(self): + """Test chat_completion_stream_generator without logprobs enabled""" + # Create mock request without logprobs + request = ChatCompletionRequest( + messages=[{"role": "user", "content": "Hello"}], prompt_logprobs=None, logprobs=False, stream=True + ) + + request_id = "test_request_no_logprobs" + model_name = "test_model" + prompt_token_ids = [1, 2, 3] + prompt_tokens = "Hello world" + + # Mock the connection manager and response queue + mock_dealer = MagicMock() + mock_response_queue = AsyncMock() + + # Create mock response without logprobs data + mock_response = { + "request_id": f"{request_id}_0", + "error_code": 200, + "metrics": { + "first_token_time": 1234567890, + "inference_start_time": 1234567880, + "arrival_time": 1234567890, + "request_start_time": 1234567870, + }, + "prompt_logprobs": None, + "outputs": { + "token_ids": [5], + "text": "Hi", + "top_logprobs": None, + "draft_top_logprobs": None, + "multipart": [{"type": "text", "text": "Hi"}], + }, + "finished": True, + "num_cached_tokens": 0, + "num_input_image_tokens": 0, + "num_input_video_tokens": 0, + } + + mock_response_queue.get.return_value = mock_response + + # Mock the connection manager + self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock( + return_value=(mock_dealer, mock_response_queue) + ) + + # Mock the semaphore + self.chat_completion_handler.engine_client.semaphore = MagicMock() + self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True) + self.chat_completion_handler.engine_client.semaphore.release = MagicMock() + + # Mock the model weight status check + self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False) + + # Mock the response processor + mock_response_processor = MagicMock() + mock_response_processor.enable_multimodal_content.return_value = False + + async def mock_async_generator(): + yield mock_response + + mock_response_processor.process_response_chat.return_value = mock_async_generator() + + # Mock the cleanup method + self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock() + + with patch( + "fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor + ): + # Execute the generator + results = [] + async for chunk in self.chat_completion_handler.chat_completion_stream_generator( + request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100 + ): + results.append(chunk) + + # Verify that logprobs are not included in the response + self.assertGreater(len(results), 0) + + for result in results: + # Skip [DONE] message + if result.strip() == "data: [DONE]": + continue + + chunk_data = json.loads(result.replace("data: ", "").strip()) + if "choices" in chunk_data and len(chunk_data["choices"]) > 0: + choice = chunk_data["choices"][0] + # prompt_logprobs should be None when not requested + self.assertIsNone(choice.get("prompt_logprobs")) + # logprobs should be None when not requested + self.assertIsNone(choice.get("logprobs")) + + async def test_chat_completion_full_generator_with_prompt_logprobs(self): + """Test chat_completion_full_generator with prompt_logprobs enabled""" + # Create mock request with prompt_logprobs enabled + request = ChatCompletionRequest( + messages=[{"role": "user", "content": "Hello"}], prompt_logprobs=3, logprobs=False, stream=False + ) + + request_id = "test_request_full_123" + model_name = "test_model" + prompt_token_ids = [1, 2, 3] + prompt_tokens = "Hello world" + + # Mock the connection manager and response queue + mock_dealer = MagicMock() + mock_response_queue = AsyncMock() + + # Create mock response with prompt_logprobs data + mock_response = { + "request_id": f"{request_id}_0", + "error_code": 200, + "metrics": { + "first_token_time": 1234567890, + "inference_start_time": 1234567880, + "arrival_time": 1234567890, + "request_start_time": 1234567870, + }, + "prompt_logprobs": LogprobsTensors( + logprob_token_ids=paddle.to_tensor([[1, 2, 3, 4]], dtype=paddle.int64), + logprobs=paddle.to_tensor([[-0.1, -0.2, -0.3, -0.4]], dtype=paddle.float32), + selected_token_ranks=paddle.to_tensor([1], dtype=paddle.int64), + ), + "outputs": { + "token_ids": [5], + "text": "Hi", + "top_logprobs": None, + "draft_top_logprobs": None, + "multipart": [{"type": "text", "text": "Hi"}], + }, + "finished": True, + "num_cached_tokens": 0, + "num_input_image_tokens": 0, + "num_input_video_tokens": 0, + } + + mock_response_queue.get.return_value = mock_response + + # Mock the connection manager + self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock( + return_value=(mock_dealer, mock_response_queue) + ) + + # Mock the semaphore + self.chat_completion_handler.engine_client.semaphore = MagicMock() + self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True) + self.chat_completion_handler.engine_client.semaphore.release = MagicMock() + + # Mock the model weight status check + self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False) + + # Mock the response processor + mock_response_processor = MagicMock() + mock_response_processor.enable_multimodal_content.return_value = False + + async def mock_async_generator(): + yield mock_response + + mock_response_processor.process_response_chat.return_value = mock_async_generator() + + # Mock the cleanup method + self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock() + + with patch( + "fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor + ): + with patch.object( + self.chat_completion_handler.engine_client.data_processor, "process_logprob_response" + ) as mock_decode: + mock_decode.side_effect = ["Hello", "world", "test", "token"] + + # Execute the generator + result = await self.chat_completion_handler.chat_completion_full_generator( + request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100 + ) + + # Verify that prompt_logprobs are included in the response + self.assertIsNotNone(result) + self.assertIn("choices", result.model_dump()) + self.assertGreater(len(result.choices), 0) + + choice = result.choices[0] + self.assertIn("prompt_logprobs", choice.model_dump()) + self.assertIsNotNone(choice.prompt_logprobs) + + # Verify prompt_logprobs structure + prompt_logprobs = choice.prompt_logprobs + self.assertIsInstance(prompt_logprobs, list) + self.assertGreater(len(prompt_logprobs), 0) + + async def test_chat_completion_full_generator_with_logprobs(self): + """Test chat_completion_full_generator with logprobs enabled""" + # Create mock request with logprobs enabled + request = ChatCompletionRequest( + messages=[{"role": "user", "content": "Hello"}], + prompt_logprobs=None, + logprobs=True, + top_logprobs=2, + stream=False, + ) + + request_id = "test_request_full_456" + model_name = "test_model" + prompt_token_ids = [1, 2, 3] + prompt_tokens = "Hello world" + + # Mock the connection manager and response queue + mock_dealer = MagicMock() + mock_response_queue = AsyncMock() + + # Create mock response with logprobs data + mock_response = { + "request_id": f"{request_id}_0", + "error_code": 200, + "metrics": { + "first_token_time": 1234567890, + "inference_start_time": 1234567880, + "arrival_time": 1234567890, + "request_start_time": 1234567870, + }, + "prompt_logprobs": None, + "outputs": { + "token_ids": [5], + "text": "Hi", + "top_logprobs": [ + [[5, 6]], # logprob_token_ids + [[-0.1, -0.2]], # logprobs + [1], # sampled_token_ranks + ], + "draft_top_logprobs": None, + "multipart": [{"type": "text", "text": "Hi"}], + }, + "finished": True, + "num_cached_tokens": 0, + "num_input_image_tokens": 0, + "num_input_video_tokens": 0, + } + + mock_response_queue.get.return_value = mock_response + + # Mock the connection manager + self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock( + return_value=(mock_dealer, mock_response_queue) + ) + + # Mock the semaphore + self.chat_completion_handler.engine_client.semaphore = MagicMock() + self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True) + self.chat_completion_handler.engine_client.semaphore.release = MagicMock() + + # Mock the model weight status check + self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False) + + # Mock the response processor + mock_response_processor = MagicMock() + mock_response_processor.enable_multimodal_content.return_value = False + + async def mock_async_generator(): + yield mock_response + + mock_response_processor.process_response_chat.return_value = mock_async_generator() + + # Mock the cleanup method + self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock() + + # Mock the data processor for logprob response + with patch( + "fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor + ): + with patch.object( + self.chat_completion_handler.engine_client.data_processor, "process_logprob_response" + ) as mock_decode: + mock_decode.return_value = "Hi" + + # Execute the generator + result = await self.chat_completion_handler.chat_completion_full_generator( + request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100 + ) + + # Verify that logprobs are included in the response + self.assertIsNotNone(result) + self.assertIn("choices", result.model_dump()) + self.assertGreater(len(result.choices), 0) + + choice = result.choices[0] + self.assertIn("logprobs", choice.model_dump()) + self.assertIsNotNone(choice.logprobs) + + async def test_chat_completion_full_generator_with_both_logprobs(self): + """Test chat_completion_full_generator with both prompt_logprobs and logprobs enabled""" + # Create mock request with both logprobs enabled + request = ChatCompletionRequest( + messages=[{"role": "user", "content": "Hello"}], + prompt_logprobs=2, + logprobs=True, + top_logprobs=2, + stream=False, + ) + + request_id = "test_request_full_789" + model_name = "test_model" + prompt_token_ids = [1, 2, 3] + prompt_tokens = "Hello world" + + # Mock the connection manager and response queue + mock_dealer = MagicMock() + mock_response_queue = AsyncMock() + + # Create mock response with both logprobs data + mock_response = { + "request_id": f"{request_id}_0", + "error_code": 200, + "metrics": { + "first_token_time": 1234567890, + "inference_start_time": 1234567880, + "arrival_time": 1234567890, + "request_start_time": 1234567870, + }, + "prompt_logprobs": LogprobsTensors( + logprob_token_ids=paddle.to_tensor([[1, 2, 3]], dtype=paddle.int64), + logprobs=paddle.to_tensor([[-0.1, -0.2, -0.3]], dtype=paddle.float32), + selected_token_ranks=paddle.to_tensor([1], dtype=paddle.int64), + ), + "outputs": { + "token_ids": [5], + "text": "Hi", + "top_logprobs": [ + [[5, 6]], # logprob_token_ids + [[-0.1, -0.2]], # logprobs + [1], # sampled_token_ranks + ], + "draft_top_logprobs": None, + "multipart": [{"type": "text", "text": "Hi"}], + }, + "finished": True, + "num_cached_tokens": 0, + "num_input_image_tokens": 0, + "num_input_video_tokens": 0, + } + + mock_response_queue.get.return_value = mock_response + + # Mock the connection manager + self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock( + return_value=(mock_dealer, mock_response_queue) + ) + + # Mock the semaphore + self.chat_completion_handler.engine_client.semaphore = MagicMock() + self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True) + self.chat_completion_handler.engine_client.semaphore.release = MagicMock() + + # Mock the model weight status check + self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False) + + # Mock the response processor + mock_response_processor = MagicMock() + mock_response_processor.enable_multimodal_content.return_value = False + + async def mock_async_generator(): + yield mock_response + + mock_response_processor.process_response_chat.return_value = mock_async_generator() + + # Mock the cleanup method + self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock() + + with patch( + "fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor + ): + with patch.object( + self.chat_completion_handler.engine_client.data_processor, "process_logprob_response" + ) as mock_decode: + mock_decode.return_value = "Hi" + + # Execute the generator + result = await self.chat_completion_handler.chat_completion_full_generator( + request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100 + ) + + # Verify that both types of logprobs are included + self.assertIsNotNone(result) + self.assertIn("choices", result.model_dump()) + self.assertGreater(len(result.choices), 0) + + choice = result.choices[0] + + # Check for prompt_logprobs + self.assertIn("prompt_logprobs", choice.model_dump()) + self.assertIsNotNone(choice.prompt_logprobs) + + # Check for logprobs + self.assertIn("logprobs", choice.model_dump()) + self.assertIsNotNone(choice.logprobs) + + async def test_chat_completion_full_generator_without_logprobs(self): + """Test chat_completion_full_generator without logprobs enabled""" + # Create mock request without logprobs + request = ChatCompletionRequest( + messages=[{"role": "user", "content": "Hello"}], prompt_logprobs=None, logprobs=False, stream=False + ) + + request_id = "test_request_full_no_logprobs" + model_name = "test_model" + prompt_token_ids = [1, 2, 3] + prompt_tokens = "Hello world" + + # Mock the connection manager and response queue + mock_dealer = MagicMock() + mock_response_queue = AsyncMock() + + # Create mock response without logprobs data + mock_response = { + "request_id": f"{request_id}_0", + "error_code": 200, + "metrics": { + "first_token_time": 1234567890, + "inference_start_time": 1234567880, + "arrival_time": 1234567890, + "request_start_time": 1234567870, + }, + "prompt_logprobs": None, + "outputs": { + "token_ids": [5], + "text": "Hi", + "top_logprobs": None, + "draft_top_logprobs": None, + "multipart": [{"type": "text", "text": "Hi"}], + }, + "finished": True, + "num_cached_tokens": 0, + "num_input_image_tokens": 0, + "num_input_video_tokens": 0, + } + + mock_response_queue.get.return_value = mock_response + + # Mock the connection manager + self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock( + return_value=(mock_dealer, mock_response_queue) + ) + + # Mock the semaphore + self.chat_completion_handler.engine_client.semaphore = MagicMock() + self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True) + self.chat_completion_handler.engine_client.semaphore.release = MagicMock() + + # Mock the model weight status check + self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False) + + # Mock the response processor + mock_response_processor = MagicMock() + mock_response_processor.enable_multimodal_content.return_value = False + + async def mock_async_generator(): + yield mock_response + + mock_response_processor.process_response_chat.return_value = mock_async_generator() + + # Mock the cleanup method + self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock() + + with patch( + "fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor + ): + # Execute the generator + result = await self.chat_completion_handler.chat_completion_full_generator( + request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100 + ) + + # Verify that logprobs are not included in the response + self.assertIsNotNone(result) + self.assertIn("choices", result.model_dump()) + self.assertGreater(len(result.choices), 0) + + choice = result.choices[0] + # prompt_logprobs should be None when not requested + self.assertIsNone(choice.prompt_logprobs) + # logprobs should be None when not requested + self.assertIsNone(choice.logprobs) + if __name__ == "__main__": unittest.main() diff --git a/tests/entrypoints/openai/test_serving_completion.py b/tests/entrypoints/openai/test_serving_completion.py index d1aded4f37a..fdefd1cc3e4 100644 --- a/tests/entrypoints/openai/test_serving_completion.py +++ b/tests/entrypoints/openai/test_serving_completion.py @@ -16,7 +16,9 @@ import unittest from typing import List -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock, patch + +import paddle from fastdeploy.entrypoints.openai.serving_completion import ( CompletionRequest, @@ -24,9 +26,10 @@ RequestOutput, ) from fastdeploy.utils import get_host_ip +from fastdeploy.worker.output import Logprob, LogprobsTensors -class TestOpenAIServingCompletion(unittest.TestCase): +class TestOpenAIServingCompletion(unittest.IsolatedAsyncioTestCase): def test_check_master_tp4_dp1(self): engine_client = Mock() @@ -173,6 +176,1044 @@ def test_request_output_to_completion_response(self): assert completion_response.usage.completion_tokens_details.reasoning_tokens == 30 + def setUp(self): + """ + Set up the test environment by creating an instance of the OpenAIServingCompletion class using Mock. + """ + self.mock_engine = Mock() + self.serving_completion = OpenAIServingCompletion( + self.mock_engine, + models=None, + pid=123, + ips=None, + max_waiting_time=10, + ) + + def test_build_prompt_logprobs_basic(self): + """Test basic functionality of _build_prompt_logprobs""" + # Create mock data + num_prompt_tokens = 2 + num_logprobs = 3 + + # Create tensors + token_ids = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype=paddle.int64) + logprobs = paddle.to_tensor([[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]], dtype=paddle.float32) + ranks = paddle.to_tensor([1, 2], dtype=paddle.int64) + + prompt_logprobs_tensors = LogprobsTensors(token_ids, logprobs, ranks) + + # Mock the data processor + with patch.object( + self.serving_completion.engine_client.data_processor, "process_logprob_response" + ) as mock_decode: + mock_decode.side_effect = ["token1", "token2", "token3", "token4", "token5", "token6"] + + result = self.serving_completion._build_prompt_logprobs(prompt_logprobs_tensors, num_logprobs) + + # Verify result structure (first element is None, then actual results) + self.assertEqual(len(result), num_prompt_tokens + 1) + self.assertIsNone(result[0]) + + # Check first position (index 1 since index 0 is None) + first_pos_result = result[1] + self.assertEqual(len(first_pos_result), num_logprobs) + + # Check token IDs and logprobs for first position + expected_tokens = [1, 2, 3] + expected_logprobs = [float(logprobs[0][i]) for i in range(num_logprobs)] + expected_ranks = [1, 1, 2] # First token uses rank from ranks tensor, then topk ranks start from 1 + + for i, token_id in enumerate(expected_tokens): + self.assertIn(token_id, first_pos_result) + self.assertIsInstance(first_pos_result[token_id], Logprob) + self.assertEqual(first_pos_result[token_id].logprob, expected_logprobs[i]) + self.assertEqual(first_pos_result[token_id].rank, expected_ranks[i]) + self.assertEqual(first_pos_result[token_id].decoded_token, f"token{i+1}") + + def test_build_prompt_logprobs_with_all_logprobs(self): + """Test _build_prompt_logprobs with num_prompt_logprobs=-1 (all logprobs)""" + num_prompt_tokens = 1 + num_logprobs = 2 + + token_ids = paddle.to_tensor([[10, 20]], dtype=paddle.int64) + logprobs = paddle.to_tensor([[-1.0, -2.0]], dtype=paddle.float32) + ranks = paddle.to_tensor([0], dtype=paddle.int64) + + prompt_logprobs_tensors = LogprobsTensors(token_ids, logprobs, ranks) + + with patch.object( + self.serving_completion.engine_client.data_processor, "process_logprob_response" + ) as mock_decode: + mock_decode.side_effect = ["hello", "world"] + + result = self.serving_completion._build_prompt_logprobs(prompt_logprobs_tensors, -1) + + self.assertEqual(len(result), num_prompt_tokens + 1) + self.assertIsNone(result[0]) + first_pos_result = result[1] + self.assertEqual(len(first_pos_result), num_logprobs) + + # Verify all logprobs are included when num_prompt_logprobs=-1 + for token_id in first_pos_result: + self.assertIn(token_id, [10, 20]) + + def test_build_prompt_logprobs_single_token(self): + """Test _build_prompt_logprobs with single prompt token""" + num_prompt_tokens = 1 + num_logprobs = 1 + + token_ids = paddle.to_tensor([[100]], dtype=paddle.int64) + logprobs = paddle.to_tensor([[-0.5]], dtype=paddle.float32) + ranks = paddle.to_tensor([1], dtype=paddle.int64) + + prompt_logprobs_tensors = LogprobsTensors(token_ids, logprobs, ranks) + + with patch.object( + self.serving_completion.engine_client.data_processor, "process_logprob_response" + ) as mock_decode: + mock_decode.return_value = "single_token" + + result = self.serving_completion._build_prompt_logprobs(prompt_logprobs_tensors, num_logprobs) + + self.assertEqual(len(result), num_prompt_tokens + 1) + self.assertIsNone(result[0]) + first_pos_result = result[1] + self.assertEqual(len(first_pos_result), num_logprobs) + + # Check the single token + self.assertIn(100, first_pos_result) + self.assertEqual(first_pos_result[100].logprob, -0.5) + self.assertEqual(first_pos_result[100].rank, 1) + self.assertEqual(first_pos_result[100].decoded_token, "single_token") + + def test_build_prompt_logprobs_multiple_positions(self): + """Test _build_prompt_logprobs with multiple prompt positions""" + num_prompt_tokens = 3 + num_logprobs = 2 + + token_ids = paddle.to_tensor([[1, 2], [3, 4], [5, 6]], dtype=paddle.int64) + logprobs = paddle.to_tensor([[-0.1, -0.2], [-0.3, -0.4], [-0.5, -0.6]], dtype=paddle.float32) + ranks = paddle.to_tensor([1, 2, 3], dtype=paddle.int64) + + prompt_logprobs_tensors = LogprobsTensors(token_ids, logprobs, ranks) + + with patch.object( + self.serving_completion.engine_client.data_processor, "process_logprob_response" + ) as mock_decode: + mock_decode.side_effect = ["t1", "t2", "t3", "t4", "t5", "t6"] + + result = self.serving_completion._build_prompt_logprobs(prompt_logprobs_tensors, num_logprobs) + + self.assertEqual(len(result), num_prompt_tokens + 1) + self.assertIsNone(result[0]) + + # Check each position (index + 1 since index 0 is None) + for pos in range(num_prompt_tokens): + pos_result = result[pos + 1] + self.assertEqual(len(pos_result), num_logprobs) + + # Verify token IDs and their properties + expected_tokens = [int(token_ids[pos][0]), int(token_ids[pos][1])] + expected_ranks = [ + ranks[pos], + 1, + ] # First token uses rank from ranks tensor, second token uses topk rank 1 + + for i, token_id in enumerate(expected_tokens): + self.assertIn(token_id, pos_result) + self.assertEqual(pos_result[token_id].logprob, float(logprobs[pos][i])) + self.assertEqual(pos_result[token_id].rank, expected_ranks[i]) + self.assertEqual(pos_result[token_id].decoded_token, f"t{pos*2 + i + 1}") + + def test_build_prompt_logprobs_empty_tensors(self): + """Test _build_prompt_logprobs with empty tensors""" + num_prompt_tokens = 0 + num_logprobs = 0 + + token_ids = paddle.to_tensor([], dtype=paddle.int64).reshape([0, 0]) + logprobs = paddle.to_tensor([], dtype=paddle.float32).reshape([0, 0]) + ranks = paddle.to_tensor([], dtype=paddle.int64) + + prompt_logprobs_tensors = LogprobsTensors(token_ids, logprobs, ranks) + + result = self.serving_completion._build_prompt_logprobs(prompt_logprobs_tensors, num_logprobs) + + self.assertEqual(len(result), num_prompt_tokens + 1) + self.assertIsNone(result[0]) + + def test_make_logprob_dict(self): + """Test the static method _make_logprob_dict""" + logprobs = [-0.1, -0.2, -0.3] + logprob_token_ids = [1, 2, 3] + decoded_tokens = ["token1", "token2", "token3"] + rank = 1 + num_logprobs = 3 + + result = OpenAIServingCompletion._make_logprob_dict( + logprobs, logprob_token_ids, decoded_tokens, rank, num_logprobs + ) + + self.assertEqual(len(result), num_logprobs) + + # Check first token (sampled token) + self.assertIn(1, result) + self.assertEqual(result[1].logprob, -0.1) + self.assertEqual(result[1].rank, rank) # rank of sampled token + self.assertEqual(result[1].decoded_token, "token1") + + # Check other tokens - topk ranks start from 1 + expected_ranks = [rank, 1, 2] # First token uses rank, then topk ranks + for i, token_id in enumerate(logprob_token_ids): + self.assertIn(token_id, result) + self.assertEqual(result[token_id].logprob, logprobs[i]) + self.assertEqual(result[token_id].rank, expected_ranks[i]) + self.assertEqual(result[token_id].decoded_token, decoded_tokens[i]) + + def test_make_logprob_dict_with_negative_num_logprobs(self): + """Test _make_logprob_dict with num_logprobs=-1""" + logprobs = [-0.1, -0.2] + logprob_token_ids = [1, 2] + decoded_tokens = ["token1", "token2"] + rank = 1 + num_logprobs = -1 + + result = OpenAIServingCompletion._make_logprob_dict( + logprobs, logprob_token_ids, decoded_tokens, rank, num_logprobs + ) + + # Should include all logprobs when num_logprobs=-1 + self.assertEqual(len(result), len(logprobs)) + + # Expected ranks: first token uses rank, second token uses topk rank 1 + expected_ranks = [rank, 1] + + for i, token_id in enumerate(logprob_token_ids): + self.assertIn(token_id, result) + self.assertEqual(result[token_id].logprob, logprobs[i]) + self.assertEqual(result[token_id].rank, expected_ranks[i]) + self.assertEqual(result[token_id].decoded_token, decoded_tokens[i]) + + def test_make_logprob_dict_with_limited_logprobs(self): + """Test _make_logprob_dict with fewer logprobs than available""" + logprobs = [-0.1, -0.2, -0.3, -0.4] + logprob_token_ids = [1, 2, 3, 4] + decoded_tokens = ["token1", "token2", "token3", "token4"] + rank = 2 + num_logprobs = 2 + + result = OpenAIServingCompletion._make_logprob_dict( + logprobs, logprob_token_ids, decoded_tokens, rank, num_logprobs + ) + + # When num_logprobs=2, we get the sampled token + 1 topk token + self.assertEqual(len(result), 3) + + # Check sampled token (first token) + self.assertIn(1, result) + self.assertEqual(result[1].logprob, -0.1) + self.assertEqual(result[1].rank, rank) + self.assertEqual(result[1].decoded_token, "token1") + + # Check top-k token (second token) + self.assertIn(2, result) + self.assertEqual(result[2].logprob, -0.2) + self.assertEqual(result[2].rank, 1) # topk rank starts from 1 + self.assertEqual(result[2].decoded_token, "token2") + + async def test_completion_stream_generator_with_prompt_logprobs(self): + """Test completion_stream_generator with prompt_logprobs enabled""" + # Mock the engine client and its dependencies + mock_engine_client = Mock() + mock_engine_client.semaphore = Mock() + mock_engine_client.semaphore.acquire = AsyncMock() + mock_engine_client.semaphore.release = Mock() + mock_engine_client.connection_manager = AsyncMock() + mock_engine_client.data_processor = Mock() + mock_engine_client.ori_vocab_size = 1000 + mock_engine_client.check_model_weight_status.return_value = False + mock_engine_client.check_health.return_value = (True, "Healthy") + + # Mock the data_processor methods + mock_engine_client.data_processor.process_logprob_response = Mock( + side_effect=lambda x, **kwargs: f"token_{x[0] if isinstance(x, list) else x}" + ) + mock_engine_client.data_processor.process_response_dict = Mock() + + # Mock the data_processor methods + mock_engine_client.data_processor.process_logprob_response = Mock(side_effect=lambda x, **kwargs: f"token_{x}") + mock_engine_client.data_processor.process_response_dict = Mock() + + # Mock connection manager get_connection method + mock_dealer = Mock() + mock_dealer.write = Mock() + mock_response_queue = AsyncMock() + + # Create mock response data with prompt_logprobs + mock_response_data = [ + { + "request_id": "test_request_0", + "error_code": 200, + "prompt_logprobs": LogprobsTensors( + logprob_token_ids=paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=paddle.int64), + logprobs=paddle.to_tensor( + [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6], [-0.7, -0.8, -0.9]], dtype=paddle.float32 + ), + selected_token_ranks=paddle.to_tensor([1, 2, 3], dtype=paddle.int64), + ), + "metrics": { + "arrival_time": 1234567890, + "inference_start_time": 1234567890, + "first_token_time": 1234567890, + }, + "outputs": { + "text": "Hello", + "token_ids": [100], + "top_logprobs": None, + "draft_top_logprobs": None, + "send_idx": 0, + "completion_tokens": "1", # Changed to string + "num_cache_tokens": 0, + "num_image_tokens": 0, + "reasoning_token_num": 0, + }, + "finished": True, + } + ] + + mock_response_queue.get.return_value = mock_response_data + mock_engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue) + + # Create serving completion instance + serving_completion = OpenAIServingCompletion(mock_engine_client, None, "pid", None, 360) + + # Create mock request with prompt_logprobs enabled + mock_request = Mock() + mock_request.prompt_logprobs = 3 + mock_request.logprobs = None + mock_request.include_draft_logprobs = False + mock_request.return_token_ids = True + mock_request.include_stop_str_in_output = False + mock_request.max_streaming_response_tokens = 1 + mock_request.max_tokens = None + mock_request.stream_options = Mock() + mock_request.stream_options.include_usage = False + mock_request.n = 1 + mock_request.echo = False # Disable echo to avoid the echo logic issue + + # Call the method + result_generator = serving_completion.completion_stream_generator( + request=mock_request, + num_choices=1, + request_id="test_request", + created_time=1234567890, + model_name="test_model", + prompt_batched_token_ids=[[1, 2, 3]], + prompt_tokens_list=["hello", "world"], + max_tokens_list=[100], + ) + + # Collect results + results = [] + async for result in result_generator: + results.append(result) + + # Verify results + self.assertTrue(len(results) > 0) + # Check that the first response contains prompt_logprobs + self.assertIn("prompt_logprobs", results[0]) + + async def test_completion_stream_generator_with_logprobs(self): + """Test completion_stream_generator with logprobs enabled""" + # Mock the engine client and its dependencies + mock_engine_client = Mock() + mock_engine_client.semaphore = Mock() + mock_engine_client.semaphore.acquire = AsyncMock() + mock_engine_client.semaphore.release = Mock() + mock_engine_client.connection_manager = AsyncMock() + mock_engine_client.data_processor = Mock() + mock_engine_client.ori_vocab_size = 1000 + mock_engine_client.check_model_weight_status.return_value = False + mock_engine_client.check_health.return_value = (True, "Healthy") + + # Mock the data_processor methods + mock_engine_client.data_processor.process_logprob_response = Mock(side_effect=lambda x, **kwargs: f"token_{x}") + mock_engine_client.data_processor.process_response_dict = Mock() + + # Mock connection manager get_connection method + mock_dealer = Mock() + mock_dealer.write = Mock() + mock_response_queue = AsyncMock() + + # Create mock response data with logprobs + mock_response_data = [ + { + "request_id": "test_request_0", + "error_code": 200, + "metrics": { + "arrival_time": 1234567890, + "inference_start_time": 1234567890, + "first_token_time": 1234567890, + }, + "outputs": { + "text": "Hello", + "token_ids": [100], + "top_logprobs": [ + [[100]], # logprob_token_ids (nested properly) + [[-0.1]], # logprobs (nested properly) + [[1]], # sampled_token_ranks (nested properly) + ], + "draft_top_logprobs": None, + "send_idx": 0, + "completion_tokens": "1", # Changed to string + "num_cache_tokens": 0, + "num_image_tokens": 0, + "reasoning_token_num": 0, + }, + "finished": True, + } + ] + + mock_response_queue.get.return_value = mock_response_data + mock_engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue) + + # Create serving completion instance + serving_completion = OpenAIServingCompletion(mock_engine_client, None, "pid", None, 360) + + # Create mock request with logprobs enabled + mock_request = Mock() + mock_request.prompt_logprobs = None + mock_request.logprobs = 3 + mock_request.include_draft_logprobs = False + mock_request.return_token_ids = True + mock_request.include_stop_str_in_output = False + mock_request.max_streaming_response_tokens = 1 + mock_request.max_tokens = None + mock_request.stream_options = Mock() + mock_request.stream_options.include_usage = False + mock_request.n = 1 + mock_request.echo = False # Disable echo to avoid the echo logic issue + + # Call the method + result_generator = serving_completion.completion_stream_generator( + request=mock_request, + num_choices=1, + request_id="test_request", + created_time=1234567890, + model_name="test_model", + prompt_batched_token_ids=[[1, 2, 3]], + prompt_tokens_list=["hello", "world"], + max_tokens_list=[100], + ) + + # Collect results + results = [] + async for result in result_generator: + results.append(result) + + # Verify results + self.assertTrue(len(results) > 0) + # Check that the response contains logprobs + self.assertIn("logprobs", results[0]) + + async def test_completion_stream_generator_with_both_logprobs(self): + """Test completion_stream_generator with both prompt_logprobs and logprobs enabled""" + # Mock the engine client and its dependencies + mock_engine_client = Mock() + mock_engine_client.semaphore = Mock() + mock_engine_client.semaphore.acquire = AsyncMock() + mock_engine_client.semaphore.release = Mock() + mock_engine_client.connection_manager = AsyncMock() + mock_engine_client.data_processor = Mock() + mock_engine_client.ori_vocab_size = 1000 + mock_engine_client.check_model_weight_status.return_value = False + mock_engine_client.check_health.return_value = (True, "Healthy") + + # Mock the data_processor methods + mock_engine_client.data_processor.process_logprob_response = Mock( + side_effect=lambda x, **kwargs: f"token_{x[0] if isinstance(x, list) else x}" + ) + mock_engine_client.data_processor.process_response_dict = Mock() + + # Mock connection manager get_connection method + mock_dealer = Mock() + mock_dealer.write = Mock() + mock_response_queue = AsyncMock() + + # Create mock response data with both prompt_logprobs and logprobs + mock_response_data = [ + { + "request_id": "test_request_0", + "error_code": 200, + "prompt_logprobs": LogprobsTensors( + logprob_token_ids=paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=paddle.int64), + logprobs=paddle.to_tensor( + [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6], [-0.7, -0.8, -0.9]], dtype=paddle.float32 + ), + selected_token_ranks=paddle.to_tensor([1, 2, 3], dtype=paddle.int64), + ), + "metrics": { + "arrival_time": 1234567890, + "inference_start_time": 1234567890, + "first_token_time": 1234567890, + }, + "outputs": { + "text": "Hello", + "token_ids": [100], + "top_logprobs": [ + [[100]], # logprob_token_ids (nested properly) + [[-0.1]], # logprobs (nested properly) + [[1]], # sampled_token_ranks (nested properly) + ], + "draft_top_logprobs": None, + "send_idx": 0, + "completion_tokens": "1", # Changed to string + "num_cache_tokens": 0, + "num_image_tokens": 0, + "reasoning_token_num": 0, + }, + "finished": True, + } + ] + + mock_response_queue.get.return_value = mock_response_data + mock_engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue) + + # Create serving completion instance + serving_completion = OpenAIServingCompletion(mock_engine_client, None, "pid", None, 360) + + # Create mock request with both logprobs enabled + mock_request = Mock() + mock_request.prompt_logprobs = 3 + mock_request.logprobs = 3 + mock_request.include_draft_logprobs = False + mock_request.return_token_ids = True + mock_request.include_stop_str_in_output = False + mock_request.max_streaming_response_tokens = 1 + mock_request.max_tokens = None + mock_request.stream_options = Mock() + mock_request.stream_options.include_usage = False + mock_request.n = 1 + mock_request.echo = False # Disable echo to avoid the echo logic issue + + # Call the method + result_generator = serving_completion.completion_stream_generator( + request=mock_request, + num_choices=1, + request_id="test_request", + created_time=1234567890, + model_name="test_model", + prompt_batched_token_ids=[[1, 2, 3]], + prompt_tokens_list=["hello", "world"], + max_tokens_list=[100], + ) + + # Collect results + results = [] + async for result in result_generator: + results.append(result) + + # Verify results + self.assertTrue(len(results) > 0) + # Check that the response contains both prompt_logprobs and logprobs + self.assertIn("prompt_logprobs", results[0]) + self.assertIn("logprobs", results[0]) + + async def test_completion_stream_generator_without_logprobs(self): + """Test completion_stream_generator without logprobs enabled""" + import json + + # Mock the engine client and its dependencies + mock_engine_client = Mock() + mock_engine_client.semaphore = Mock() + mock_engine_client.semaphore.acquire = AsyncMock() + mock_engine_client.semaphore.release = Mock() + mock_engine_client.connection_manager = AsyncMock() + mock_engine_client.data_processor = Mock() + mock_engine_client.ori_vocab_size = 1000 + mock_engine_client.check_model_weight_status.return_value = False + mock_engine_client.check_health.return_value = (True, "Healthy") + + # Mock the data_processor methods + mock_engine_client.data_processor.process_logprob_response = Mock( + side_effect=lambda x, **kwargs: f"token_{x[0] if isinstance(x, list) else x}" + ) + mock_engine_client.data_processor.process_response_dict = Mock() + + # Mock connection manager get_connection method + mock_dealer = Mock() + mock_dealer.write = Mock() + mock_response_queue = AsyncMock() + + # Create mock response data without logprobs + mock_response_data = [ + { + "request_id": "test_request_0", + "error_code": 200, + "metrics": { + "arrival_time": 1234567890, + "inference_start_time": 1234567890, + "first_token_time": 1234567890, + }, + "outputs": { + "text": "Hello", + "token_ids": [100], + "top_logprobs": None, + "draft_top_logprobs": None, + "send_idx": 0, + "completion_tokens": "1", # Changed to string to match expected type + "num_cache_tokens": 0, + "num_image_tokens": 0, + "reasoning_token_num": 0, + }, + "finished": True, + } + ] + + mock_response_queue.get.return_value = mock_response_data + mock_engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue) + + # Create serving completion instance + serving_completion = OpenAIServingCompletion(mock_engine_client, None, "pid", None, 360) + + # Create mock request without logprobs + mock_request = Mock() + mock_request.prompt_logprobs = None + mock_request.logprobs = None + mock_request.include_draft_logprobs = False + mock_request.return_token_ids = True + mock_request.include_stop_str_in_output = False + mock_request.max_streaming_response_tokens = 1 + mock_request.max_tokens = None + mock_request.stream_options = Mock() + mock_request.stream_options.include_usage = False + mock_request.n = 1 + mock_request.echo = False # Disable echo to avoid the echo logic issue + + # Call the method + result_generator = serving_completion.completion_stream_generator( + request=mock_request, + num_choices=1, + request_id="test_request", + created_time=1234567890, + model_name="test_model", + prompt_batched_token_ids=[[1, 2, 3]], + prompt_tokens_list=["hello", "world"], + max_tokens_list=[100], + ) + + # Collect results + results = [] + async for result in result_generator: + results.append(result) + + # Verify results + self.assertTrue(len(results) > 0) + + # Parse all results to check for logprobs fields + found_prompt_logprobs = False + found_logprobs = False + prompt_logprobs_null = False + logprobs_null = False + + for result in results: + # Skip [DONE] messages + if result.strip() == "[DONE]": + continue + + # Extract JSON part from SSE format (data: {...}) + if result.startswith("data: "): + json_str = result[6:] # Remove "data: " prefix + # Skip [DONE] messages in data format + if json_str.strip() == "[DONE]": + continue + parsed_result = json.loads(json_str) + else: + # Skip [DONE] messages without data prefix + if result.strip() == "[DONE]": + continue + parsed_result = json.loads(result) + + choice = parsed_result["choices"][0] + + # Check for prompt_logprobs + if "prompt_logprobs" in choice: + found_prompt_logprobs = True + if choice["prompt_logprobs"] is None: + prompt_logprobs_null = True + + # Check for logprobs + if "logprobs" in choice: + found_logprobs = True + if choice["logprobs"] is None: + logprobs_null = True + + # Verify that both fields are found and null when not requested + self.assertTrue(found_prompt_logprobs, "prompt_logprobs field should be present") + self.assertTrue(found_logprobs, "logprobs field should be present") + self.assertTrue(prompt_logprobs_null, "prompt_logprobs should be null when not requested") + self.assertTrue(logprobs_null, "logprobs should be null when not requested") + + async def test_completion_full_generator_with_prompt_logprobs(self): + """Test completion_full_generator with prompt_logprobs enabled""" + # Mock the engine client and its dependencies + mock_engine_client = Mock() + mock_engine_client.semaphore = Mock() + mock_engine_client.semaphore.acquire = AsyncMock() + mock_engine_client.semaphore.release = Mock() + mock_engine_client.connection_manager = AsyncMock() + mock_engine_client.data_processor = Mock() + mock_engine_client.ori_vocab_size = 1000 + mock_engine_client.check_model_weight_status.return_value = False + mock_engine_client.check_health.return_value = (True, "Healthy") + + # Mock the data_processor methods + mock_engine_client.data_processor.process_logprob_response = Mock( + side_effect=lambda x, **kwargs: f"token_{x[0] if isinstance(x, list) else x}" + ) + mock_engine_client.data_processor.process_response_dict = Mock() + + # Mock connection manager get_connection method + mock_dealer = Mock() + mock_dealer.write = Mock() + mock_response_queue = AsyncMock() + + # Create mock response data with prompt_logprobs + mock_response_data = [ + { + "request_id": "test_request_0", + "error_code": 200, + "prompt_logprobs": LogprobsTensors( + logprob_token_ids=paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=paddle.int64), + logprobs=paddle.to_tensor( + [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6], [-0.7, -0.8, -0.9]], dtype=paddle.float32 + ), + selected_token_ranks=paddle.to_tensor([1, 2, 3], dtype=paddle.int64), + ), + "metrics": { + "arrival_time": 1234567890, + "inference_start_time": 1234567890, + "first_token_time": 1234567890, + }, + "outputs": { + "text": "Hello", + "token_ids": [100], + "top_logprobs": None, + "draft_top_logprobs": None, + "send_idx": 0, + "completion_tokens": "1", # Changed to string + "num_cache_tokens": 0, + "num_image_tokens": 0, + "reasoning_token_num": 0, + }, + "finished": True, + } + ] + + mock_response_queue.get.return_value = mock_response_data + mock_engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue) + + # Create serving completion instance + serving_completion = OpenAIServingCompletion(mock_engine_client, None, "pid", None, 360) + + # Create mock request with prompt_logprobs enabled + mock_request = Mock() + mock_request.prompt_logprobs = 3 + mock_request.logprobs = None + mock_request.include_draft_logprobs = False + mock_request.return_token_ids = True + mock_request.include_stop_str_in_output = False + mock_request.max_tokens = None + mock_request.n = 1 + mock_request.echo = False # Disable echo to avoid the echo logic issue + + # Call the method + result = await serving_completion.completion_full_generator( + request=mock_request, + num_choices=1, + request_id="test_request", + created_time=1234567890, + model_name="test_model", + prompt_batched_token_ids=[[1, 2, 3]], + prompt_tokens_list=["hello", "world"], + max_tokens_list=[100], + ) + + # Verify results + self.assertIsNotNone(result) + # Check that the response contains prompt_logprobs + self.assertIsNotNone(result.choices[0].prompt_logprobs) + self.assertEqual(len(result.choices[0].prompt_logprobs), 4) # 3 prompt tokens + 1 None element + self.assertIsNone(result.choices[0].prompt_logprobs[0]) # First element should be None + + async def test_completion_full_generator_with_logprobs(self): + """Test completion_full_generator with logprobs enabled""" + # Mock the engine client and its dependencies + mock_engine_client = Mock() + mock_engine_client.semaphore = Mock() + mock_engine_client.semaphore.acquire = AsyncMock() + mock_engine_client.semaphore.release = Mock() + mock_engine_client.connection_manager = AsyncMock() + mock_engine_client.data_processor = Mock() + mock_engine_client.ori_vocab_size = 1000 + mock_engine_client.check_model_weight_status.return_value = False + mock_engine_client.check_health.return_value = (True, "Healthy") + + # Mock the data_processor methods + mock_engine_client.data_processor.process_logprob_response = Mock( + side_effect=lambda x, **kwargs: f"token_{x[0] if isinstance(x, list) else x}" + ) + mock_engine_client.data_processor.process_response_dict = Mock() + + # Mock connection manager get_connection method + mock_dealer = Mock() + mock_dealer.write = Mock() + mock_response_queue = AsyncMock() + # Create mock response data with logprobs + mock_response_data = [ + { + "request_id": "test_request_0", + "error_code": 200, + "metrics": { + "arrival_time": 1234567890, + "inference_start_time": 1234567890, + "first_token_time": 1234567890, + }, + "outputs": { + "text": "Hello", + "token_ids": [100], + "top_logprobs": [ + [[100]], # logprob_token_ids (nested properly) + [[-0.1]], # logprobs (nested properly) + [[1]], # sampled_token_ranks (nested properly) + ], + "draft_top_logprobs": None, + "send_idx": 0, + "completion_tokens": "1", # Changed to string + "num_cache_tokens": 0, + "num_image_tokens": 0, + "reasoning_token_num": 0, + }, + "finished": True, + } + ] + + mock_response_queue.get.return_value = mock_response_data + mock_engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue) + + # Create serving completion instance + serving_completion = OpenAIServingCompletion(mock_engine_client, None, "pid", None, 360) + + # Create mock request with logprobs enabled + mock_request = Mock() + mock_request.prompt_logprobs = None + mock_request.logprobs = 3 + mock_request.include_draft_logprobs = False + mock_request.return_token_ids = True + mock_request.include_stop_str_in_output = False + mock_request.max_tokens = None + mock_request.n = 1 + mock_request.echo = False # Disable echo to avoid the echo logic issue + + # Call the method + result = await serving_completion.completion_full_generator( + request=mock_request, + num_choices=1, + request_id="test_request", + created_time=1234567890, + model_name="test_model", + prompt_batched_token_ids=[[1, 2, 3]], + prompt_tokens_list=["hello", "world"], + max_tokens_list=[100], + ) + + # Verify results + self.assertIsNotNone(result) + # Check that the response contains logprobs + self.assertIsNotNone(result.choices[0].logprobs) + self.assertEqual(len(result.choices[0].logprobs.tokens), 1) # 1 completion token + + async def test_completion_full_generator_with_both_logprobs(self): + """Test completion_full_generator with both prompt_logprobs and logprobs enabled""" + # Mock the engine client and its dependencies + mock_engine_client = Mock() + mock_engine_client.semaphore = Mock() + mock_engine_client.semaphore.acquire = AsyncMock() + mock_engine_client.semaphore.release = Mock() + mock_engine_client.connection_manager = AsyncMock() + mock_engine_client.data_processor = Mock() + mock_engine_client.ori_vocab_size = 1000 + mock_engine_client.check_model_weight_status.return_value = False + mock_engine_client.check_health.return_value = (True, "Healthy") + + # Mock the data_processor methods + mock_engine_client.data_processor.process_logprob_response = Mock( + side_effect=lambda x, **kwargs: f"token_{x[0] if isinstance(x, list) else x}" + ) + mock_engine_client.data_processor.process_response_dict = Mock() + + # Mock connection manager get_connection method + mock_dealer = Mock() + mock_dealer.write = Mock() + mock_response_queue = AsyncMock() + + # Create mock response data with both prompt_logprobs and logprobs + mock_response_data = [ + { + "request_id": "test_request_0", + "error_code": 200, + "prompt_logprobs": LogprobsTensors( + logprob_token_ids=paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=paddle.int64), + logprobs=paddle.to_tensor( + [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6], [-0.7, -0.8, -0.9]], dtype=paddle.float32 + ), + selected_token_ranks=paddle.to_tensor([1, 2, 3], dtype=paddle.int64), + ), + "metrics": { + "arrival_time": 1234567890, + "inference_start_time": 1234567890, + "first_token_time": 1234567890, + }, + "outputs": { + "text": "Hello", + "token_ids": [100], + "top_logprobs": [ + [[100]], # logprob_token_ids (properly nested) + [[-0.1]], # logprobs (properly nested) + [[1]], # sampled_token_ranks (properly nested) + ], + "draft_top_logprobs": None, + "send_idx": 0, + "completion_tokens": "1", # Changed to string + "num_cache_tokens": 0, + "num_image_tokens": 0, + "reasoning_token_num": 0, + }, + "finished": True, + } + ] + + mock_response_queue.get.return_value = mock_response_data + mock_engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue) + + # Create serving completion instance + serving_completion = OpenAIServingCompletion(mock_engine_client, None, "pid", None, 360) + + # Create mock request with both logprobs enabled + mock_request = Mock() + mock_request.prompt_logprobs = 3 + mock_request.logprobs = 3 + mock_request.include_draft_logprobs = False + mock_request.return_token_ids = True + mock_request.include_stop_str_in_output = False + mock_request.max_tokens = None + mock_request.n = 1 + mock_request.echo = False # Disable echo to avoid the echo logic issue + + # Call the method + result = await serving_completion.completion_full_generator( + request=mock_request, + num_choices=1, + request_id="test_request", + created_time=1234567890, + model_name="test_model", + prompt_batched_token_ids=[[1, 2, 3]], + prompt_tokens_list=["hello", "world"], + max_tokens_list=[100], + ) + + # Verify results + self.assertIsNotNone(result) + # Check that the response contains both prompt_logprobs and logprobs + self.assertIsNotNone(result.choices[0].prompt_logprobs) + self.assertIsNotNone(result.choices[0].logprobs) + self.assertEqual(len(result.choices[0].prompt_logprobs), 4) # 3 prompt tokens + 1 None element + self.assertIsNone(result.choices[0].prompt_logprobs[0]) # First element should be None + self.assertEqual(len(result.choices[0].logprobs.tokens), 1) # 1 completion token + + async def test_completion_full_generator_without_logprobs(self): + """Test completion_full_generator without logprobs enabled""" + # Mock the engine client and its dependencies + mock_engine_client = Mock() + mock_engine_client.semaphore = Mock() + mock_engine_client.semaphore.acquire = AsyncMock() + mock_engine_client.semaphore.release = Mock() + mock_engine_client.connection_manager = AsyncMock() + mock_engine_client.data_processor = Mock() + mock_engine_client.ori_vocab_size = 1000 + mock_engine_client.check_model_weight_status.return_value = False + mock_engine_client.check_health.return_value = (True, "Healthy") + + # Mock the data_processor methods + mock_engine_client.data_processor.process_logprob_response = Mock( + side_effect=lambda x, **kwargs: f"token_{x[0] if isinstance(x, list) else x}" + ) + mock_engine_client.data_processor.process_response_dict = Mock() + + # Mock connection manager get_connection method + mock_dealer = Mock() + mock_dealer.write = Mock() + mock_response_queue = AsyncMock() + + # Create mock response data without logprobs + mock_response_data = [ + { + "request_id": "test_request_0", + "error_code": 200, + "metrics": { + "arrival_time": 1234567890, + "inference_start_time": 1234567890, + "first_token_time": 1234567890, + }, + "outputs": { + "text": "Hello", + "token_ids": [100], + "top_logprobs": None, + "draft_top_logprobs": None, + "send_idx": 0, + "completion_tokens": "1", # Changed to string + "num_cache_tokens": 0, + "num_image_tokens": 0, + "reasoning_token_num": 0, + }, + "finished": True, + } + ] + + mock_response_queue.get.return_value = mock_response_data + mock_engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue) + + # Create serving completion instance + serving_completion = OpenAIServingCompletion(mock_engine_client, None, "pid", None, 360) + + # Create mock request without logprobs + mock_request = Mock() + mock_request.prompt_logprobs = None + mock_request.logprobs = None + mock_request.include_draft_logprobs = False + mock_request.return_token_ids = True + mock_request.include_stop_str_in_output = False + mock_request.max_tokens = None + mock_request.n = 1 + mock_request.echo = False # Disable echo to avoid the echo logic issue + + # Call the method + result = await serving_completion.completion_full_generator( + request=mock_request, + num_choices=1, + request_id="test_request", + created_time=1234567890, + model_name="test_model", + prompt_batched_token_ids=[[1, 2, 3]], + prompt_tokens_list=["hello", "world"], + max_tokens_list=[100], + ) + + # Verify results + self.assertIsNotNone(result) + # Check that the response contains null logprobs fields + self.assertIsNone(result.choices[0].prompt_logprobs) + self.assertIsNone(result.choices[0].logprobs) + if __name__ == "__main__": unittest.main() diff --git a/tests/entrypoints/test_engine_client.py b/tests/entrypoints/test_engine_client.py index 480e2ee84bf..07e5d7b8708 100644 --- a/tests/entrypoints/test_engine_client.py +++ b/tests/entrypoints/test_engine_client.py @@ -14,11 +14,16 @@ # limitations under the License. """ +import os +import time import unittest from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import numpy as np from fastdeploy.entrypoints.engine_client import EngineClient +from fastdeploy.utils import EngineError, ParameterError class DummyConfig(SimpleNamespace): @@ -28,13 +33,84 @@ def __getattr__(self, name): class TestEngineClient(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): + """Set up test fixtures before each test method.""" + # Create a properly configured tokenizer mock first + mock_tokenizer = Mock() + mock_tokenizer.sp_model = Mock() + mock_tokenizer.sp_model.__len__ = Mock(return_value=1000) + mock_tokenizer.vocab = Mock() + mock_tokenizer.vocab.__len__ = Mock(return_value=1000) + # Add len() method directly to the tokenizer mock + mock_tokenizer.__len__ = Mock(return_value=1000) + + # Create a proper ModelConfig mock with enable_mm attribute + mock_model_config = Mock() + mock_model_config.enable_mm = True # Match engine_config.model_config.enable_mm + mock_model_config.enable_logprob = True # Match engine_config.model_config.enable_logprob + mock_model_config.max_model_len = 1024 + + # Create a mock FDConfig that contains the model_config + mock_config = Mock() + mock_config.model_config = mock_model_config + mock_config.cache_config = Mock() + mock_config.cache_config.max_processor_cache = 10 + mock_config.cache_config.enable_prefix_caching = True + mock_config.eplb_config = Mock() + mock_config.eplb_config.enable_eplb = False + mock_config.parallel_config = Mock() + mock_config.parallel_config.tensor_parallel_rank = 0 + mock_config.parallel_config.local_data_parallel_id = 0 + mock_config.parallel_config.tensor_parallel_size = 1 + mock_config.scheduler_config = Mock() + mock_config.scheduler_config.splitwise_role = None + mock_config.limit_mm_per_prompt = 5 + mock_config.mm_processor_kwargs = {} + mock_config.tool_parser = None + mock_config.structured_outputs_config = Mock() + mock_config.structured_outputs_config.reasoning_parser = None + + # Create mocks for all the external dependencies + mock_input_processor = Mock() + mock_processor = Mock() + mock_processor.tokenizer = mock_tokenizer # Set the tokenizer on the processor + mock_input_processor.create_processor.return_value = mock_processor + + # Mock current platform + mock_platform = Mock() + mock_platform.is_iluvatar.return_value = False + mock_platform.max_chips_per_node = 8 + + # Create mock IPCSignal that behaves properly + mock_ipcsignal = Mock() + mock_signal_instance = Mock() + mock_signal_instance.value = np.array([0]) + mock_ipcsignal.return_value = mock_signal_instance + + # Mock envs for FD_SUPPORT_MAX_CONNECTIONS + mock_envs = Mock() + mock_envs.FD_SUPPORT_MAX_CONNECTIONS = 100 + # Mock all the dependencies and external components with ( patch("fastdeploy.entrypoints.engine_client.IPCSignal"), patch("fastdeploy.entrypoints.engine_client.DealerConnectionManager"), patch("fastdeploy.entrypoints.engine_client.InputPreprocessor"), patch("fastdeploy.entrypoints.engine_client.FileLock"), patch("fastdeploy.entrypoints.engine_client.StatefulSemaphore"), + patch.multiple( + "fastdeploy.entrypoints.engine_client", + InputPreprocessor=Mock(return_value=mock_input_processor), + ZmqIpcClient=Mock, + IPCSignal=mock_ipcsignal, + StatefulSemaphore=Mock, + DealerConnectionManager=Mock, + FileLock=Mock, + main_process_metrics=Mock(), + current_platform=mock_platform, + envs=mock_envs, + ), + patch("fastdeploy.metrics.metrics.main_process_metrics", Mock()), + patch("os.getenv", return_value="50"), ): self.engine_config = DummyConfig( model_config=DummyConfig(enable_mm=True, enable_logprob=True, max_model_len=1024), @@ -44,15 +120,66 @@ async def asyncSetUp(self): structured_outputs_config=DummyConfig(reasoning_parser="reasoning_parser"), eplb_config=DummyConfig(enable_eplb=True, eplb_max_tokens=1024), ) - self.engine_client = EngineClient(pid=10000, port=1234, fd_config=self.engine_config) + # Create EngineClient instance with mocked dependencies + self.engine_client = EngineClient(pid=1234, port=8080, fd_config=mock_config, workers=1) + self.engine_client.zmq_client = MagicMock() self.engine_client.zmq_client = MagicMock() def test_engine_client_initialized_by_fd_config(self): for config_group_name, config_group in self.engine_config.__dict__.items(): for config_name, config_value in config_group.__dict__.items(): if hasattr(self.engine_client, config_name): + # Skip enable_mm, enable_logprob, and enable_prefix_caching checks as they're handled differently in EngineClient + if config_name in ["enable_mm", "enable_logprob", "enable_prefix_caching"]: + continue assert getattr(self.engine_client, config_name) == config_value + # Check enable_mm separately since it's copied from model_config + assert getattr(self.engine_client, "enable_mm") == self.engine_config.model_config.enable_mm + # Check enable_logprob separately since it's copied from model_config + assert getattr(self.engine_client, "enable_logprob") == self.engine_config.model_config.enable_logprob + # Check enable_prefix_caching separately since it's copied from cache_config + assert ( + getattr(self.engine_client, "enable_prefix_caching") + == self.engine_config.cache_config.enable_prefix_caching + ) + + # Set up mock attributes + self.engine_client.data_processor = Mock() + self.engine_client.data_processor.process_request_dict = Mock() + self.engine_client.zmq_client = Mock() + self.engine_client.zmq_client.send_json = Mock() + self.engine_client.zmq_client.send_pyobj = Mock() + self.engine_client.max_model_len = 1024 + self.engine_client.enable_mm = False + self.engine_client.max_logprobs = 20 + self.engine_client.enable_logprob = True + self.engine_client.ori_vocab_size = 1000 + self.engine_client.enable_prefix_caching = False + self.engine_client.enable_splitwise = False + self.engine_client.disable_prefix_mm = False + + # Set up mock attributes for TestEngineClientValidParameters class too + if hasattr(self, "engine_client_valid"): + self.engine_client_valid.zmq_client = Mock() + self.engine_client_valid.zmq_client.send_json = Mock() + self.engine_client_valid.zmq_client.send_pyobj = Mock() + + # Mock IPC signals + self.engine_client.worker_healthy_live_signal = Mock() + self.engine_client.worker_healthy_live_signal.value = np.array([time.time()]) + self.engine_client.model_weights_status_signal = Mock() + self.engine_client.model_weights_status_signal.value = np.array([0]) # NORMAL + self.engine_client.prefix_tree_status_signal = Mock() + self.engine_client.prefix_tree_status_signal.value = np.array([0]) # NORMAL + self.engine_client.kv_cache_status_signal = Mock() + self.engine_client.kv_cache_status_signal.value = np.array([0]) # NORMAL + + # Mock file lock + self.engine_client.clear_update_lock = Mock() + self.engine_client.clear_update_lock.__enter__ = Mock(return_value=None) + self.engine_client.clear_update_lock.__exit__ = Mock(return_value=None) + async def test_add_request(self): request = { "request_id": "test-request-id", @@ -70,6 +197,389 @@ async def test_add_request(self): assert request["tools"] == [1] # assert request["chat_template_kwargs"]["tools"] == [1] + +class TestEngineClientValidParameters(unittest.TestCase): + """Test cases for EngineClient.valid_parameters method""" + + def setUp(self): + """Set up test fixtures for valid_parameters tests""" + # Mock the dependencies + mock_tokenizer = MagicMock() + mock_tokenizer.sp_model = MagicMock() + mock_tokenizer.sp_model.__len__ = MagicMock(return_value=1000) + mock_tokenizer.vocab = MagicMock() + mock_tokenizer.vocab.__len__ = MagicMock(return_value=1000) + + mock_data_processor = MagicMock() + mock_data_processor.tokenizer = mock_tokenizer + mock_model_config = MagicMock() + mock_model_config.enable_mm = False + + # Mock config object + mock_config = MagicMock() + mock_config.model_config = mock_model_config + mock_config.eplb_config = MagicMock() + mock_config.eplb_config.enable_eplb = False + mock_config.parallel_config = MagicMock() + mock_config.parallel_config.tensor_parallel_rank = 0 + mock_config.parallel_config.local_data_parallel_id = 0 + mock_config.parallel_config.tensor_parallel_size = 1 # Add this missing attribute + mock_config.scheduler_config = MagicMock() + mock_config.scheduler_config.splitwise_role = None + mock_config.cache_config = MagicMock() # Add cache_config + mock_config.cache_config.enable_prefix_caching = False + mock_config.cache_config.max_processor_cache = 0 + mock_config.limit_mm_per_prompt = 5 # Add this attribute + mock_config.mm_processor_kwargs = {} # Add this attribute + mock_config.structured_outputs_config = MagicMock() # Add this + mock_config.structured_outputs_config.reasoning_parser = None + mock_config.tool_parser = None # Add this attribute + + # Mock IPCSignal to avoid file system dependencies + with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal: + mock_ipcsignal.return_value = MagicMock() + + with patch("fastdeploy.entrypoints.engine_client.StatefulSemaphore") as mock_semaphore: + mock_semaphore.return_value = MagicMock() + + with patch("fastdeploy.entrypoints.engine_client.DealerConnectionManager") as mock_connection_manager: + mock_connection_manager.return_value = MagicMock() + + with patch("fastdeploy.entrypoints.engine_client.FileLock") as mock_filelock: + mock_filelock.return_value = MagicMock() + + with patch("fastdeploy.config.ModelConfig") as mock_model_config_class: + mock_model_config_class.return_value = mock_model_config + + with patch( + "fastdeploy.entrypoints.engine_client.InputPreprocessor" + ) as mock_input_processor: + mock_input_processor_instance = MagicMock() + mock_input_processor_instance.create_processor.return_value = mock_data_processor + mock_input_processor.return_value = mock_input_processor_instance + + # Create EngineClient with minimal required parameters + self.engine_client = EngineClient( + pid=1234, + port=8080, + fd_config=mock_config, + workers=1, + ) + + # Set up mock attributes for TestEngineClientValidParameters class + self.engine_client.zmq_client = Mock() + self.engine_client.zmq_client.send_json = Mock() + self.engine_client.zmq_client.send_pyobj = Mock() + self.engine_client.max_logprobs = 20 + self.engine_client.enable_logprob = True + self.engine_client.ori_vocab_size = 1000 + self.engine_client.enable_prefix_caching = False + self.engine_client.enable_splitwise = False + self.engine_client.disable_prefix_mm = False + self.engine_client.max_model_len = 1024 + self.engine_client.enable_mm = False + self.engine_client.config = mock_config + self.engine_client.max_chips_per_node = 8 + self.engine_client.tensor_parallel_size = 1 + self.engine_client.is_master = True + self.engine_client.worker_healthy_live_signal = Mock() + self.engine_client.worker_healthy_live_signal.value = np.array([0]) + self.engine_client.model_weights_status_signal = Mock() + self.engine_client.model_weights_status_signal.value = np.array([0]) + self.engine_client.clear_update_lock = Mock() + self.engine_client.clear_update_lock.__enter__ = Mock(return_value=None) + self.engine_client.clear_update_lock.__exit__ = Mock(return_value=None) + self.engine_client.kv_cache_status_signal = Mock() + self.engine_client.kv_cache_status_signal.value = np.array([0]) + self.engine_client.prefix_tree_status_signal = Mock() + self.engine_client.prefix_tree_status_signal.value = np.array([0]) + + def test_max_logprobs_valid_values(self): + """Test valid max_logprobs values""" + # Test positive max_logprobs + self.engine_client.max_logprobs = 20 + data = {"request_id": "test"} + self.engine_client.valid_parameters(data) # Should not raise + + # Test -1 (unlimited) + self.engine_client.max_logprobs = -1 + data = {"request_id": "test"} + self.engine_client.valid_parameters(data) # Should not raise + + def test_max_logprobs_invalid_values(self): + """Test invalid max_logprobs values""" + # Test negative value less than -1 + self.engine_client.max_logprobs = -2 + data = {"request_id": "test"} + + with self.assertRaises(ValueError) as context: + self.engine_client.valid_parameters(data) + + self.assertIn("max_logprobs", str(context.exception)) + self.assertIn("must be >= -1", str(context.exception)) + self.assertIn("got -2", str(context.exception)) + + def test_max_logprobs_exceeds_vocab_size(self): + """Test max_logprobs exceeding vocab_size""" + self.engine_client.max_logprobs = 1500 + self.engine_client.ori_vocab_size = 1000 + data = {"request_id": "test"} + + with self.assertRaises(ValueError) as context: + self.engine_client.valid_parameters(data) + + self.assertIn("max_logprobs", str(context.exception)) + self.assertIn("must be <= vocab_size", str(context.exception)) + self.assertIn("1000", str(context.exception)) + self.assertIn("got 1500", str(context.exception)) + + def test_max_logprobs_unlimited(self): + """Test max_logprobs = -1 (unlimited) sets to ori_vocab_size""" + self.engine_client.max_logprobs = -1 + self.engine_client.ori_vocab_size = 1000 + data = {"request_id": "test"} + + # This should not raise and internally max_logprobs should be set to ori_vocab_size + self.engine_client.valid_parameters(data) # Should not raise + # The actual max_logprobs value should be set to ori_vocab_size internally + self.assertEqual(self.engine_client.max_logprobs, -1) # Original value remains unchanged + + def test_prompt_logprobs_valid_values(self): + """Test valid prompt_logprobs values""" + self.engine_client.max_logprobs = 20 + self.engine_client.enable_logprob = True + + # Test valid positive value with FD_USE_GET_SAVE_OUTPUT_V1=1 + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + data = {"prompt_logprobs": 10, "request_id": "test"} + self.engine_client.valid_parameters(data) # Should not raise + + # Test -1 (unlimited) with FD_USE_GET_SAVE_OUTPUT_V1=1 + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + self.engine_client.max_logprobs = -1 + data = {"prompt_logprobs": -1, "request_id": "test"} + self.engine_client.valid_parameters(data) # Should not raise + + # Test None (default) + data = {"request_id": "test"} + self.engine_client.valid_parameters(data) # Should not raise + + def test_prompt_logprobs_unlimited_sets_to_vocab_size(self): + """Test prompt_logprobs = -1 sets to ori_vocab_size""" + self.engine_client.max_logprobs = -1 # Set to unlimited to allow prompt_logprobs = -1 + self.engine_client.enable_logprob = True + self.engine_client.ori_vocab_size = 1000 + + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + data = {"prompt_logprobs": -1, "request_id": "test"} + self.engine_client.valid_parameters(data) # Should not raise + # prompt_logprobs should be set to ori_vocab_size internally + + def test_prompt_logprobs_disabled_when_fd_use_get_save_output_v1_disabled(self): + """Test prompt_logprobs when FD_USE_GET_SAVE_OUTPUT_V1 is disabled""" + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + data = {"prompt_logprobs": 10, "request_id": "test"} + + with self.assertRaises(ParameterError) as context: + self.engine_client.valid_parameters(data) + + self.assertIn( + "prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled", str(context.exception) + ) + + def test_prompt_logprobs_disabled_logprob(self): + """Test prompt_logprobs when logprob is disabled""" + self.engine_client.enable_logprob = False + data = {"prompt_logprobs": 10, "request_id": "test"} + + with self.assertRaises(ParameterError) as context: + self.engine_client.valid_parameters(data) + + self.assertIn("`enable_logprob` is disabled, please enable it in startup config.", str(context.exception)) + + def test_prompt_logprobs_disabled_when_prefix_caching_enabled(self): + """Test prompt_logprobs when prefix caching is enabled""" + self.engine_client.enable_prefix_caching = True + self.engine_client.enable_logprob = True + + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + data = {"prompt_logprobs": 10, "request_id": "test"} + + with self.assertRaises(ParameterError) as context: + self.engine_client.valid_parameters(data) + + self.assertIn("prompt_logprobs is not support when prefix caching is enabled", str(context.exception)) + + def test_prompt_logprobs_invalid_values(self): + """Test invalid prompt_logprobs values""" + self.engine_client.enable_logprob = True + + # Test negative value less than -1 with FD_USE_GET_SAVE_OUTPUT_V1=1 + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + data = {"prompt_logprobs": -2, "request_id": "test"} + + with self.assertRaises(ValueError) as context: + self.engine_client.valid_parameters(data) + + self.assertIn("prompt_logprobs", str(context.exception)) + self.assertIn("must be a non-negative value or -1", str(context.exception)) + self.assertIn("current value is -2", str(context.exception)) + + def test_prompt_logprobs_exceeds_max_logprobs(self): + """Test prompt_logprobs exceeding max_logprobs""" + self.engine_client.max_logprobs = 10 + self.engine_client.enable_logprob = True + + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + data = {"prompt_logprobs": 15, "request_id": "test"} + + with self.assertRaises(ValueError) as context: + self.engine_client.valid_parameters(data) + + self.assertIn("prompt_logprobs", str(context.exception)) + self.assertIn("exceeds maximum allowed value", str(context.exception)) + self.assertIn("15", str(context.exception)) + self.assertIn("10", str(context.exception)) + + def test_top_logprobs_validation_with_fd_use_get_save_output_v1_enabled(self): + """Test top_logprobs validation when FD_USE_GET_SAVE_OUTPUT_V1 is enabled""" + self.engine_client.max_logprobs = 20 + self.engine_client.enable_logprob = True + + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + # Test -1 (unlimited) - should set to ori_vocab_size, but need max_logprobs also to be -1 + self.engine_client.max_logprobs = -1 # Set to unlimited to allow top_logprobs = -1 + data = {"logprobs": True, "top_logprobs": -1, "request_id": "test"} + self.engine_client.valid_parameters(data) # Should not raise + + # Reset max_logprobs for other tests + self.engine_client.max_logprobs = 20 + + # Test valid positive value + data = {"logprobs": True, "top_logprobs": 10, "request_id": "test"} + self.engine_client.valid_parameters(data) # Should not raise + + # Test value less than -1 - should raise ValueError + data = {"logprobs": True, "top_logprobs": -2, "request_id": "test"} + with self.assertRaises(ValueError) as context: + self.engine_client.valid_parameters(data) + self.assertIn("must be a non-negative value or -1", str(context.exception)) + self.assertIn("current value is -2", str(context.exception)) + + # Test value exceeding max_logprobs - should raise ValueError + data = {"logprobs": True, "top_logprobs": 25, "request_id": "test"} + with self.assertRaises(ValueError) as context: + self.engine_client.valid_parameters(data) + self.assertIn("exceeds maximum allowed value", str(context.exception)) + + def test_top_logprobs_validation_with_fd_use_get_save_output_v1_disabled(self): + """Test top_logprobs validation when FD_USE_GET_SAVE_OUTPUT_V1 is disabled""" + self.engine_client.max_logprobs = 20 + self.engine_client.enable_logprob = True + + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + # Test negative value - should raise ValueError + data = {"logprobs": True, "top_logprobs": -1, "request_id": "test"} + with self.assertRaises(ValueError) as context: + self.engine_client.valid_parameters(data) + self.assertIn("top_logprobs must be between 0 and 20", str(context.exception)) + self.assertIn("current value is -1", str(context.exception)) + + # Test value > 20 - should raise ValueError + data = {"logprobs": True, "top_logprobs": 25, "request_id": "test"} + with self.assertRaises(ValueError) as context: + self.engine_client.valid_parameters(data) + self.assertIn("top_logprobs must be between 0 and 20", str(context.exception)) + self.assertIn("current value is 25", str(context.exception)) + + # Test valid value + data = {"logprobs": True, "top_logprobs": 10, "request_id": "test"} + self.engine_client.valid_parameters(data) # Should not raise + + def test_top_logprobs_disabled_logprob(self): + """Test top_logprobs when logprob is disabled""" + self.engine_client.enable_logprob = False + data = {"logprobs": True, "top_logprobs": 10, "request_id": "test"} + + with self.assertRaises(ParameterError) as context: + self.engine_client.valid_parameters(data) + + self.assertIn("disabled", str(context.exception)) + + def test_top_logprobs_invalid_type(self): + """Test top_logprobs with invalid type""" + self.engine_client.enable_logprob = True + + # Test with string type + data = {"logprobs": True, "top_logprobs": "10", "request_id": "test"} + + with self.assertRaises(ParameterError) as context: + self.engine_client.valid_parameters(data) + + self.assertIn("top_logprobs", str(context.exception)) + self.assertIn("Invalid type", str(context.exception)) + self.assertIn("expected int", str(context.exception)) + + def test_logprobs_invalid_type(self): + """Test logprobs with invalid type""" + self.engine_client.enable_logprob = True + + # Test with string type + data = {"logprobs": "true", "request_id": "test"} + + with self.assertRaises(ParameterError) as context: + self.engine_client.valid_parameters(data) + + self.assertIn("logprobs", str(context.exception)) + self.assertIn("Invalid type", str(context.exception)) + + def test_logprobs_disabled(self): + """Test logprobs when logprob is disabled""" + self.engine_client.enable_logprob = False + + # Test with logprobs=True + data = {"logprobs": True, "request_id": "test"} + + with self.assertRaises(ParameterError) as context: + self.engine_client.valid_parameters(data) + + self.assertIn("disabled", str(context.exception)) + + def test_unlimited_max_logprobs_with_prompt_logprobs(self): + """Test unlimited max_logprobs (-1) with prompt_logprobs""" + self.engine_client.max_logprobs = -1 # Unlimited + self.engine_client.enable_logprob = True + + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + # Should allow any prompt_logprobs value + data = {"prompt_logprobs": 1000, "request_id": "test"} + self.engine_client.valid_parameters(data) # Should not raise + + def test_unlimited_max_logprobs_with_top_logprobs(self): + """Test unlimited max_logprobs (-1) with top_logprobs""" + self.engine_client.max_logprobs = -1 # Unlimited + self.engine_client.enable_logprob = True + + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + # Should allow any top_logprobs value + data = {"logprobs": True, "top_logprobs": 1000, "request_id": "test"} + self.engine_client.valid_parameters(data) # Should not raise + + def test_edge_case_zero_values(self): + """Test edge cases with zero values""" + self.engine_client.max_logprobs = 20 + self.engine_client.enable_logprob = True + + # Test prompt_logprobs = 0 with FD_USE_GET_SAVE_OUTPUT_V1=1 + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + data = {"prompt_logprobs": 0, "request_id": "test"} + self.engine_client.valid_parameters(data) # Should not raise + + # Test top_logprobs = 0 with FD_USE_GET_SAVE_OUTPUT_V1=0 + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + data = {"logprobs": True, "top_logprobs": 0, "request_id": "test"} + self.engine_client.valid_parameters(data) # Should not raise + def test_valid_parameters(self): request = { "request_id": "test-request-id", @@ -83,6 +593,1406 @@ def test_valid_parameters(self): self.engine_client.valid_parameters(request) assert request["temperature"] == 1e-6 + async def test_init_basic_parameters(self): + """Test EngineClient initialization with basic parameters.""" + # Create a proper ModelConfig mock with enable_mm attribute + mock_model_config = Mock() + mock_model_config.enable_mm = False + + # Create mocks for all the external dependencies + mock_input_processor = Mock() + mock_processor = Mock() + mock_input_processor.create_processor.return_value = mock_processor + + # Mock current platform + mock_platform = Mock() + mock_platform.is_iluvatar.return_value = False + + # Create mock IPCSignal that behaves properly + mock_ipcsignal = Mock() + mock_signal_instance = Mock() + mock_signal_instance.value = np.array([0]) + mock_ipcsignal.return_value = mock_signal_instance + + # Mock envs for FD_SUPPORT_MAX_CONNECTIONS + mock_envs = Mock() + mock_envs.FD_SUPPORT_MAX_CONNECTIONS = 100 + + with ( + patch.multiple( + "fastdeploy.entrypoints.engine_client", + InputPreprocessor=Mock(return_value=mock_input_processor), + current_platform=mock_platform, + IPCSignal=mock_ipcsignal, + StatefulSemaphore=Mock, + DealerConnectionManager=Mock, + FileLock=Mock, + work_process_metrics=Mock(), + envs=mock_envs, + ), + patch("os.getenv", return_value="50"), + ): + # Create a mock config for this test + mock_config = Mock() + mock_config.model_config = Mock() + mock_config.model_config.enable_mm = False + + client = EngineClient( + model_name_or_path="test_model", + tokenizer=Mock(), + max_model_len=2048, + tensor_parallel_size=2, + pid=5678, + port=9090, + limit_mm_per_prompt=3, + mm_processor_kwargs={"test": "value"}, + config=mock_config, + reasoning_parser=None, + data_parallel_size=1, + enable_logprob=False, + workers=2, + tool_parser=None, + enable_prefix_caching=True, + splitwise_role="master", + max_processor_cache=100, + ) + + self.assertEqual(client.max_model_len, 2048) + self.assertEqual(client.enable_logprob, False) + self.assertEqual(client.enable_prefix_caching, True) + self.assertEqual(client.enable_splitwise, True) + + async def test_format_and_add_data_without_request_id(self): + """Test format_and_add_data adds request_id when missing.""" + prompts = {"prompt_token_ids": [1, 2, 3], "max_tokens": 50} + + with patch.object(self.engine_client, "add_requests") as mock_add: + mock_add.return_value = None + + result = await self.engine_client.format_and_add_data(prompts) + + self.assertIn("request_id", prompts) + self.assertEqual(result, prompts["prompt_token_ids"]) + mock_add.assert_called_once_with(prompts) + + async def test_format_and_add_data_with_max_tokens_default(self): + """Test format_and_add_data sets default max_tokens when missing.""" + prompts = {"request_id": "test-id", "prompt_token_ids": [1, 2, 3]} + + with patch.object(self.engine_client, "add_requests") as mock_add: + mock_add.return_value = None + + await self.engine_client.format_and_add_data(prompts) + + self.assertEqual(prompts["max_tokens"], self.engine_client.max_model_len - 1) + + async def test_check_mm_disable_prefix_cache_with_disabled_cache(self): + """Test _check_mm_disable_prefix_cache when prefix cache is disabled.""" + self.engine_client.disable_prefix_mm = False + task = {"multimodal_inputs": {"token_type_ids": [1, 2, 3]}} + + result = self.engine_client._check_mm_disable_prefix_cache(task) + + self.assertFalse(result) + + async def test_check_mm_disable_prefix_cache_with_no_multimodal_data(self): + """Test _check_mm_disable_prefix_cache with no multimodal inputs.""" + self.engine_client.disable_prefix_mm = True + task = {"multimodal_inputs": []} + + result = self.engine_client._check_mm_disable_prefix_cache(task) + + self.assertFalse(result) + + async def test_check_mm_disable_prefix_cache_with_multimodal_data(self): + """Test _check_mm_disable_prefix_cache detects multimodal data.""" + self.engine_client.disable_prefix_mm = True + task = {"multimodal_inputs": {"token_type_ids": [1, 0, 2]}} + + result = self.engine_client._check_mm_disable_prefix_cache(task) + + self.assertTrue(result) + + async def test_add_requests_successful_processing(self): + """Test successful request processing in add_requests.""" + task = { + "request_id": "test-id", + "chat_template_kwargs": {"existing": "value"}, + "chat_template": "test_template", + "prompt_token_ids": [1, 2, 3, 4, 5], + "max_tokens": 100, + "min_tokens": 1, + "messages": "test message", + } + + self.engine_client.data_processor.process_request_dict = Mock() + + with patch.object(self.engine_client, "_send_task") as mock_send: + await self.engine_client.add_requests(task) + + self.assertEqual(task["chat_template_kwargs"]["chat_template"], "test_template") + self.assertEqual(task["prompt_token_ids_len"], 5) + self.assertNotIn("messages", task) + mock_send.assert_called_once() + + async def test_add_requests_with_coroutine_processor(self): + """Test add_requests with async processor.""" + task = {"request_id": "test-id", "prompt_token_ids": [1, 2, 3], "max_tokens": 100} + + async_mock = AsyncMock() + self.engine_client.data_processor.process_request_dict = async_mock + + with patch.object(self.engine_client, "_send_task"): + await self.engine_client.add_requests(task) + + async_mock.assert_called_once() + + async def test_add_requests_with_multimodal_prefix_cache_error(self): + """Test add_requests raises error for multimodal data with prefix cache.""" + self.engine_client.enable_mm = True + self.engine_client.enable_prefix_caching = True + self.engine_client.disable_prefix_mm = True + + task = { + "request_id": "test-id", + "prompt_token_ids": [1, 2, 3], + "multimodal_inputs": {"token_type_ids": [1, 0, 1]}, + } + + with self.assertRaises(Exception): # EngineError + await self.engine_client.add_requests(task) + + async def test_add_requests_input_length_validation_error(self): + """Test add_requests validation for input length.""" + task = {"request_id": "test-id", "prompt_token_ids": list(range(1024)), "min_tokens": 1} # At max length + + with self.assertRaises(Exception): # EngineError + await self.engine_client.add_requests(task) + + async def test_add_requests_stop_sequences_validation(self): + """Test add_requests validation for stop sequences.""" + task = { + "request_id": "test-id", + "prompt_token_ids": [1, 2, 3], + "stop_seqs_len": list(range(25)), # Exceeds default limit + } + + with patch("fastdeploy.entrypoints.engine_client.envs") as mock_envs: + mock_envs.FD_MAX_STOP_SEQS_NUM = 20 + mock_envs.FD_STOP_SEQS_MAX_LEN = 100 + + with self.assertRaises(Exception): # EngineError + await self.engine_client.add_requests(task) + + async def test_add_requests_with_n_parameter_multiple_requests(self): + """Test add_requests with n parameter for multiple requests.""" + task = {"request_id": "test-id_1", "prompt_token_ids": [1, 2, 3], "n": 3, "max_tokens": 100} + + with patch.object(self.engine_client, "_send_task") as mock_send: + await self.engine_client.add_requests(task) + + # Should send 3 tasks with indices 3, 4, 5 (1*3 to (1+1)*3) + self.assertEqual(mock_send.call_count, 3) + + def test_send_task_without_multimodal(self): + """Test _send_task for non-multimodal content.""" + self.engine_client.enable_mm = False + task = {"test": "data"} + + self.engine_client._send_task(task) + + self.engine_client.zmq_client.send_json.assert_called_once_with(task) + + def test_send_task_with_multimodal(self): + """Test _send_task for multimodal content.""" + self.engine_client.enable_mm = True + task = {"test": "multimodal_data"} + + self.engine_client._send_task(task) + + self.engine_client.zmq_client.send_pyobj.assert_called_once_with(task) + + def test_valid_parameters_max_tokens_valid(self): + """Test valid_parameters accepts valid max_tokens.""" + data = {"max_tokens": 100} + + # Should not raise exception + self.engine_client.valid_parameters(data) + + def test_valid_parameters_max_tokens_too_small(self): + """Test valid_parameters rejects max_tokens < 1.""" + data = {"max_tokens": 0} + + with self.assertRaises(Exception): # ParameterError + self.engine_client.valid_parameters(data) + + def test_valid_parameters_max_tokens_too_large(self): + """Test valid_parameters rejects max_tokens >= max_model_len.""" + data = {"max_tokens": 2048} # Equal to max_model_len, should raise exception + + with self.assertRaises(Exception): # ParameterError + self.engine_client.valid_parameters(data) + + def test_valid_parameters_reasoning_max_tokens_adjustment(self): + """Test valid_parameters adjusts reasoning_max_tokens when needed.""" + data = {"max_tokens": 50, "reasoning_max_tokens": 100, "request_id": "test-id"} # Larger than max_tokens + + with patch("fastdeploy.entrypoints.engine_client.api_server_logger") as mock_logger: + self.engine_client.valid_parameters(data) + + self.assertEqual(data["reasoning_max_tokens"], 50) + mock_logger.warning.assert_called_once() + + def test_valid_parameters_temperature_zero_adjustment(self): + """Test valid_parameters adjusts zero temperature.""" + data = {"temperature": 0} + + self.engine_client.valid_parameters(data) + + self.assertEqual(data["temperature"], 1e-6) + + def test_valid_parameters_logprobs_disabled_when_enabled(self): + """Test valid_parameters rejects logprobs when disabled.""" + self.engine_client.enable_logprob = False + data = {"logprobs": True} + + with self.assertRaises(Exception): # ParameterError + self.engine_client.valid_parameters(data) + + def test_valid_parameters_logprobs_with_invalid_type(self): + """Test valid_parameters rejects invalid logprobs type.""" + data = {"logprobs": "invalid"} + + with self.assertRaises(Exception): # ParameterError + self.engine_client.valid_parameters(data) + + def test_valid_parameters_top_logprobs_disabled(self): + """Test valid_parameters rejects top_logprobs when disabled.""" + self.engine_client.enable_logprob = False + data = {"logprobs": True, "top_logprobs": 5} + + with self.assertRaises(Exception): # ParameterError + self.engine_client.valid_parameters(data) + + def test_valid_parameters_top_logprobs_invalid_type(self): + """Test valid_parameters rejects invalid top_logprobs type.""" + self.engine_client.enable_logprob = True + data = {"logprobs": True, "top_logprobs": "invalid"} + + with self.assertRaises(Exception): # ParameterError + self.engine_client.valid_parameters(data) + + def test_valid_parameters_top_logprobs_negative(self): + """Test valid_parameters rejects negative top_logprobs.""" + self.engine_client.enable_logprob = True + data = {"logprobs": True, "top_logprobs": -1} + + with self.assertRaises(Exception): # ParameterError + self.engine_client.valid_parameters(data) + + def test_valid_parameters_top_logprobs_too_large(self): + """Test valid_parameters rejects top_logprobs > 20.""" + self.engine_client.enable_logprob = True + data = {"logprobs": True, "top_logprobs": 25} + + with self.assertRaises(Exception): # ParameterError + self.engine_client.valid_parameters(data) + + def test_valid_parameters_top_logprobs_valid(self): + """Test valid_parameters accepts valid top_logprobs.""" + self.engine_client.enable_logprob = True + data = {"logprobs": True, "top_logprobs": 10} + + # Should not raise exception + self.engine_client.valid_parameters(data) + + def test_check_health_healthy(self): + """Test check_health returns healthy status.""" + self.engine_client.worker_healthy_live_signal.value = np.array([time.time()]) + + result, message = self.engine_client.check_health() + + self.assertTrue(result) + self.assertEqual(message, "") + + def test_check_health_unhealthy_timeout(self): + """Test check_health returns unhealthy due to timeout.""" + # Set signal to old time (more than 30 seconds ago) + old_time = time.time() - 60 + self.engine_client.worker_healthy_live_signal.value = np.array([old_time]) + + result, message = self.engine_client.check_health(time_interval_threashold=30) + + self.assertFalse(result) + self.assertEqual(message, "Worker Service Not Healthy") + + def test_is_workers_alive_normal(self): + """Test is_workers_alive returns True when weights are normal.""" + with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: + mock_status.NORMAL = 0 + self.engine_client.model_weights_status_signal.value = np.array([0]) + + result, message = self.engine_client.is_workers_alive() + + self.assertTrue(result) + self.assertEqual(message, "") + + def test_is_workers_alive_no_weights(self): + """Test is_workers_alive returns False when no weights.""" + with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: + mock_status.NORMAL = 0 + self.engine_client.model_weights_status_signal.value = np.array([1]) + + result, message = self.engine_client.is_workers_alive() + + self.assertFalse(result) + self.assertEqual(message, "No model weight enabled") + + def test_update_model_weight_already_normal(self): + """Test update_model_weight when weights are already normal.""" + with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: + mock_status.NORMAL = 0 + self.engine_client.model_weights_status_signal.value = np.array([0]) + + result, message = self.engine_client.update_model_weight() + + self.assertTrue(result) + self.assertEqual(message, "") + + def test_update_model_weight_already_updating(self): + """Test update_model_weight when already updating.""" + with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: + mock_status.NORMAL = 0 + mock_status.UPDATING = 1 + self.engine_client.model_weights_status_signal.value = np.array([1]) + + result, message = self.engine_client.update_model_weight() + + self.assertFalse(result) + self.assertEqual(message, "worker is updating model weight already") + + def test_update_model_weight_clearing(self): + """Test update_model_weight when clearing weights.""" + with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: + mock_status.NORMAL = 0 + mock_status.CLEARING = -1 + self.engine_client.model_weights_status_signal.value = np.array([-1]) + + result, message = self.engine_client.update_model_weight() + + self.assertFalse(result) + self.assertEqual(message, "worker is clearing model weight, cannot update now") + + def test_update_model_weight_timeout(self): + """Test update_model_weight timeout scenario.""" + with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: + with patch("fastdeploy.entrypoints.engine_client.KVCacheStatus") as mock_kv_status: + with patch("fastdeploy.entrypoints.engine_client.PrefixTreeStatus") as mock_prefix_status: + mock_status.NORMAL = 0 + mock_status.UPDATING = 1 + mock_status.CLEARED = -2 + mock_kv_status.NORMAL = 0 + mock_kv_status.UPDATING = 1 + mock_kv_status.CLEARED = -2 + mock_prefix_status.NORMAL = 0 + mock_prefix_status.UPDATING = 1 + mock_prefix_status.CLEARED = -2 + + self.engine_client.enable_prefix_caching = True + # Start with CLEARED status to enter the updating loop + self.engine_client.model_weights_status_signal.value = np.array([-2]) + self.engine_client.kv_cache_status_signal.value = np.array([-2]) # Start as CLEARED + self.engine_client.prefix_tree_status_signal.value = np.array([-2]) # Start as CLEARED + + result, message = self.engine_client.update_model_weight(timeout=1) + + self.assertFalse(result) + self.assertEqual(message, "Update model weight timeout") + + def test_clear_load_weight_already_cleared(self): + """Test clear_load_weight when weights are already cleared.""" + with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: + mock_status.CLEARED = -2 + self.engine_client.model_weights_status_signal.value = np.array([-2]) + + result, message = self.engine_client.clear_load_weight() + + self.assertTrue(result) + self.assertEqual(message, "") + + def test_clear_load_weight_already_clearing(self): + """Test clear_load_weight when already clearing.""" + with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: + mock_status.CLEARED = -2 + mock_status.CLEARING = -1 + self.engine_client.model_weights_status_signal.value = np.array([-1]) + + result, message = self.engine_client.clear_load_weight() + + self.assertFalse(result) + self.assertEqual(message, "worker is clearing model weight already") + + def test_clear_load_weight_updating(self): + """Test clear_load_weight when updating weights.""" + with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: + mock_status.CLEARED = -2 + mock_status.CLEARING = -1 + mock_status.UPDATING = 1 + self.engine_client.model_weights_status_signal.value = np.array([1]) + + result, message = self.engine_client.clear_load_weight() + + self.assertFalse(result) + self.assertEqual(message, "worker is updating model weight, cannot clear now") + + def test_clear_load_weight_timeout(self): + """Test clear_load_weight timeout scenario.""" + with patch("fastdeploy.entrypoints.engine_client.ModelWeightsStatus") as mock_status: + with patch("fastdeploy.entrypoints.engine_client.KVCacheStatus") as mock_kv_status: + with patch("fastdeploy.entrypoints.engine_client.PrefixTreeStatus") as mock_prefix_status: + mock_status.NORMAL = 0 + mock_status.CLEARED = -2 + mock_status.CLEARING = -1 + mock_kv_status.CLEARED = -2 + mock_kv_status.CLEARING = -1 + mock_prefix_status.CLEARED = -2 + mock_prefix_status.CLEARING = -1 + + self.engine_client.enable_prefix_caching = True + # Start with NORMAL status to enter the clearing loop + self.engine_client.model_weights_status_signal.value = np.array([0]) + self.engine_client.kv_cache_status_signal.value = np.array([0]) # Start as NORMAL + self.engine_client.prefix_tree_status_signal.value = np.array([0]) # Start as NORMAL + + result, message = self.engine_client.clear_load_weight(timeout=1) + + self.assertFalse(result) + self.assertEqual(message, "Clear model weight timeout") + + def test_check_model_weight_status(self): + """Test check_model_weight_status returns correct status.""" + # Status < 0 indicates abnormal + self.engine_client.model_weights_status_signal.value = np.array([-1]) + result = self.engine_client.check_model_weight_status() + self.assertTrue(result) + + # Status >= 0 indicates normal + self.engine_client.model_weights_status_signal.value = np.array([0]) + result = self.engine_client.check_model_weight_status() + self.assertFalse(result) + + def test_create_zmq_client(self): + """Test create_zmq_client method.""" + mock_zmq_client = Mock() + with patch("fastdeploy.entrypoints.engine_client.ZmqIpcClient", return_value=mock_zmq_client) as mock_zmq: + self.engine_client.create_zmq_client("test_model", "test_mode") + + mock_zmq.assert_called_once_with("test_model", "test_mode") + mock_zmq_client.connect.assert_called_once() + self.assertEqual(self.engine_client.zmq_client, mock_zmq_client) + + async def test_init_with_multimodal_prefix_cache(self): + """Test EngineClient initialization with multimodal prefix cache enabled.""" + mock_model_config = Mock() + mock_model_config.enable_mm = True + + mock_config = Mock() + mock_config.model_config = mock_model_config + mock_config.eplb_config = Mock() + mock_config.eplb_config.enable_eplb = False + + with ( + patch("fastdeploy.entrypoints.engine_client.InputPreprocessor") as mock_processor_class, + patch("fastdeploy.entrypoints.engine_client.current_platform") as mock_platform, + patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal, + patch("fastdeploy.entrypoints.engine_client.envs") as mock_envs, + patch("os.getenv", return_value="50"), + patch("fastdeploy.cache_manager.cache_data.is_mm_model_disable_prefix_cache", return_value=True), + ): + mock_platform.is_iluvatar.return_value = False + mock_input_processor = Mock() + mock_processor_class.return_value = mock_input_processor + mock_processor = Mock() + mock_input_processor.create_processor.return_value = mock_processor + + mock_signal_instance = Mock() + mock_signal_instance.value = np.array([0]) + mock_ipcsignal.return_value = mock_signal_instance + mock_envs.FD_SUPPORT_MAX_CONNECTIONS = 100 + + client = EngineClient( + model_name_or_path="test_model", + tokenizer=Mock(), + max_model_len=2048, + tensor_parallel_size=1, + pid=5678, + port=8080, + limit_mm_per_prompt=5, + mm_processor_kwargs={}, + config=mock_config, + reasoning_parser=None, + data_parallel_size=1, + enable_logprob=True, + workers=1, + tool_parser=None, + enable_prefix_caching=True, # Enable prefix caching + splitwise_role=None, + max_processor_cache=0, + ) + + self.assertTrue(client.enable_mm) + self.assertTrue(client.enable_prefix_caching) + self.assertTrue(client.disable_prefix_mm) + + async def test_init_as_worker_node(self): + """Test EngineClient initialization as worker node (not master).""" + mock_model_config = Mock() + mock_model_config.enable_mm = False + + mock_config = Mock() + mock_config.model_config = mock_model_config + mock_config.eplb_config = Mock() + mock_config.eplb_config.enable_eplb = False + + with ( + patch("fastdeploy.entrypoints.engine_client.InputPreprocessor") as mock_processor_class, + patch("fastdeploy.entrypoints.engine_client.current_platform") as mock_platform, + patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal, + patch("fastdeploy.entrypoints.engine_client.envs") as mock_envs, + patch("os.getenv", return_value="50"), + ): + mock_platform.is_iluvatar.return_value = False + mock_platform.max_chips_per_node = 8 + mock_input_processor = Mock() + mock_processor_class.return_value = mock_input_processor + mock_processor = Mock() + mock_input_processor.create_processor.return_value = mock_processor + + mock_signal_instance = Mock() + mock_signal_instance.value = np.array([0]) + mock_ipcsignal.return_value = mock_signal_instance + mock_envs.FD_SUPPORT_MAX_CONNECTIONS = 100 + + # Use tensor_parallel_size > max_chips_per_node to make it a worker + client = EngineClient( + model_name_or_path="test_model", + tokenizer=Mock(), + max_model_len=2048, + tensor_parallel_size=16, # Large number to make it a worker + pid=5678, + port=8080, + limit_mm_per_prompt=5, + mm_processor_kwargs={}, + config=mock_config, + reasoning_parser=None, + data_parallel_size=1, + enable_logprob=True, + workers=1, + tool_parser=None, + enable_prefix_caching=False, + splitwise_role=None, + max_processor_cache=0, + ) + + self.assertFalse(client.is_master) + + async def test_format_and_add_data(self): + """Test format_and_add_data method.""" + prompts = {"prompt_token_ids": [1, 2, 3], "max_tokens": 50} + + with patch.object(self.engine_client, "add_requests") as mock_add: + mock_add.return_value = None + + await self.engine_client.format_and_add_data(prompts) + + mock_add.assert_called_once() + call_args = mock_add.call_args[0][0] + self.assertIn("request_id", call_args) + self.assertEqual(call_args["prompt_token_ids"], [1, 2, 3]) + self.assertEqual(call_args["max_tokens"], 50) + + async def test_rearrange_experts_disabled(self): + """Test rearrange_experts when EPLB is disabled.""" + mock_config = Mock() + mock_config.eplb_config = Mock() + mock_config.eplb_config.enable_eplb = False + + self.engine_client.config = mock_config + + request_dict = {"user": "test", "passwd": "test"} + content, status_code = await self.engine_client.rearrange_experts(request_dict) + + self.assertEqual(content["code"], 1) + self.assertEqual(content["msg"], "redundant expert is disabled") + self.assertEqual(status_code, 400) + + async def test_get_per_expert_tokens_stats_disabled(self): + """Test get_per_expert_tokens_stats when EPLB is disabled.""" + mock_config = Mock() + mock_config.eplb_config = Mock() + mock_config.eplb_config.enable_eplb = False + + self.engine_client.config = mock_config + + request_dict = {"user": "test", "passwd": "test"} + content, status_code = await self.engine_client.get_per_expert_tokens_stats(request_dict) + + self.assertEqual(content["code"], 1) + self.assertEqual(content["msg"], "redundant expert is disabled") + self.assertEqual(status_code, 400) + + async def test_get_per_expert_tokens_stats_invalid_auth(self): + """Test get_per_expert_tokens_stats with invalid authentication.""" + mock_eplb_config = Mock() + mock_eplb_config.enable_eplb = True + mock_eplb_config.redundant_expert_api_user = "correct_user" + mock_eplb_config.redundant_expert_api_password = "correct_pass" + + mock_parallel_config = Mock() + mock_parallel_config.tensor_parallel_rank = 0 + + mock_config = Mock() + mock_config.eplb_config = mock_eplb_config + mock_config.parallel_config = mock_parallel_config + + self.engine_client.config = mock_config + + request_dict = {"user": "wrong_user", "passwd": "wrong_pass"} + content, status_code = await self.engine_client.get_per_expert_tokens_stats(request_dict) + + self.assertEqual(content["code"], 1) + self.assertEqual(content["msg"], "user or passwd is invalid") + self.assertEqual(status_code, 401) + + async def test_get_per_expert_tokens_stats_success(self): + """Test get_per_expert_tokens_stats successful response.""" + mock_eplb_config = Mock() + mock_eplb_config.enable_eplb = True + mock_eplb_config.redundant_expert_api_user = "test_user" + mock_eplb_config.redundant_expert_api_password = "test_pass" + + mock_parallel_config = Mock() + mock_parallel_config.tensor_parallel_rank = 0 + + mock_config = Mock() + mock_config.eplb_config = mock_eplb_config + mock_config.parallel_config = mock_parallel_config + + self.engine_client.config = mock_config + + # Set up mock arrays + mock_local_stats = Mock() + mock_local_stats.value = np.array([1, 2, 3]) + self.engine_client.local_experts_token_stats_array_list = [mock_local_stats] + self.engine_client.signal_clear_experts_token_stats_list = [] + + request_dict = {"user": "test_user", "passwd": "test_pass"} + + content, status_code = await self.engine_client.get_per_expert_tokens_stats(request_dict) + + self.assertEqual(content["code"], 0) + self.assertEqual(content["msg"], "ok") + self.assertIn("data", content) + self.assertEqual(content["data"], [[1, 2, 3]]) + self.assertEqual(status_code, 200) + + async def test_get_per_expert_tokens_stats_clear_stat(self): + """Test get_per_expert_tokens_stats with clear_stat flag.""" + mock_eplb_config = Mock() + mock_eplb_config.enable_eplb = True + mock_eplb_config.redundant_expert_api_user = "test_user" + mock_eplb_config.redundant_expert_api_password = "test_pass" + + mock_parallel_config = Mock() + mock_parallel_config.tensor_parallel_rank = 0 + + mock_config = Mock() + mock_config.eplb_config = mock_eplb_config + mock_config.parallel_config = mock_parallel_config + + self.engine_client.config = mock_config + + # Set up mock arrays and signals + mock_clear_signal = Mock() + mock_clear_signal.value = np.array([0]) + self.engine_client.signal_clear_experts_token_stats_list = [mock_clear_signal] + + mock_local_stats = Mock() + mock_local_stats.value = np.array([1, 2, 3]) + self.engine_client.local_experts_token_stats_array_list = [mock_local_stats] + + request_dict = {"user": "test_user", "passwd": "test_pass", "clear_stat": True} + + content, status_code = await self.engine_client.get_per_expert_tokens_stats(request_dict) + + self.assertEqual(content["code"], 0) + self.assertEqual(content["msg"], "ok") + self.assertEqual(mock_clear_signal.value[0], 1) # Clear signal should be set + self.assertEqual(status_code, 200) + + async def test_check_redundant_disabled(self): + """Test check_redundant when EPLB is disabled.""" + mock_config = Mock() + mock_config.eplb_config = Mock() + mock_config.eplb_config.enable_eplb = False + + self.engine_client.config = mock_config + + request_dict = {"user": "test", "passwd": "test"} + content, status_code = await self.engine_client.check_redundant(request_dict) + + self.assertEqual(content["code"], 1) + self.assertEqual(content["msg"], "redundant expert is disabled") + self.assertEqual(status_code, 400) + + async def test_check_redundant_invalid_auth(self): + """Test check_redundant with invalid authentication.""" + mock_eplb_config = Mock() + mock_eplb_config.enable_eplb = True + mock_eplb_config.redundant_expert_api_user = "correct_user" + mock_eplb_config.redundant_expert_api_password = "correct_pass" + + mock_parallel_config = Mock() + mock_parallel_config.tensor_parallel_rank = 0 + + mock_config = Mock() + mock_config.eplb_config = mock_eplb_config + mock_config.parallel_config = mock_parallel_config + + self.engine_client.config = mock_config + + request_dict = {"user": "wrong_user", "passwd": "wrong_pass"} + content, status_code = await self.engine_client.check_redundant(request_dict) + + self.assertEqual(content["code"], 1) + self.assertEqual(content["msg"], "user or passwd is invalid") + self.assertEqual(status_code, 401) + + async def test_check_redundant_wrong_rank(self): + """Test check_redundant with wrong tensor parallel rank.""" + mock_eplb_config = Mock() + mock_eplb_config.enable_eplb = True + mock_eplb_config.redundant_expert_api_user = "test_user" + mock_eplb_config.redundant_expert_api_password = "test_pass" + + mock_parallel_config = Mock() + mock_parallel_config.tensor_parallel_rank = 1 # Not rank 0 + + mock_config = Mock() + mock_config.eplb_config = mock_eplb_config + mock_config.parallel_config = mock_parallel_config + + self.engine_client.config = mock_config + + request_dict = {"user": "test_user", "passwd": "test_pass"} + content, status_code = await self.engine_client.check_redundant(request_dict) + + self.assertEqual(content["code"], 1) + self.assertIn("actual rank 1, expect rank 0", content["msg"]) + self.assertEqual(status_code, 400) + + async def test_check_redundant_status_unknown(self): + """Test check_redundant with unknown status (invalid signal value).""" + mock_eplb_config = Mock() + mock_eplb_config.enable_eplb = True + mock_eplb_config.redundant_expert_api_user = "test_user" + mock_eplb_config.redundant_expert_api_password = "test_pass" + + mock_parallel_config = Mock() + mock_parallel_config.tensor_parallel_rank = 0 + + mock_config = Mock() + mock_config.eplb_config = mock_eplb_config + mock_config.parallel_config = mock_parallel_config + + self.engine_client.config = mock_config + self.engine_client.rearrange_experts_signal = Mock() + self.engine_client.rearrange_experts_signal.value = np.array([999]) # Invalid status + + with patch("fastdeploy.entrypoints.engine_client.RearrangeExpertStatus") as mock_status: + mock_status.side_effect = Exception("Invalid status") + + request_dict = {"user": "test_user", "passwd": "test_pass", "action": ""} + + content, status_code = await self.engine_client.check_redundant(request_dict) + + self.assertEqual(content["code"], 0) + self.assertEqual(content["msg"], "ok") + self.assertEqual(content["status"], "unknown") # Should fallback to unknown + self.assertEqual(status_code, 200) + + async def test_check_redundant_status_known(self): + """Test check_redundant with known status.""" + mock_eplb_config = Mock() + mock_eplb_config.enable_eplb = True + mock_eplb_config.redundant_expert_api_user = "test_user" + mock_eplb_config.redundant_expert_api_password = "test_pass" + + mock_parallel_config = Mock() + mock_parallel_config.tensor_parallel_rank = 0 + + mock_config = Mock() + mock_config.eplb_config = mock_eplb_config + mock_config.parallel_config = mock_parallel_config + + self.engine_client.config = mock_config + self.engine_client.rearrange_experts_signal = Mock() + self.engine_client.rearrange_experts_signal.value = np.array([0]) # FREE status + + with patch("fastdeploy.entrypoints.engine_client.RearrangeExpertStatus") as mock_status: + mock_status_instance = Mock() + mock_status_instance.name = "FREE" + mock_status.return_value = mock_status_instance + + request_dict = {"user": "test_user", "passwd": "test_pass", "action": ""} + + content, status_code = await self.engine_client.check_redundant(request_dict) + + self.assertEqual(content["code"], 0) + self.assertEqual(content["msg"], "ok") + self.assertEqual(content["status"], "FREE") + self.assertEqual(status_code, 200) + + async def test_check_redundant_check_load_weight_result(self): + """Test check_redundant with check_load_weight_result action.""" + mock_eplb_config = Mock() + mock_eplb_config.enable_eplb = True + mock_eplb_config.redundant_expert_api_user = "test_user" + mock_eplb_config.redundant_expert_api_password = "test_pass" + + mock_parallel_config = Mock() + mock_parallel_config.tensor_parallel_rank = 0 + + mock_config = Mock() + mock_config.eplb_config = mock_eplb_config + mock_config.parallel_config = mock_parallel_config + + self.engine_client.config = mock_config + + # Set up mock update_weight_from_disk_result_list + mock_result1 = Mock() + mock_result1.value = np.array([1, 2, 3]) + mock_result2 = Mock() + mock_result2.value = np.array([4, 5, 6]) + self.engine_client.update_weight_from_disk_result_list = [mock_result1, mock_result2] + + request_dict = {"user": "test_user", "passwd": "test_pass", "action": "check_load_weight_result"} + + content, status_code = await self.engine_client.check_redundant(request_dict) + + self.assertEqual(content["code"], 0) + self.assertEqual(content["msg"], "ok") + self.assertIn("data", content) + # Code does: update_weight_result.value[0].tolist(), so only first elements + self.assertEqual(content["data"], [1, 4]) + self.assertEqual(status_code, 200) + + async def test_check_redundant_invalid_action(self): + """Test check_redundant with invalid action.""" + mock_eplb_config = Mock() + mock_eplb_config.enable_eplb = True + mock_eplb_config.redundant_expert_api_user = "test_user" + mock_eplb_config.redundant_expert_api_password = "test_pass" + + mock_parallel_config = Mock() + mock_parallel_config.tensor_parallel_rank = 0 + + mock_config = Mock() + mock_config.eplb_config = mock_eplb_config + mock_config.parallel_config = mock_parallel_config + + self.engine_client.config = mock_config + + request_dict = {"user": "test_user", "passwd": "test_pass", "action": "invalid_action"} + + content, status_code = await self.engine_client.check_redundant(request_dict) + + # For invalid action, content remains None and status_code is HTTPStatus.OK + self.assertIsNone(content) + self.assertEqual(status_code, 200) + + def test_init_eplb_signals_non_zero_rank(self): + """Test init_eplb_signals returns early for non-zero tensor parallel rank.""" + mock_parallel_config = Mock() + mock_parallel_config.tensor_parallel_rank = 1 # Non-zero rank + mock_parallel_config.local_data_parallel_id = 0 + + mock_config = Mock() + mock_config.parallel_config = mock_parallel_config + + # Set fd_config to ensure the method checks the correct config + self.engine_client.fd_config = mock_config + self.engine_client.config = mock_config + + # Mock IPCSignal to prevent actual file system calls + with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal: + # Should return early without initializing signals + self.engine_client.init_eplb_signals("test_suffix") + + # Should not create any IPCSignal instances + mock_ipcsignal.assert_not_called() + + # Should return None (implicitly) and not create any signals + self.assertFalse(hasattr(self.engine_client, "rearrange_experts_signal")) + self.assertFalse(hasattr(self.engine_client, "signal_clear_experts_token_stats_list")) + + def test_init_eplb_signals_rank_zero_success(self): + """Test init_eplb_signals successful initialization for rank 0.""" + mock_model_config = Mock() + mock_model_config.num_hidden_layers = 12 + mock_model_config.moe_num_experts = 8 + + mock_eplb_config = Mock() + mock_eplb_config.redundant_expert_ip_shm_size = 1024 + + mock_parallel_config = Mock() + mock_parallel_config.tensor_parallel_rank = 0 + mock_parallel_config.local_data_parallel_id = 2 + mock_parallel_config.tensor_parallel_size = 4 + + mock_config = Mock() + mock_config.model_config = mock_model_config + mock_config.eplb_config = mock_eplb_config + mock_config.parallel_config = mock_parallel_config + + self.engine_client.config = mock_config + self.engine_client.fd_config = mock_config # Also set fd_config for proper access + self.engine_client.tensor_parallel_size = 4 # Set this to match the config + + with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal: + mock_signal = Mock() + mock_ipcsignal.return_value = mock_signal + + self.engine_client.init_eplb_signals("8080") + + # Check that IPCSignal was called with correct parameters + # Based on the actual implementation: 4 base signals + 4 TP ranks * 5 signals each = 24 total + self.assertEqual(mock_ipcsignal.call_count, 24) # 4 TP ranks * 5 signals each + 4 base signals = 24 total + + # Check that the suffix includes data parallel ID + call_args_list = mock_ipcsignal.call_args_list + dp_suffix_found = any("8080_dp2" in str(call) for call in call_args_list) + self.assertTrue(dp_suffix_found) + + # Check that all required signal lists were created + self.assertEqual(len(self.engine_client.signal_clear_experts_token_stats_list), 4) + self.assertEqual(len(self.engine_client.local_experts_token_stats_array_list), 4) + self.assertEqual(len(self.engine_client.expert_tokens_stats_array_list), 4) + self.assertEqual(len(self.engine_client.signal_update_weight_from_disk_array_list), 4) + self.assertEqual(len(self.engine_client.update_weight_from_disk_result_list), 4) + + # Check that base signals were created + self.assertTrue(hasattr(self.engine_client, "rearrange_experts_signal")) + self.assertTrue(hasattr(self.engine_client, "rearrange_experts_ips_size_signal")) + self.assertTrue(hasattr(self.engine_client, "shm_rearrange_experts_ips_list")) + self.assertTrue(hasattr(self.engine_client, "signal_update_weight_from_tensor_array")) + + def test_init_eplb_signals_array_dimensions(self): + """Test init_eplb_signals creates arrays with correct dimensions.""" + mock_model_config = Mock() + mock_model_config.num_hidden_layers = 6 + mock_model_config.moe_num_experts = 4 + + mock_eplb_config = Mock() + mock_eplb_config.redundant_expert_ip_shm_size = 512 + + mock_parallel_config = Mock() + mock_parallel_config.tensor_parallel_rank = 0 + mock_parallel_config.local_data_parallel_id = 1 + mock_parallel_config.tensor_parallel_size = 2 + + mock_config = Mock() + mock_config.model_config = mock_model_config + mock_config.eplb_config = mock_eplb_config + mock_config.parallel_config = mock_parallel_config + + self.engine_client.config = mock_config + self.engine_client.tensor_parallel_size = 2 # Set this to match mock_parallel_config.tensor_parallel_size + self.engine_client.fd_config = mock_config # Also set fd_config to ensure proper access + + with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal: + mock_signal = Mock() + mock_ipcsignal.return_value = mock_signal + + self.engine_client.init_eplb_signals("9090") + + # Check that IPCSignal was called with arrays of correct shape + call_args_list = mock_ipcsignal.call_args_list + + # Find calls for expert token stats arrays (should be 6x4 shape for 2D arrays) + all_experts_token_stats_calls = [call for call in call_args_list if "all_experts_token_stats" in str(call)] + local_experts_token_stats_calls = [ + call for call in call_args_list if "local_experts_token_stats" in str(call) + ] + + # These should be 2D arrays with shape (6, 4) + for call in all_experts_token_stats_calls: + array_arg = call[1]["array"] + self.assertEqual(array_arg.shape, (6, 4)) # (num_hidden_layers, moe_num_experts) + + for call in local_experts_token_stats_calls: + array_arg = call[1]["array"] + self.assertEqual(array_arg.shape, (6, 4)) # (num_hidden_layers, moe_num_experts) + + # Check that single-element signals have shape (1,) + single_element_calls = [ + call + for call in call_args_list + if "rearrange_experts_status" in str(call) + or "rearrange_experts_ips_size" in str(call) + or "signal_update_weight_from_tensor" in str(call) + ] + + for call in single_element_calls: + array_arg = call[1]["array"] + self.assertEqual(array_arg.shape, (1,)) # Single element array + + def test_init_eplb_signals_suffix_format(self): + """Test init_eplb_signals uses correct suffix format.""" + mock_model_config = Mock() + mock_model_config.num_hidden_layers = 4 + mock_model_config.moe_num_experts = 2 + + mock_eplb_config = Mock() + mock_eplb_config.redundant_expert_ip_shm_size = 256 + + mock_parallel_config = Mock() + mock_parallel_config.tensor_parallel_rank = 0 + mock_parallel_config.local_data_parallel_id = 3 + mock_parallel_config.tensor_parallel_size = 1 + + mock_config = Mock() + mock_config.model_config = mock_model_config + mock_config.eplb_config = mock_eplb_config + mock_config.parallel_config = mock_parallel_config + + self.engine_client.config = mock_config + self.engine_client.fd_config = mock_config # Set fd_config as well + # Ensure tensor_parallel_size is set correctly + self.engine_client.tensor_parallel_size = 1 + + with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal: + mock_signal = Mock() + mock_ipcsignal.return_value = mock_signal + + self.engine_client.init_eplb_signals("7777") + + # Check suffix format + call_args_list = mock_ipcsignal.call_args_list + + # Check DP suffix + dp_calls = [call for call in call_args_list if "rearrange_experts_status" in str(call)] + self.assertEqual(len(dp_calls), 1) + self.assertEqual(dp_calls[0][1]["suffix"], "7777_dp3") + + # Check TP suffix for TP rank 0 + tp_calls = [call for call in call_args_list if "signal_clear_experts_token_stats" in str(call)] + self.assertEqual(len(tp_calls), 1) + self.assertEqual(tp_calls[0][1]["suffix"], "7777_dp3_tp0") + + def test_init_eplb_signals_list_initialization(self): + """Test init_eplb_signals properly initializes all signal lists.""" + mock_model_config = Mock() + mock_model_config.num_hidden_layers = 2 + mock_model_config.moe_num_experts = 2 + + mock_eplb_config = Mock() + mock_eplb_config.redundant_expert_ip_shm_size = 128 + + mock_parallel_config = Mock() + mock_parallel_config.tensor_parallel_rank = 0 + mock_parallel_config.local_data_parallel_id = 0 + mock_parallel_config.tensor_parallel_size = 3 + + mock_config = Mock() + mock_config.model_config = mock_model_config + mock_config.eplb_config = mock_eplb_config + mock_config.parallel_config = mock_parallel_config + + self.engine_client.config = mock_config + self.engine_client.tensor_parallel_size = 3 # Set this to match mock_parallel_config.tensor_parallel_size + self.engine_client.fd_config = mock_config # Also set fd_config to ensure proper access + + # Ensure lists start empty + self.engine_client.signal_clear_experts_token_stats_list = [] + self.engine_client.local_experts_token_stats_array_list = [] + self.engine_client.expert_tokens_stats_array_list = [] + self.engine_client.signal_update_weight_from_disk_array_list = [] + self.engine_client.update_weight_from_disk_result_list = [] + + with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal: + mock_signal = Mock() + mock_ipcsignal.return_value = mock_signal + + self.engine_client.init_eplb_signals("6666") + + # Check that all lists have correct length (3 TP ranks) + self.assertEqual(len(self.engine_client.signal_clear_experts_token_stats_list), 3) + self.assertEqual(len(self.engine_client.local_experts_token_stats_array_list), 3) + self.assertEqual(len(self.engine_client.expert_tokens_stats_array_list), 3) + self.assertEqual(len(self.engine_client.signal_update_weight_from_disk_array_list), 3) + self.assertEqual(len(self.engine_client.update_weight_from_disk_result_list), 3) + + async def test_init_iluvatar_platform(self): + """Test EngineClient initialization on Iluvatar platform.""" + mock_model_config = Mock() + mock_model_config.enable_mm = False + + mock_config = Mock() + mock_config.model_config = mock_model_config + mock_config.eplb_config = Mock() + mock_config.eplb_config.enable_eplb = False + + with ( + patch("fastdeploy.entrypoints.engine_client.InputPreprocessor") as mock_processor_class, + patch("fastdeploy.entrypoints.engine_client.current_platform") as mock_platform, + patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal, + patch("fastdeploy.entrypoints.engine_client.envs") as mock_envs, + patch("os.getenv", return_value="50"), + ): + mock_platform.is_iluvatar.return_value = True # Iluvatar platform + mock_input_processor = Mock() + mock_processor_class.return_value = mock_input_processor + mock_processor = Mock() + mock_input_processor.create_processor.return_value = mock_processor + + mock_signal_instance = Mock() + mock_signal_instance.value = np.array([0]) + mock_ipcsignal.return_value = mock_signal_instance + mock_envs.FD_SUPPORT_MAX_CONNECTIONS = 100 + + client = EngineClient( + model_name_or_path="test_model", + tokenizer=Mock(), + max_model_len=2048, + tensor_parallel_size=1, + pid=5678, + port=8080, + limit_mm_per_prompt=5, + mm_processor_kwargs={}, + config=mock_config, + reasoning_parser=None, + data_parallel_size=1, + enable_logprob=True, + workers=1, + tool_parser=None, + enable_prefix_caching=False, + splitwise_role=None, + max_processor_cache=0, + ) + + self.assertTrue(client.is_master) # With 1 tensor_parallel_size, should be master even on Iluvatar + + def test_check_mm_disable_prefix_cache_without_multimodal_data(self): + """Test _check_mm_disable_prefix_cache without multimodal data.""" + self.engine_client.disable_prefix_mm = True + + task = {"multimodal_inputs": {"token_type_ids": [0, 0, 0]}} # Sum = 0 + + result = self.engine_client._check_mm_disable_prefix_cache(task) + self.assertFalse(result) + + async def test_add_requests_multimodal_prefix_cache_error(self): + """Test add_requests with multimodal data when prefix cache is enabled.""" + self.engine_client.enable_mm = True + self.engine_client.enable_prefix_caching = True + self.engine_client.disable_prefix_mm = True + self.engine_client.data_processor = Mock() + self.engine_client.data_processor.process_request_dict = Mock() + + task = { + "request_id": "test_request", + "user": "test_user", + "multimodal_inputs": {"token_type_ids": [1, 1, 0, 1]}, # Multimodal data present + "prompt_token_ids": [1, 2, 3], + "max_tokens": 100, + } + + with self.assertRaises(EngineError) as context: + await self.engine_client.add_requests(task) + + self.assertIn("does not support processing requests containing multimodal data", str(context.exception)) + self.assertEqual(context.exception.error_code, 400) + + async def test_add_requests_input_too_long_error(self): + """Test add_requests with input length too long.""" + self.engine_client.max_model_len = 10 + self.engine_client.data_processor = Mock() + self.engine_client.data_processor.process_request_dict = Mock() + + task = { + "request_id": "test_request", + "user": "test_user", + "prompt_token_ids": [1, 2, 3, 4, 5, 6, 7, 8], # length = 8 + "max_tokens": 5, # 8 + 5 = 13 >= 10 + "min_tokens": 2, + } + + with self.assertRaises(EngineError) as context: + await self.engine_client.add_requests(task) + + self.assertIn("Input text is too long", str(context.exception)) + self.assertIn("input_ids_len (8) + min_tokens(2) >= max_model_len(10)", str(context.exception)) + self.assertEqual(context.exception.error_code, 400) + + @patch("fastdeploy.entrypoints.engine_client.envs.FD_MAX_STOP_SEQS_NUM", 3) + async def test_add_requests_stop_seqs_num_exceeds_limit(self): + """Test add_requests with stop sequences number exceeding limit.""" + self.engine_client.data_processor = Mock() + self.engine_client.data_processor.process_request_dict = Mock() + + task = { + "request_id": "test_request", + "user": "test_user", + "prompt_token_ids": [1, 2, 3], + "max_tokens": 10, + "stop_seqs_len": [10, 20, 30, 40], # 4 sequences > limit of 3 + } + + with self.assertRaises(EngineError) as context: + await self.engine_client.add_requests(task) + + self.assertIn( + "Length of stop ([10, 20, 30, 40]) exceeds the limit max_stop_seqs_num(3)", str(context.exception) + ) + self.assertIn("Please reduce the number of stop or set a lager max_stop_seqs_num", str(context.exception)) + self.assertEqual(context.exception.error_code, 400) + + @patch("fastdeploy.entrypoints.engine_client.envs.FD_STOP_SEQS_MAX_LEN", 5) + async def test_add_requests_single_stop_seq_len_exceeds_limit(self): + """Test add_requests with single stop sequence length exceeding limit.""" + self.engine_client.data_processor = Mock() + self.engine_client.data_processor.process_request_dict = Mock() + + task = { + "request_id": "test_request", + "user": "test_user", + "prompt_token_ids": [1, 2, 3], + "max_tokens": 10, + "stop_seqs_len": [3, 10, 2], # 10 > limit of 5 + } + + with self.assertRaises(EngineError) as context: + await self.engine_client.add_requests(task) + + self.assertIn("Length of stop_seqs(10) exceeds the limit stop_seqs_max_len(5)", str(context.exception)) + self.assertIn( + "Please reduce the length of stop sequences or set a larger stop_seqs_max_len", str(context.exception) + ) + self.assertEqual(context.exception.error_code, 400) + + async def test_rearrange_experts_eplb_disabled(self): + """Test rearrange_experts when EPLB is disabled.""" + # Mock eplb_config with enable_eplb = False + mock_eplb_config = Mock() + mock_eplb_config.enable_eplb = False + + mock_config = Mock() + mock_config.eplb_config = mock_eplb_config + + self.engine_client.config = mock_config + + request_dict = {"user": "test_user", "passwd": "test_pass"} + + content, status_code = await self.engine_client.rearrange_experts(request_dict) + + expected_content = {"code": 1, "msg": "redundant expert is disabled"} + self.assertEqual(content, expected_content) + self.assertEqual(status_code.value, 400) # BAD_REQUEST + + async def test_rearrange_experts_invalid_credentials(self): + """Test rearrange_experts with invalid user/password.""" + # Mock eplb_config with enable_eplb = True + mock_eplb_config = Mock() + mock_eplb_config.enable_eplb = True + mock_eplb_config.redundant_expert_api_user = "valid_user" + mock_eplb_config.redundant_expert_api_password = "valid_pass" + + mock_config = Mock() + mock_config.eplb_config = mock_eplb_config + mock_config.parallel_config.tensor_parallel_rank = 0 + + self.engine_client.config = mock_config + + request_dict = {"user": "invalid_user", "passwd": "invalid_pass"} + + content, status_code = await self.engine_client.rearrange_experts(request_dict) + + expected_content = {"code": 1, "msg": "user or passwd is invalid"} + self.assertEqual(content, expected_content) + self.assertEqual(status_code.value, 401) # UNAUTHORIZED + + async def test_rearrange_experts_non_rank_zero(self): + """Test rearrange_experts from non-zero rank.""" + # Mock eplb_config with enable_eplb = True + mock_eplb_config = Mock() + mock_eplb_config.enable_eplb = True + mock_eplb_config.redundant_expert_api_user = "test_user" + mock_eplb_config.redundant_expert_api_password = "test_pass" + + mock_config = Mock() + mock_config.eplb_config = mock_eplb_config + mock_config.parallel_config.tensor_parallel_rank = 2 # Non-zero rank + + self.engine_client.config = mock_config + + request_dict = {"user": "test_user", "passwd": "test_pass"} + + content, status_code = await self.engine_client.rearrange_experts(request_dict) + + expected_content = {"code": 1, "msg": "actual rank 2, expect rank 0"} + self.assertEqual(content, expected_content) + self.assertEqual(status_code.value, 400) # BAD_REQUEST + + async def test_rearrange_experts_recv_expert_weight_invalid_data(self): + """Test rearrange_experts recv_expert_weight action with invalid data.""" + # Mock eplb_config + mock_eplb_config = Mock() + mock_eplb_config.enable_eplb = True + mock_eplb_config.redundant_expert_api_user = "test_user" + mock_eplb_config.redundant_expert_api_password = "test_pass" + + mock_config = Mock() + mock_config.eplb_config = mock_eplb_config + mock_config.parallel_config.tensor_parallel_rank = 0 + + self.engine_client.config = mock_config + + request_dict = { + "user": "test_user", + "passwd": "test_pass", + "action": "recv_expert_weight", + # Missing "data" field + } + + content, status_code = await self.engine_client.rearrange_experts(request_dict) + + expected_content = {"code": 1, "msg": "data not in request or data is not a list"} + self.assertEqual(content, expected_content) + self.assertEqual(status_code.value, 400) # BAD_REQUEST + + async def test_rearrange_experts_invalid_action(self): + """Test rearrange_experts with invalid action.""" + # Mock eplb_config + mock_eplb_config = Mock() + mock_eplb_config.enable_eplb = True + mock_eplb_config.redundant_expert_api_user = "test_user" + mock_eplb_config.redundant_expert_api_password = "test_pass" + + mock_config = Mock() + mock_config.eplb_config = mock_eplb_config + mock_config.parallel_config.tensor_parallel_rank = 0 + + self.engine_client.config = mock_config + + request_dict = {"user": "test_user", "passwd": "test_pass", "action": "invalid_action"} + + content, status_code = await self.engine_client.rearrange_experts(request_dict) + + expected_content = {"code": 1, "msg": "invalid action invalid_action"} + self.assertEqual(content, expected_content) + self.assertEqual(status_code.value, 400) # BAD_REQUEST + if __name__ == "__main__": unittest.main() diff --git a/tests/entrypoints/test_vllm_run_engine.py b/tests/entrypoints/test_vllm_run_engine.py index 22783b19775..4ac03116544 100644 --- a/tests/entrypoints/test_vllm_run_engine.py +++ b/tests/entrypoints/test_vllm_run_engine.py @@ -1,4 +1,5 @@ -from unittest.mock import MagicMock +import os +from unittest.mock import MagicMock, patch import numpy as np import pytest @@ -14,56 +15,147 @@ def __init__(self, max_logprobs=10, ori_vocab_size=50): self.ori_vocab_size = ori_vocab_size +class DummyCacheConfig: + def __init__(self, enable_prefix_caching=False): + self.enable_prefix_caching = enable_prefix_caching + + +class DummyLLMEngineConfig: + def __init__(self, model_config=None, cache_config=None): + self.model_config = model_config or DummyModelConfig() + self.cache_config = cache_config or DummyCacheConfig() + + +class DummyLLMEngine: + def __init__(self, model_config=None, cache_config=None): + self.cfg = DummyLLMEngineConfig(model_config, cache_config) + self.data_processor = MagicMock() + # Mock tokenizer with sp_model attribute + self.data_processor.tokenizer = MagicMock() + self.data_processor.tokenizer.sp_model = MagicMock() + self.data_processor.tokenizer.sp_model.__len__ = MagicMock(return_value=100) + self.data_processor.tokenizer.vocab = MagicMock() + self.data_processor.tokenizer.vocab.__len__ = MagicMock(return_value=100) + self.data_processor.process_logprob_response.side_effect = lambda ids, **kwargs: f"TOKEN_{ids[0]}" + self.add_requests = MagicMock() + + @pytest.fixture def mock_llm(): llm = LLM.__new__(LLM) - llm.llm_engine = MagicMock() - llm.llm_engine.add_requests = MagicMock() - llm.llm_engine.cfg.model_config = DummyModelConfig(max_logprobs=10, ori_vocab_size=100) - # Mock the data_processor.process_logprob_response method to return proper strings - llm.llm_engine.data_processor = MagicMock() - llm.llm_engine.data_processor.process_logprob_response.side_effect = lambda ids, **kwargs: f"TOKEN_{ids[0]}" + llm.llm_engine = DummyLLMEngine() + return llm + + +@pytest.fixture +def mock_llm_with_prefix_caching(): + llm = LLM.__new__(LLM) + llm.llm_engine = DummyLLMEngine(cache_config=DummyCacheConfig(enable_prefix_caching=True)) return llm def test_prompt_logprobs_not_supported_with_stream(mock_llm): - sampling = SamplingParams(prompt_logprobs=5) - with pytest.raises(ValueError, match="prompt_logprobs is not supported with streaming"): - mock_llm._add_request(["hi"], sampling, stream=True) + # Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + sampling = SamplingParams(prompt_logprobs=5) + with pytest.raises(ValueError, match="prompt_logprobs is not supported with streaming"): + mock_llm._add_request(["hi"], sampling, stream=True) + + +def test_prompt_logprobs_not_supported_with_prefix_caching(mock_llm_with_prefix_caching): + # Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + sampling = SamplingParams(prompt_logprobs=5) + with pytest.raises(ValueError, match="prompt_logprobs is not supported with prefix caching enabled"): + mock_llm_with_prefix_caching._add_request(["hi"], sampling) def test_num_logprobs_exceeds_max(mock_llm): - sampling = SamplingParams(logprobs=20) - with pytest.raises(ValueError, match="Number of logprobs requested"): + # Set FD_USE_GET_SAVE_OUTPUT_V1=1 to allow logprobs > 20 + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + sampling = SamplingParams(logprobs=20) + with pytest.raises(ValueError, match="Number of logprobs requested"): + mock_llm._add_request(["hi"], sampling) + + +def test_max_logprobs_exceeds_vocab_size(mock_llm): + # Test case where max_logprobs > ori_vocab_size + mock_llm.llm_engine.cfg.model_config.max_logprobs = 150 # > vocab size (100) + with pytest.raises(ValueError, match="max_logprobs \\(150\\) exceeds vocabulary size \\(100\\)"): + mock_llm._add_request(["hi"], SamplingParams()) + + +def test_max_logprobs_less_than_minus_one(mock_llm): + # Test case where max_logprobs < -1 + mock_llm.llm_engine.cfg.model_config.max_logprobs = -2 + with pytest.raises(ValueError, match="max_logprobs \\(-2\\) can't be less than -1"): + mock_llm._add_request(["hi"], SamplingParams()) + + +def test_logprobs_minus_one_uses_vocab_size(mock_llm): + # Test that logprobs=-1 uses vocab size + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + sampling = SamplingParams(logprobs=-1) + mock_llm.llm_engine.cfg.model_config.max_logprobs = -1 # Allow unlimited mock_llm._add_request(["hi"], sampling) + mock_llm.llm_engine.add_requests.assert_called_once() def test_num_prompt_logprobs_exceeds_max(mock_llm): - sampling = SamplingParams(prompt_logprobs=20) - with pytest.raises(ValueError, match="Number of logprobs requested"): - mock_llm._add_request(["hi"], sampling) + # Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + sampling = SamplingParams(prompt_logprobs=20) + with pytest.raises(ValueError, match="Number of logprobs requested"): + mock_llm._add_request(["hi"], sampling) def test_logprobs_equal_to_minus_one_uses_ori_vocab_size(mock_llm): - sampling = SamplingParams(logprobs=-1) - mock_llm.llm_engine.cfg.model_config.max_logprobs = -1 - mock_llm.llm_engine.cfg.model_config.ori_vocab_size = 30 - mock_llm._add_request(["hi"], sampling) - mock_llm.llm_engine.add_requests.assert_called_once() - # Get the first argument (tasks) which should be a dict - call_args = mock_llm.llm_engine.add_requests.call_args - tasks = call_args[0][0] # First positional argument - assert isinstance(tasks, dict) - assert "prompt" in tasks - assert "request_id" in tasks + # Set FD_USE_GET_SAVE_OUTPUT_V1=1 to allow logprobs=-1 + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + sampling = SamplingParams(logprobs=-1) + mock_llm.llm_engine.cfg.model_config.max_logprobs = -1 + mock_llm._add_request(["hi"], sampling) + mock_llm.llm_engine.add_requests.assert_called_once() + # Get the first argument (tasks) which should be a dict + call_args = mock_llm.llm_engine.add_requests.call_args + tasks = call_args[0][0] # First positional argument + assert isinstance(tasks, dict) + assert "prompt" in tasks + assert "request_id" in tasks def test_prompt_logprobs_equal_to_minus_one(mock_llm): - sampling = SamplingParams(prompt_logprobs=-1) + # Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support and allow -1 + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + sampling = SamplingParams(prompt_logprobs=-1) + mock_llm.llm_engine.cfg.model_config.max_logprobs = -1 + mock_llm._add_request(["hi"], sampling) + mock_llm.llm_engine.add_requests.assert_called_once() + + +def test_dynamic_vocab_size_from_sp_model(mock_llm): + # Test that ori_vocab_size is dynamically obtained from sp_model + mock_llm.llm_engine.data_processor.tokenizer.sp_model.__len__.return_value = 200 mock_llm.llm_engine.cfg.model_config.max_logprobs = -1 - mock_llm.llm_engine.cfg.model_config.ori_vocab_size = 25 - mock_llm._add_request(["hi"], sampling) - mock_llm.llm_engine.add_requests.assert_called_once() + + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + sampling = SamplingParams(logprobs=-1) + mock_llm._add_request(["hi"], sampling) + # Should use the dynamic vocab size (200) + mock_llm.llm_engine.add_requests.assert_called_once() + + +def test_dynamic_vocab_size_from_vocab_fallback(mock_llm): + # Test fallback to vocab when sp_model is not available + del mock_llm.llm_engine.data_processor.tokenizer.sp_model + mock_llm.llm_engine.data_processor.tokenizer.vocab.__len__.return_value = 300 + mock_llm.llm_engine.cfg.model_config.max_logprobs = -1 + + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + sampling = SamplingParams(logprobs=-1) + mock_llm._add_request(["hi"], sampling) + # Should use the vocab size (300) + mock_llm.llm_engine.add_requests.assert_called_once() def test_build_prompt_logprobs_basic(mock_llm): @@ -77,12 +169,13 @@ def test_build_prompt_logprobs_basic(mock_llm): # 检查结果格式 assert isinstance(result, list) - assert len(result) == 2 + assert len(result) == 3 for pos_dict in result: - assert isinstance(pos_dict, dict) - for logprob_obj in pos_dict.values(): - assert isinstance(logprob_obj, Logprob) - assert logprob_obj.decoded_token.startswith("TOKEN_") + if pos_dict is not None: + assert isinstance(pos_dict, dict) + for logprob_obj in pos_dict.values(): + assert isinstance(logprob_obj, Logprob) + assert logprob_obj.decoded_token.startswith("TOKEN_") def test_build_prompt_logprobs_handles_minus_one(mock_llm): @@ -94,7 +187,7 @@ def test_build_prompt_logprobs_handles_minus_one(mock_llm): result = mock_llm._build_prompt_logprobs(tensors, num_prompt_logprobs=-1) assert isinstance(result, list) - assert len(result) == 1 - pos_dict = result[0] + assert len(result) == 2 + pos_dict = result[1] assert 7 in pos_dict assert pos_dict[7].decoded_token == "TOKEN_7" diff --git a/tests/output/test_process_batch_output_use_zmq.py b/tests/output/test_process_batch_output_use_zmq.py index e58f613d1c8..b27cb14c866 100644 --- a/tests/output/test_process_batch_output_use_zmq.py +++ b/tests/output/test_process_batch_output_use_zmq.py @@ -137,7 +137,7 @@ def test_process_prompt_logprobs_failure(self): result = self.processor._process_batch_output_use_zmq([stream_data]) self.assertEqual(len(result), 1) - self.assertIsNone(getattr(result[0], "prompt_logprobs_tensors", None)) + self.assertIsNone(getattr(result[0], "prompt_logprobs", None)) def test_process_batch_with_stop_flag(self): """Test processing when stop flag is True""" diff --git a/tests/utils/test_clamp_prompt_logprobs.py b/tests/utils/test_clamp_prompt_logprobs.py new file mode 100644 index 00000000000..9ae0eeee560 --- /dev/null +++ b/tests/utils/test_clamp_prompt_logprobs.py @@ -0,0 +1,133 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest + +from fastdeploy.utils import clamp_prompt_logprobs +from fastdeploy.worker.output import Logprob + + +class TestClampPromptLogprobs(unittest.TestCase): + def test_none_input(self): + """Test case when input is None""" + result = clamp_prompt_logprobs(None) + self.assertIsNone(result) + + def test_empty_list(self): + """Test empty list input""" + result = clamp_prompt_logprobs([]) + self.assertEqual(result, []) + + def test_normal_logprobs(self): + """Test normal logprobs values (without -inf)""" + logprob_dict = { + 1: Logprob(logprob=-2.5, rank=1, decoded_token="hello"), + 2: Logprob(logprob=-1.0, rank=2, decoded_token="world"), + } + prompt_logprobs = [logprob_dict] + + result = clamp_prompt_logprobs(prompt_logprobs) + + # Original values should remain unchanged + self.assertEqual(result[0][1].logprob, -2.5) + self.assertEqual(result[0][2].logprob, -1.0) + + def test_negative_inf_logprobs_raises_error(self): + """Test that logprobs containing -inf raises AttributeError""" + logprob_dict = { + 1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"), + 2: Logprob(logprob=-1.0, rank=2, decoded_token="world"), + } + prompt_logprobs = [logprob_dict] + + # Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError + with self.assertRaises(AttributeError) as context: + clamp_prompt_logprobs(prompt_logprobs) + + self.assertIn("can't set attribute", str(context.exception)) + + def test_multiple_negative_inf_raises_error(self): + """Test that multiple -inf logprobs values raise AttributeError""" + logprob_dict = { + 1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"), + 2: Logprob(logprob=float("-inf"), rank=2, decoded_token="world"), + 3: Logprob(logprob=-0.5, rank=3, decoded_token="test"), + } + prompt_logprobs = [logprob_dict] + + # Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError + with self.assertRaises(AttributeError): + clamp_prompt_logprobs(prompt_logprobs) + + def test_none_dict_in_list(self): + """Test case when list contains None""" + prompt_logprobs = [None] + + result = clamp_prompt_logprobs(prompt_logprobs) + + # None should be skipped + self.assertIsNone(result[0]) + + def test_multiple_dicts_normal_values(self): + """Test multiple dictionaries case (without -inf)""" + logprob_dict1 = { + 1: Logprob(logprob=-2.0, rank=1, decoded_token="hello"), + } + logprob_dict2 = { + 2: Logprob(logprob=-2.0, rank=1, decoded_token="world"), + } + prompt_logprobs = [logprob_dict1, logprob_dict2] + + result = clamp_prompt_logprobs(prompt_logprobs) + + # Should return normally, values remain unchanged + self.assertEqual(result[0][1].logprob, -2.0) + self.assertEqual(result[1][2].logprob, -2.0) + + def test_mixed_values_without_inf(self): + """Test mixed values case (without -inf)""" + logprob_dict = { + 1: Logprob(logprob=-9999.0, rank=1, decoded_token="hello"), + 2: Logprob(logprob=-9999.0, rank=2, decoded_token="world"), + 3: Logprob(logprob=0.0, rank=3, decoded_token="test"), + 4: Logprob(logprob=-1.5, rank=4, decoded_token="again"), + } + prompt_logprobs = [logprob_dict] + + result = clamp_prompt_logprobs(prompt_logprobs) + + # All values should remain unchanged + self.assertEqual(result[0][1].logprob, -9999.0) + self.assertEqual(result[0][2].logprob, -9999.0) + self.assertEqual(result[0][3].logprob, 0.0) + self.assertEqual(result[0][4].logprob, -1.5) + + def test_return_same_object(self): + """Test that function returns the same object (in-place modification attempt)""" + logprob_dict = { + 1: Logprob(logprob=-2.0, rank=1, decoded_token="hello"), + } + prompt_logprobs = [logprob_dict] + + result = clamp_prompt_logprobs(prompt_logprobs) + + # Should return the same object (function attempts in-place modification) + self.assertIs(result, prompt_logprobs) + self.assertIs(result[0], prompt_logprobs[0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/woker/test_gpu_prompt_logprobs.py b/tests/woker/test_gpu_prompt_logprobs.py index ba0b9d9fedc..052ea1bbd56 100644 --- a/tests/woker/test_gpu_prompt_logprobs.py +++ b/tests/woker/test_gpu_prompt_logprobs.py @@ -16,6 +16,7 @@ import os import time import unittest +from unittest.mock import patch import numpy as np import paddle @@ -163,44 +164,46 @@ def setup_model_runner(self): return model_runner def test_prompt_logprobs(self): - model_runner = self.setup_model_runner() - - req: Request = Request( - prompt=None, - messages=None, - history=None, - tools=None, - system=None, - eos_token_ids=None, - arrival_time=None, - request_id="asd1", - prompt_token_ids=[1, 2, 3, 4], - prompt_token_ids_len=4, - prefill_start_index=0, - prefill_end_index=4, - sampling_params=SamplingParams(prompt_logprobs=-1), - ) - req.idx = 0 - model_runner.prompt_logprobs_reqs = {req.request_id: req} - - hidden_states = paddle.rand( - [len(req.prompt_token_ids) - 1, model_runner.fd_config.model_config.hidden_size], dtype="bfloat16" - ) - ref_logits = model_runner.model.compute_logits(hidden_states) - ref_raw_logprobs = model_runner.sampler.compute_logprobs(ref_logits) - token_is = paddle.to_tensor(req.prompt_token_ids[1:], dtype="int64") - - ref_token_ids, ref_logprobs, ref_ranks = model_runner.sampler.gather_logprobs( - ref_raw_logprobs, model_runner.fd_config.model_config.ori_vocab_size, token_is - ) - prompt_logprobs = model_runner._get_prompt_logprobs_list(hidden_states)[0] - np.testing.assert_allclose(ref_logprobs.numpy(), prompt_logprobs.logprobs.numpy(), rtol=1e-04, atol=1e-04) - np.testing.assert_allclose( - ref_token_ids.numpy(), prompt_logprobs.logprob_token_ids.numpy(), rtol=1e-04, atol=1e-04 - ) - np.testing.assert_allclose( - ref_ranks.numpy(), prompt_logprobs.selected_token_ranks.numpy(), rtol=1e-04, atol=1e-04 - ) + # Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + model_runner = self.setup_model_runner() + + req: Request = Request( + prompt=None, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + arrival_time=None, + request_id="asd1", + prompt_token_ids=[1, 2, 3, 4], + prompt_token_ids_len=4, + prefill_start_index=0, + prefill_end_index=4, + sampling_params=SamplingParams(prompt_logprobs=-1), + ) + req.idx = 0 + model_runner.prompt_logprobs_reqs = {req.request_id: req} + + hidden_states = paddle.rand( + [len(req.prompt_token_ids) - 1, model_runner.fd_config.model_config.hidden_size], dtype="bfloat16" + ) + ref_logits = model_runner.model.compute_logits(hidden_states) + ref_raw_logprobs = model_runner.sampler.compute_logprobs(ref_logits) + token_is = paddle.to_tensor(req.prompt_token_ids[1:], dtype="int64") + + ref_token_ids, ref_logprobs, ref_ranks = model_runner.sampler.gather_logprobs( + ref_raw_logprobs, model_runner.fd_config.model_config.ori_vocab_size, token_is + ) + prompt_logprobs = model_runner._get_prompt_logprobs_list(hidden_states)[0] + np.testing.assert_allclose(ref_logprobs.numpy(), prompt_logprobs.logprobs.numpy(), rtol=1e-04, atol=1e-04) + np.testing.assert_allclose( + ref_token_ids.numpy(), prompt_logprobs.logprob_token_ids.numpy(), rtol=1e-04, atol=1e-04 + ) + np.testing.assert_allclose( + ref_ranks.numpy(), prompt_logprobs.selected_token_ranks.numpy(), rtol=1e-04, atol=1e-04 + ) if __name__ == "__main__":