diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 78918314509..290224e4a2f 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -418,13 +418,16 @@ def valid_parameters(self, data): # logprobs logprobs = data.get("logprobs") top_logprobs = None + is_chat = False - if isinstance(logprobs, bool) and logprobs: - if not self.enable_logprob: - err_msg = "Logprobs is disabled, please enable it in startup config." - api_server_logger.error(err_msg) - raise ParameterError("logprobs", err_msg) - top_logprobs = data.get("top_logprobs") + if isinstance(logprobs, bool): + if logprobs: + is_chat = True + if not self.enable_logprob: + err_msg = "Logprobs is disabled, please enable it in startup config." + api_server_logger.error(err_msg) + raise ParameterError("logprobs", err_msg) + top_logprobs = data.get("top_logprobs") elif isinstance(logprobs, int): top_logprobs = logprobs elif logprobs: @@ -478,38 +481,40 @@ def valid_parameters(self, data): raise ValueError("prompt_logprobs", err_msg) # enable_logprob - if top_logprobs: + if top_logprobs is not None: if not self.enable_logprob: err_msg = "Logprobs is disabled, please enable it in startup config." api_server_logger.error(err_msg) - raise ParameterError("logprobs", err_msg) + raise ParameterError("top_logprobs" if is_chat else "logprobs", err_msg) if not isinstance(top_logprobs, int): err_type = type(top_logprobs).__name__ - err_msg = f"Invalid type for 'top_logprobs': expected int but got {err_type}." + err_msg = ( + f"Invalid type for {'top_logprobs' if is_chat else 'logprobs'}: expected int but got {err_type}." + ) + api_server_logger.error(err_msg) + raise ParameterError("top_logprobs" if is_chat else "logprobs", err_msg) + + if top_logprobs > max_logprobs: + err_msg = f"Number of {'top_logprobs' if is_chat else 'logprobs'} requested ({top_logprobs}) exceeds maximum allowed value ({max_logprobs})." api_server_logger.error(err_msg) - raise ParameterError("top_logprobs", err_msg) + raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg) if not envs.FD_USE_GET_SAVE_OUTPUT_V1: - if top_logprobs < 0 or top_logprobs > 20: - err_msg = f"top_logprobs must be between 0 and 20; the current value is {top_logprobs}." + if top_logprobs < 0 or top_logprobs > max_logprobs: + err_msg = f"{'top_logprobs' if is_chat else 'logprobs'} must be between 0 and {max_logprobs}; the current value is {top_logprobs}." api_server_logger.error(err_msg) - raise ValueError("top_logprobs", err_msg) + raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg) else: if top_logprobs == -1 and self.ori_vocab_size > max_logprobs: - err_msg = f"The requested value of ({self.ori_vocab_size}) for top_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})" + err_msg = f"The requested value of ({self.ori_vocab_size}) for {'top_logprobs' if is_chat else 'logprobs'} (-1) exceeds the maximum allowed value of ({max_logprobs})" api_server_logger.error(err_msg) - raise ValueError("top_logprobs", err_msg) + raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg) if top_logprobs < -1: - err_msg = f"top_logprobs must be a non-negative value or -1; the current value is {top_logprobs}." - api_server_logger.error(err_msg) - raise ValueError("top_logprobs", err_msg) - - if top_logprobs > max_logprobs: - err_msg = f"Number of logprobs requested ({top_logprobs}) exceeds maximum allowed value ({max_logprobs})." + err_msg = f"{'top_logprobs' if is_chat else 'logprobs'} must be a non-negative value or -1; the current value is {top_logprobs}." api_server_logger.error(err_msg) - raise ValueError("top_logprobs", err_msg) + raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg) def check_health(self, time_interval_threashold=30): """ diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py index 462e3c5950f..69bb60196c4 100644 --- a/fastdeploy/entrypoints/llm.py +++ b/fastdeploy/entrypoints/llm.py @@ -351,6 +351,10 @@ def _add_request( if current_sampling_params.logprobs is not None: num_logprobs = current_sampling_params.logprobs + if not self.llm_engine.cfg.model_config.enable_logprob: + raise ValueError( + "logprobs is only supported if `enable_logprob` is set to true in startup config." + ) if num_logprobs == -1 and ori_vocab_size > max_logprobs: raise ValueError( f"Number of logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})." @@ -360,6 +364,10 @@ def _add_request( f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})." ) if current_sampling_params.prompt_logprobs is not None: + if not self.llm_engine.cfg.model_config.enable_logprob: + raise ValueError( + "prompt_logprobs is only supported if `enable_logprob` is set to true in startup config." + ) if self.llm_engine.cfg.cache_config.enable_prefix_caching: raise ValueError("prompt_logprobs is not supported with prefix caching enabled.") if kwargs.get("stream"): @@ -403,19 +411,18 @@ def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: i llm_logger.warning("Empty logprob_token_ids in LogprobsLists") return None - # exclude sampled token at index 0 - available_topk = len(logprobs_lists.logprob_token_ids[0]) - 1 + available_topk = len(logprobs_lists.logprob_token_ids[0]) effective_topk_logprobs = min(topk_logprobs, available_topk) - if effective_topk_logprobs <= 0: + if effective_topk_logprobs < 0: llm_logger.warning( f"Invalid effective_topk_logprobs={effective_topk_logprobs}, " f"available_topk={available_topk}, topk_logprobs={topk_logprobs}; returning empty result." ) return None - # sliced 1 ~ (1 + effective_topk_logprobs) - sliced_logprobs_lists = logprobs_lists.slice_columns(1, 1 + effective_topk_logprobs) + # sliced 0 ~ effective_topk_logprobs+1 + sliced_logprobs_lists = logprobs_lists.slice_columns(0, effective_topk_logprobs + 1) result = [] for token_ids, logprobs in zip(sliced_logprobs_lists.logprob_token_ids, sliced_logprobs_lists.logprobs): @@ -559,7 +566,7 @@ def _run_engine( result = self.llm_engine.data_processor.process_response(result) # filter logprobs - if result.outputs.top_logprobs and topk_logprobs: + if result.outputs.top_logprobs is not None and topk_logprobs is not None: if topk_logprobs == -1: topk_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size result.outputs.logprobs = self._build_sample_logprobs( diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 0bb35a284dc..e91b5d44abf 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -611,7 +611,7 @@ class ChatCompletionRequest(BaseModel): model: Optional[str] = "default" frequency_penalty: Optional[float] = Field(None, le=2, ge=-2) logprobs: Optional[bool] = False - top_logprobs: Optional[int] = 0 + top_logprobs: Optional[int] = None prompt_logprobs: Optional[int] = None include_draft_logprobs: Optional[bool] = False diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 8bed4d9d915..2b66ce4e138 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -20,7 +20,8 @@ import paddle -class Logprob(NamedTuple): +@dataclass +class Logprob: """ A named tuple containing information about a token's log probability. """ diff --git a/tests/entrypoints/openai/test_build_sample_logprobs.py b/tests/entrypoints/openai/test_build_sample_logprobs.py index 74a00fcda52..0d11f072d1e 100644 --- a/tests/entrypoints/openai/test_build_sample_logprobs.py +++ b/tests/entrypoints/openai/test_build_sample_logprobs.py @@ -56,8 +56,9 @@ def test_build_sample_logprobs_basic(self): expected = [ { - 101: Logprob(logprob=-0.5, rank=1, decoded_token="token_101"), - 102: Logprob(logprob=-1.0, rank=2, decoded_token="token_102"), + 100: Logprob(logprob=-0.1, rank=1, decoded_token="token_100"), + 101: Logprob(logprob=-0.5, rank=2, decoded_token="token_101"), + 102: Logprob(logprob=-1.0, rank=3, decoded_token="token_102"), } ] @@ -79,7 +80,7 @@ def test_build_sample_logprobs_invalid_topk(self): logprobs_lists = MagicMock(spec=LogprobsLists) logprobs_lists.logprob_token_ids = [[100]] result = self.llm._build_sample_logprobs(logprobs_lists, topk_logprobs=2) - self.assertIsNone(result) + self.assertEqual(result, []) def test_decode_token(self): """ diff --git a/tests/entrypoints/openai/test_chatcompletion_request.py b/tests/entrypoints/openai/test_chatcompletion_request.py index 55aaf1944e5..a7f1985da20 100644 --- a/tests/entrypoints/openai/test_chatcompletion_request.py +++ b/tests/entrypoints/openai/test_chatcompletion_request.py @@ -41,7 +41,7 @@ def test_default_values(self): req = ChatCompletionRequest(messages=[1]) self.assertEqual(req.model, "default") self.assertFalse(req.logprobs) - self.assertEqual(req.top_logprobs, 0) + self.assertIsNone(req.top_logprobs) self.assertEqual(req.n, 1) self.assertEqual(req.stop, []) diff --git a/tests/entrypoints/test_engine_client.py b/tests/entrypoints/test_engine_client.py index 07e5d7b8708..78d21614d5b 100644 --- a/tests/entrypoints/test_engine_client.py +++ b/tests/entrypoints/test_engine_client.py @@ -489,8 +489,9 @@ def test_top_logprobs_validation_with_fd_use_get_save_output_v1_disabled(self): data = {"logprobs": True, "top_logprobs": 25, "request_id": "test"} with self.assertRaises(ValueError) as context: self.engine_client.valid_parameters(data) - self.assertIn("top_logprobs must be between 0 and 20", str(context.exception)) - self.assertIn("current value is 25", str(context.exception)) + self.assertIn( + "Number of top_logprobs requested (25) exceeds maximum allowed value (20)", str(context.exception) + ) # Test valid value data = {"logprobs": True, "top_logprobs": 10, "request_id": "test"} diff --git a/tests/entrypoints/test_vllm_run_engine.py b/tests/entrypoints/test_vllm_run_engine.py index 4ac03116544..2648bb3d26a 100644 --- a/tests/entrypoints/test_vllm_run_engine.py +++ b/tests/entrypoints/test_vllm_run_engine.py @@ -10,9 +10,10 @@ class DummyModelConfig: - def __init__(self, max_logprobs=10, ori_vocab_size=50): + def __init__(self, max_logprobs=10, ori_vocab_size=50, enable_logprob=True): self.max_logprobs = max_logprobs self.ori_vocab_size = ori_vocab_size + self.enable_logprob = enable_logprob class DummyCacheConfig: diff --git a/tests/utils/test_clamp_prompt_logprobs.py b/tests/utils/test_clamp_prompt_logprobs.py index 9ae0eeee560..1de76c24eab 100644 --- a/tests/utils/test_clamp_prompt_logprobs.py +++ b/tests/utils/test_clamp_prompt_logprobs.py @@ -45,22 +45,23 @@ def test_normal_logprobs(self): self.assertEqual(result[0][1].logprob, -2.5) self.assertEqual(result[0][2].logprob, -1.0) - def test_negative_inf_logprobs_raises_error(self): - """Test that logprobs containing -inf raises AttributeError""" + def test_negative_inf_logprobs_gets_clamped(self): + """Test that logprobs containing -inf get clamped to -9999.0""" logprob_dict = { 1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"), 2: Logprob(logprob=-1.0, rank=2, decoded_token="world"), } prompt_logprobs = [logprob_dict] - # Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError - with self.assertRaises(AttributeError) as context: - clamp_prompt_logprobs(prompt_logprobs) + # Since Logprob is now a dataclass, its fields can be modified + result = clamp_prompt_logprobs(prompt_logprobs) - self.assertIn("can't set attribute", str(context.exception)) + # The -inf value should be clamped to -9999.0 + self.assertEqual(result[0][1].logprob, -9999.0) + self.assertEqual(result[0][2].logprob, -1.0) # unchanged - def test_multiple_negative_inf_raises_error(self): - """Test that multiple -inf logprobs values raise AttributeError""" + def test_multiple_negative_inf_gets_clamped(self): + """Test that multiple -inf logprobs values get clamped to -9999.0""" logprob_dict = { 1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"), 2: Logprob(logprob=float("-inf"), rank=2, decoded_token="world"), @@ -68,9 +69,13 @@ def test_multiple_negative_inf_raises_error(self): } prompt_logprobs = [logprob_dict] - # Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError - with self.assertRaises(AttributeError): - clamp_prompt_logprobs(prompt_logprobs) + # Since Logprob is now a dataclass, its fields can be modified + result = clamp_prompt_logprobs(prompt_logprobs) + + # All -inf values should be clamped to -9999.0 + self.assertEqual(result[0][1].logprob, -9999.0) + self.assertEqual(result[0][2].logprob, -9999.0) + self.assertEqual(result[0][3].logprob, -0.5) # unchanged def test_none_dict_in_list(self): """Test case when list contains None""" @@ -116,7 +121,7 @@ def test_mixed_values_without_inf(self): self.assertEqual(result[0][4].logprob, -1.5) def test_return_same_object(self): - """Test that function returns the same object (in-place modification attempt)""" + """Test that function returns the same object (in-place modification)""" logprob_dict = { 1: Logprob(logprob=-2.0, rank=1, decoded_token="hello"), } @@ -124,7 +129,7 @@ def test_return_same_object(self): result = clamp_prompt_logprobs(prompt_logprobs) - # Should return the same object (function attempts in-place modification) + # Should return the same object (function performs in-place modification) self.assertIs(result, prompt_logprobs) self.assertIs(result[0], prompt_logprobs[0])