Skip to content
Merged
9 changes: 8 additions & 1 deletion fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,15 @@ def __init__(
self.think_end_id = args.get("think_end_id", -1)
self.im_patch_id = args.get("image_patch_id", -1)
self.line_break_id = args.get("line_break_id", -1)
if self.max_logprobs < -1:

num_max_logprobs = args.get("max_logprobs", None)
if num_max_logprobs is not None and num_max_logprobs < -1:
raise ValueError(" The possible values for max_logprobs can't be less than -1 ")
if self.ori_vocab_size is not None and num_max_logprobs is not None:
if num_max_logprobs > self.ori_vocab_size:
raise ValueError(
f" The possible values for max_logprobs can't be greater than the vocabulary size {self.ori_vocab_size}"
)

self._post_init()

Expand Down
9 changes: 1 addition & 8 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,7 @@
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.openai.protocol import ToolCall
from fastdeploy.utils import data_processor_logger
from fastdeploy.worker.output import (
LogprobsLists,
LogprobsTensors,
PromptLogprobs,
SampleLogprobs,
)
from fastdeploy.worker.output import LogprobsLists, PromptLogprobs, SampleLogprobs


class RequestStatus(Enum):
Expand Down Expand Up @@ -519,7 +514,6 @@ def __init__(
prompt: Optional[str] = None,
prompt_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[PromptLogprobs] = None,
prompt_logprobs_tensors: Optional[LogprobsTensors] = None,
output_type: Optional[int] = 3,
outputs: CompletionOutput = None,
finished: bool = False,
Expand All @@ -537,7 +531,6 @@ def __init__(
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.prompt_logprobs = prompt_logprobs
self.prompt_logprobs_tensors = prompt_logprobs_tensors
self.output_type = output_type
self.outputs = outputs
self.finished = finished
Expand Down
20 changes: 13 additions & 7 deletions fastdeploy/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

from __future__ import annotations

import os
import random
from dataclasses import dataclass, fields
from enum import Enum
from typing import Any, List, Optional, Union

from fastdeploy import envs


@dataclass
class SamplingParams:
Expand Down Expand Up @@ -207,12 +208,17 @@ def _verify_args(self) -> None:
raise ValueError(
f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}."
)
if self.logprobs is not None and self.logprobs < -1:
raise ValueError(f"logprobs must be greater than -1, got {self.logprobs}.")
if self.logprobs is not None and self.logprobs > 20 and os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0") == "0":
raise ValueError("Invalid value for 'top_logprobs': must be less than or equal to 20.")
if self.prompt_logprobs is not None and self.prompt_logprobs < -1:
raise ValueError(f"prompt_logprobs must be greater than or equal to -1, got {self.prompt_logprobs}.")

if not envs.FD_USE_GET_SAVE_OUTPUT_V1: # False (0)
if self.logprobs is not None and (self.logprobs < 0 or self.logprobs > 20):
raise ValueError("Invalid value for 'top_logprobs': must be between 0 and 20.")
if self.prompt_logprobs is not None:
raise ValueError("prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled.")
else: # True (1)
if self.logprobs is not None and self.logprobs < -1:
raise ValueError(f"logprobs must be a non-negative value or -1, got {self.logprobs}.")
if self.prompt_logprobs is not None and self.prompt_logprobs < -1:
raise ValueError(f"prompt_logprobs a must be non-negative value or -1, got {self.prompt_logprobs}.")

if not 0 <= self.seed <= 922337203685477580:
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
Expand Down
84 changes: 74 additions & 10 deletions fastdeploy/entrypoints/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ class EngineClient:
EngineClient is a class that handles the communication between the client and the server.
"""

def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers: int = 1):
def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers: int = 1, max_logprobs: int = 20):
self.fd_config = fd_config
self.tensor_parallel_size = self.fd_config.parallel_config.tensor_parallel_size
self.enable_mm = self.fd_config.model_config.enable_mm
self.max_logprobs = max_logprobs
input_processor = InputPreprocessor(
self.fd_config.model_config,
self.fd_config.structured_outputs_config.reasoning_parser,
Expand All @@ -70,6 +71,11 @@ def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers
)
self.enable_logprob = self.fd_config.model_config.enable_logprob
self.data_processor = input_processor.create_processor()
self.ori_vocab_size = (
len(self.data_processor.tokenizer.sp_model)
if hasattr(self.data_processor.tokenizer, "sp_model")
else len(self.data_processor.tokenizer.vocab)
)
self.max_model_len = self.fd_config.model_config.max_model_len
self.enable_prefix_caching = self.fd_config.cache_config.enable_prefix_caching
self.enable_splitwise = self.fd_config.scheduler_config.splitwise_role != "mixed"
Expand Down Expand Up @@ -424,6 +430,53 @@ def valid_parameters(self, data):
elif logprobs:
raise ParameterError("logprobs", "Invalid type for 'logprobs'")

max_logprobs = self.max_logprobs
if max_logprobs == -1:
max_logprobs = self.ori_vocab_size
if max_logprobs < -1:
err_msg = f"Invalid 'max_logprobs': must be >= -1, got {max_logprobs}."
api_server_logger.error(err_msg)
raise ValueError("max_logprobs", err_msg)
if max_logprobs > self.ori_vocab_size:
err_msg = f"Invalid 'max_logprobs': must be <= vocab_size {self.ori_vocab_size}, got {max_logprobs}."
api_server_logger.error(err_msg)
raise ValueError("max_logprobs", err_msg)

prompt_logprobs = data.get("prompt_logprobs", None)

if prompt_logprobs is not None:
if not self.enable_logprob:
err_msg = "`enable_logprob` is disabled, please enable it in startup config."
api_server_logger.error(err_msg)
raise ParameterError("prompt_logprobs", err_msg)

if not envs.FD_USE_GET_SAVE_OUTPUT_V1:
err_msg = "prompt_logprobs is not support when FD_USE_GET_SAVE_OUTPUT_V1 is disabled."
api_server_logger.error(err_msg)
raise ParameterError("prompt_logprobs", err_msg)

if self.enable_prefix_caching:
err_msg = "prompt_logprobs is not support when prefix caching is enabled."
api_server_logger.error(err_msg)
raise ParameterError("prompt_logprobs", err_msg)

if prompt_logprobs == -1 and self.ori_vocab_size > max_logprobs:
err_msg = f"The requested value of ({self.ori_vocab_size}) for prompt_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})"
api_server_logger.error(err_msg)
raise ValueError("prompt_logprobs", err_msg)

if prompt_logprobs < -1:
err_msg = (
f"prompt_logprobs must be a non-negative value or -1; the current value is {prompt_logprobs}."
)
api_server_logger.error(err_msg)
raise ValueError("prompt_logprobs", err_msg)

if prompt_logprobs > max_logprobs:
err_msg = f"Number of prompt_logprobs requested ({prompt_logprobs}) exceeds maximum allowed value ({max_logprobs})."
api_server_logger.error(err_msg)
raise ValueError("prompt_logprobs", err_msg)

# enable_logprob
if top_logprobs:
if not self.enable_logprob:
Expand All @@ -437,15 +490,26 @@ def valid_parameters(self, data):
api_server_logger.error(err_msg)
raise ParameterError("top_logprobs", err_msg)

if top_logprobs < 0:
err_msg = f"Invalid 'top_logprobs': must be >= 0, got {top_logprobs}."
api_server_logger.error(err_msg)
raise ParameterError("top_logprobs", err_msg)

if top_logprobs > 20:
err_msg = "Invalid value for 'top_logprobs': must be <= 20."
api_server_logger.error(err_msg)
raise ParameterError("top_logprobs", err_msg)
if not envs.FD_USE_GET_SAVE_OUTPUT_V1:
if top_logprobs < 0 or top_logprobs > 20:
err_msg = f"top_logprobs must be between 0 and 20; the current value is {top_logprobs}."
api_server_logger.error(err_msg)
raise ValueError("top_logprobs", err_msg)
else:
if top_logprobs == -1 and self.ori_vocab_size > max_logprobs:
err_msg = f"The requested value of ({self.ori_vocab_size}) for top_logprobs (-1) exceeds the maximum allowed value of ({max_logprobs})"
api_server_logger.error(err_msg)
raise ValueError("top_logprobs", err_msg)

if top_logprobs < -1:
err_msg = f"top_logprobs must be a non-negative value or -1; the current value is {top_logprobs}."
api_server_logger.error(err_msg)
raise ValueError("top_logprobs", err_msg)

if top_logprobs > max_logprobs:
err_msg = f"Number of logprobs requested ({top_logprobs}) exceeds maximum allowed value ({max_logprobs})."
api_server_logger.error(err_msg)
raise ValueError("top_logprobs", err_msg)

def check_health(self, time_interval_threashold=30):
"""
Expand Down
37 changes: 27 additions & 10 deletions fastdeploy/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,23 +335,40 @@ def _add_request(
current_sampling_params = sampling_params[i]
else:
current_sampling_params = sampling_params
if kwargs.get("stream") and current_sampling_params.prompt_logprobs is not None:
raise ValueError("prompt_logprobs is not supported with streaming.")

ori_vocab_size = (
len(self.llm_engine.data_processor.tokenizer.sp_model)
if hasattr(self.llm_engine.data_processor.tokenizer, "sp_model")
else len(self.llm_engine.data_processor.tokenizer.vocab)
)
max_logprobs = self.llm_engine.cfg.model_config.max_logprobs
if max_logprobs == -1:
max_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
max_logprobs = ori_vocab_size
if max_logprobs < -1:
raise ValueError(f"max_logprobs ({max_logprobs}) can't be less than -1.")
if max_logprobs > ori_vocab_size:
raise ValueError(f"max_logprobs ({max_logprobs}) exceeds vocabulary size ({ori_vocab_size}).")

if current_sampling_params.logprobs is not None:
num_logprobs = current_sampling_params.logprobs
if num_logprobs == -1:
num_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
if num_logprobs == -1 and ori_vocab_size > max_logprobs:
raise ValueError(
f"Number of logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})."
)
if num_logprobs > max_logprobs:
raise ValueError(
f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})."
)
if current_sampling_params.prompt_logprobs is not None:
if self.llm_engine.cfg.cache_config.enable_prefix_caching:
raise ValueError("prompt_logprobs is not supported with prefix caching enabled.")
if kwargs.get("stream"):
raise ValueError("prompt_logprobs is not supported with streaming.")
num_prompt_logprobs = current_sampling_params.prompt_logprobs
if num_prompt_logprobs == -1:
num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
if num_prompt_logprobs == -1 and ori_vocab_size > max_logprobs:
raise ValueError(
f"Number of prompt_logprobs(-1) requested ({ori_vocab_size}) exceeds maximum allowed value ({max_logprobs})."
)
if num_prompt_logprobs > max_logprobs:
raise ValueError(
f"Number of logprobs requested ({num_prompt_logprobs}) exceeds maximum allowed value ({max_logprobs})."
Expand Down Expand Up @@ -436,7 +453,7 @@ def _build_prompt_logprobs(
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()
result: Optional[PromptLogprobs] = []
result: Optional[PromptLogprobs] = [None]
# Make Logprob for each position.
for pos in range(num_prompt_tokens):
# Handle flattening.
Expand Down Expand Up @@ -548,11 +565,11 @@ def _run_engine(
result.outputs.logprobs = self._build_sample_logprobs(
result.outputs.top_logprobs, topk_logprobs
)
if result.prompt_logprobs_tensors and num_prompt_logprobs:
if result.prompt_logprobs is not None and num_prompt_logprobs is not None:
if num_prompt_logprobs == -1:
num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
result.prompt_logprobs = self._build_prompt_logprobs(
result.prompt_logprobs_tensors, num_prompt_logprobs
result.prompt_logprobs, num_prompt_logprobs
)

output[pos] = result
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ async def lifespan(app: FastAPI):
port=int(os.environ.get("INFERENCE_MSG_QUEUE_ID", "0")),
fd_config=fd_config,
workers=args.workers,
max_logprobs=args.max_logprobs,
)
await engine_client.connection_manager.initialize()
app.state.dynamic_load_weight = args.dynamic_load_weight
Expand Down
Loading
Loading