Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions fastdeploy/entrypoints/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,13 +418,16 @@ def valid_parameters(self, data):
# logprobs
logprobs = data.get("logprobs")
top_logprobs = None
is_chat = False

if isinstance(logprobs, bool) and logprobs:
if not self.enable_logprob:
err_msg = "Logprobs is disabled, please enable it in startup config."
api_server_logger.error(err_msg)
raise ParameterError("logprobs", err_msg)
top_logprobs = data.get("top_logprobs")
if isinstance(logprobs, bool):
if logprobs:
is_chat = True
if not self.enable_logprob:
err_msg = "Logprobs is disabled, please enable it in startup config."
api_server_logger.error(err_msg)
raise ParameterError("logprobs", err_msg)
top_logprobs = data.get("top_logprobs")
elif isinstance(logprobs, int):
top_logprobs = logprobs
elif logprobs:
Expand Down Expand Up @@ -478,38 +481,40 @@ def valid_parameters(self, data):
raise ValueError("prompt_logprobs", err_msg)

# enable_logprob
if top_logprobs:
if top_logprobs is not None:
if not self.enable_logprob:
err_msg = "Logprobs is disabled, please enable it in startup config."
api_server_logger.error(err_msg)
raise ParameterError("logprobs", err_msg)
raise ParameterError("top_logprobs" if is_chat else "logprobs", err_msg)

if not isinstance(top_logprobs, int):
err_type = type(top_logprobs).__name__
err_msg = f"Invalid type for 'top_logprobs': expected int but got {err_type}."
err_msg = (
f"Invalid type for {'top_logprobs' if is_chat else 'logprobs'}: expected int but got {err_type}."
)
api_server_logger.error(err_msg)
raise ParameterError("top_logprobs" if is_chat else "logprobs", err_msg)

if top_logprobs > max_logprobs:
err_msg = f"Number of {'top_logprobs' if is_chat else 'logprobs'} requested ({top_logprobs}) exceeds maximum allowed value ({max_logprobs})."
api_server_logger.error(err_msg)
raise ParameterError("top_logprobs", err_msg)
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)

if not envs.FD_USE_GET_SAVE_OUTPUT_V1:
if top_logprobs < 0 or top_logprobs > 20:
err_msg = f"top_logprobs must be between 0 and 20; the current value is {top_logprobs}."
if top_logprobs < 0 or top_logprobs > max_logprobs:
err_msg = f"{'top_logprobs' if is_chat else 'logprobs'} must be between 0 and {max_logprobs}; the current value is {top_logprobs}."
api_server_logger.error(err_msg)
raise ValueError("top_logprobs", err_msg)
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)
else:
if top_logprobs == -1 and self.ori_vocab_size > max_logprobs:
err_msg = f"The requested value of ({self.ori_vocab_size}) for top_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})"
err_msg = f"The requested value of ({self.ori_vocab_size}) for {'top_logprobs' if is_chat else 'logprobs'} (-1) exceeds the maximum allowed value of ({max_logprobs})"
api_server_logger.error(err_msg)
raise ValueError("top_logprobs", err_msg)
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)

if top_logprobs < -1:
err_msg = f"top_logprobs must be a non-negative value or -1; the current value is {top_logprobs}."
api_server_logger.error(err_msg)
raise ValueError("top_logprobs", err_msg)

if top_logprobs > max_logprobs:
err_msg = f"Number of logprobs requested ({top_logprobs}) exceeds maximum allowed value ({max_logprobs})."
err_msg = f"{'top_logprobs' if is_chat else 'logprobs'} must be a non-negative value or -1; the current value is {top_logprobs}."
api_server_logger.error(err_msg)
raise ValueError("top_logprobs", err_msg)
raise ValueError("top_logprobs" if is_chat else "logprobs", err_msg)

def check_health(self, time_interval_threashold=30):
"""
Expand Down
19 changes: 13 additions & 6 deletions fastdeploy/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,10 @@ def _add_request(

if current_sampling_params.logprobs is not None:
num_logprobs = current_sampling_params.logprobs
if not self.llm_engine.cfg.model_config.enable_logprob:
raise ValueError(
"logprobs is only supported if `enable_logprob` is set to true in startup config."
)
if num_logprobs == -1 and ori_vocab_size > max_logprobs:
raise ValueError(
f"Number of logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})."
Expand All @@ -360,6 +364,10 @@ def _add_request(
f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})."
)
if current_sampling_params.prompt_logprobs is not None:
if not self.llm_engine.cfg.model_config.enable_logprob:
raise ValueError(
"prompt_logprobs is only supported if `enable_logprob` is set to true in startup config."
)
if self.llm_engine.cfg.cache_config.enable_prefix_caching:
raise ValueError("prompt_logprobs is not supported with prefix caching enabled.")
if kwargs.get("stream"):
Expand Down Expand Up @@ -403,19 +411,18 @@ def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: i
llm_logger.warning("Empty logprob_token_ids in LogprobsLists")
return None

# exclude sampled token at index 0
available_topk = len(logprobs_lists.logprob_token_ids[0]) - 1
available_topk = len(logprobs_lists.logprob_token_ids[0])
effective_topk_logprobs = min(topk_logprobs, available_topk)

if effective_topk_logprobs <= 0:
if effective_topk_logprobs < 0:
llm_logger.warning(
f"Invalid effective_topk_logprobs={effective_topk_logprobs}, "
f"available_topk={available_topk}, topk_logprobs={topk_logprobs}; returning empty result."
)
return None

# sliced 1 ~ (1 + effective_topk_logprobs)
sliced_logprobs_lists = logprobs_lists.slice_columns(1, 1 + effective_topk_logprobs)
# sliced 0 ~ effective_topk_logprobs+1
sliced_logprobs_lists = logprobs_lists.slice_columns(0, effective_topk_logprobs + 1)
result = []
for token_ids, logprobs in zip(sliced_logprobs_lists.logprob_token_ids, sliced_logprobs_lists.logprobs):

Expand Down Expand Up @@ -559,7 +566,7 @@ def _run_engine(
result = self.llm_engine.data_processor.process_response(result)

# filter logprobs
if result.outputs.top_logprobs and topk_logprobs:
if result.outputs.top_logprobs is not None and topk_logprobs is not None:
if topk_logprobs == -1:
topk_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
result.outputs.logprobs = self._build_sample_logprobs(
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ class ChatCompletionRequest(BaseModel):
model: Optional[str] = "default"
frequency_penalty: Optional[float] = Field(None, le=2, ge=-2)
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
top_logprobs: Optional[int] = None
prompt_logprobs: Optional[int] = None
include_draft_logprobs: Optional[bool] = False

Expand Down
3 changes: 2 additions & 1 deletion fastdeploy/worker/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import paddle


class Logprob(NamedTuple):
@dataclass
class Logprob:
"""
A named tuple containing information about a token's log probability.
"""
Expand Down
7 changes: 4 additions & 3 deletions tests/entrypoints/openai/test_build_sample_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ def test_build_sample_logprobs_basic(self):

expected = [
{
101: Logprob(logprob=-0.5, rank=1, decoded_token="token_101"),
102: Logprob(logprob=-1.0, rank=2, decoded_token="token_102"),
100: Logprob(logprob=-0.1, rank=1, decoded_token="token_100"),
101: Logprob(logprob=-0.5, rank=2, decoded_token="token_101"),
102: Logprob(logprob=-1.0, rank=3, decoded_token="token_102"),
}
]

Expand All @@ -79,7 +80,7 @@ def test_build_sample_logprobs_invalid_topk(self):
logprobs_lists = MagicMock(spec=LogprobsLists)
logprobs_lists.logprob_token_ids = [[100]]
result = self.llm._build_sample_logprobs(logprobs_lists, topk_logprobs=2)
self.assertIsNone(result)
self.assertEqual(result, [])

def test_decode_token(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/entrypoints/openai/test_chatcompletion_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_default_values(self):
req = ChatCompletionRequest(messages=[1])
self.assertEqual(req.model, "default")
self.assertFalse(req.logprobs)
self.assertEqual(req.top_logprobs, 0)
self.assertIsNone(req.top_logprobs)
self.assertEqual(req.n, 1)
self.assertEqual(req.stop, [])

Expand Down
5 changes: 3 additions & 2 deletions tests/entrypoints/test_engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,9 @@ def test_top_logprobs_validation_with_fd_use_get_save_output_v1_disabled(self):
data = {"logprobs": True, "top_logprobs": 25, "request_id": "test"}
with self.assertRaises(ValueError) as context:
self.engine_client.valid_parameters(data)
self.assertIn("top_logprobs must be between 0 and 20", str(context.exception))
self.assertIn("current value is 25", str(context.exception))
self.assertIn(
"Number of top_logprobs requested (25) exceeds maximum allowed value (20)", str(context.exception)
)

# Test valid value
data = {"logprobs": True, "top_logprobs": 10, "request_id": "test"}
Expand Down
3 changes: 2 additions & 1 deletion tests/entrypoints/test_vllm_run_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@


class DummyModelConfig:
def __init__(self, max_logprobs=10, ori_vocab_size=50):
def __init__(self, max_logprobs=10, ori_vocab_size=50, enable_logprob=True):
self.max_logprobs = max_logprobs
self.ori_vocab_size = ori_vocab_size
self.enable_logprob = enable_logprob


class DummyCacheConfig:
Expand Down
31 changes: 18 additions & 13 deletions tests/utils/test_clamp_prompt_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,37 @@ def test_normal_logprobs(self):
self.assertEqual(result[0][1].logprob, -2.5)
self.assertEqual(result[0][2].logprob, -1.0)

def test_negative_inf_logprobs_raises_error(self):
"""Test that logprobs containing -inf raises AttributeError"""
def test_negative_inf_logprobs_gets_clamped(self):
"""Test that logprobs containing -inf get clamped to -9999.0"""
logprob_dict = {
1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"),
2: Logprob(logprob=-1.0, rank=2, decoded_token="world"),
}
prompt_logprobs = [logprob_dict]

# Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError
with self.assertRaises(AttributeError) as context:
clamp_prompt_logprobs(prompt_logprobs)
# Since Logprob is now a dataclass, its fields can be modified
result = clamp_prompt_logprobs(prompt_logprobs)

self.assertIn("can't set attribute", str(context.exception))
# The -inf value should be clamped to -9999.0
self.assertEqual(result[0][1].logprob, -9999.0)
self.assertEqual(result[0][2].logprob, -1.0) # unchanged

def test_multiple_negative_inf_raises_error(self):
"""Test that multiple -inf logprobs values raise AttributeError"""
def test_multiple_negative_inf_gets_clamped(self):
"""Test that multiple -inf logprobs values get clamped to -9999.0"""
logprob_dict = {
1: Logprob(logprob=float("-inf"), rank=1, decoded_token="hello"),
2: Logprob(logprob=float("-inf"), rank=2, decoded_token="world"),
3: Logprob(logprob=-0.5, rank=3, decoded_token="test"),
}
prompt_logprobs = [logprob_dict]

# Since Logprob is a NamedTuple, its fields cannot be modified, should raise AttributeError
with self.assertRaises(AttributeError):
clamp_prompt_logprobs(prompt_logprobs)
# Since Logprob is now a dataclass, its fields can be modified
result = clamp_prompt_logprobs(prompt_logprobs)

# All -inf values should be clamped to -9999.0
self.assertEqual(result[0][1].logprob, -9999.0)
self.assertEqual(result[0][2].logprob, -9999.0)
self.assertEqual(result[0][3].logprob, -0.5) # unchanged

def test_none_dict_in_list(self):
"""Test case when list contains None"""
Expand Down Expand Up @@ -116,15 +121,15 @@ def test_mixed_values_without_inf(self):
self.assertEqual(result[0][4].logprob, -1.5)

def test_return_same_object(self):
"""Test that function returns the same object (in-place modification attempt)"""
"""Test that function returns the same object (in-place modification)"""
logprob_dict = {
1: Logprob(logprob=-2.0, rank=1, decoded_token="hello"),
}
prompt_logprobs = [logprob_dict]

result = clamp_prompt_logprobs(prompt_logprobs)

# Should return the same object (function attempts in-place modification)
# Should return the same object (function performs in-place modification)
self.assertIs(result, prompt_logprobs)
self.assertIs(result[0], prompt_logprobs[0])

Expand Down
Loading