Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion fastdeploy/entrypoints/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,15 @@ async def add_requests(self, task):

task["prompt_token_ids_len"] = len(task["prompt_token_ids"])
input_ids_len = task["prompt_token_ids_len"]
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))

completion_token_len = len(task["completion_token_ids"]) if task.get("completion_token_ids") else 0
task["max_tokens"] = min(
self.max_model_len - input_ids_len, max(0, task.get("max_tokens") - completion_token_len)
)

if task.get("min_tokens") is not None:
task["min_tokens"] = max(1, task["min_tokens"] - completion_token_len)

min_tokens = task.get("min_tokens", 1)
if "messages" in task:
del task["messages"]
Expand Down
12 changes: 10 additions & 2 deletions fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,11 @@ def process_request_dict(self, request, max_model_len=None):
else:
raise ValueError(f"Request must contain 'prompt', or 'messages': {request}")

completion_token_len = 0
if request.get("completion_token_ids"):
completion_token_len = len(request.get("completion_token_ids"))
self.append_completion_tokens(outputs, request["completion_token_ids"])

outputs = self.pack_outputs(outputs)
request["prompt_token_ids"] = outputs["input_ids"].tolist()
request["prompt_token_ids_len"] = len(request["prompt_token_ids"])
Expand All @@ -251,12 +254,17 @@ def process_request_dict(self, request, max_model_len=None):
# 截断超过长度限制的prompt
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]

tmp_max_tokens = 0
if request.get("max_tokens") is None:
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
tmp_max_tokens = request["max_tokens"]
else:
request["max_tokens"] = min(max_model_len - len(request["prompt_token_ids"]), request["max_tokens"])
tmp_max_tokens = min(
max_model_len - len(request["prompt_token_ids"]), max(0, request["max_tokens"] - completion_token_len)
)
if request.get("reasoning_max_tokens") is None:
request["reasoning_max_tokens"] = max(int(request["max_tokens"] * 0.8), 1)
request["reasoning_max_tokens"] = max(int(tmp_max_tokens * 0.8), 1)
data_processor_logger.info(f"Processed request {request}")

return request
Expand Down
139 changes: 139 additions & 0 deletions tests/entrypoints/openai/test_max_and_min_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import unittest
from unittest.mock import MagicMock, patch

from fastdeploy.entrypoints.engine_client import EngineClient, EngineError
from fastdeploy.input.ernie4_5_vl_processor.ernie4_5_vl_processor import (
Ernie4_5_VLProcessor,
)


class TestChatContinuationPreprocess(unittest.IsolatedAsyncioTestCase):

async def asyncSetUp(self):
with patch(
"fastdeploy.input.ernie4_5_vl_processor.ernie4_5_vl_processor.DataProcessor"
) as mock_data_processor:
mock_ernie4_5_processor = MagicMock()
mock_data_processor.return_value = mock_ernie4_5_processor

mock_tokenizer = MagicMock()
mock_tokenizer.eos_token_id = 102
mock_tokenizer.pad_token_id = 0
mock_ernie4_5_processor.tokenizer = mock_tokenizer
mock_ernie4_5_processor.eval = MagicMock()
mock_ernie4_5_processor.image_patch_id = MagicMock()
mock_ernie4_5_processor.spatial_conv_size = MagicMock()

self.ernie_processor = Ernie4_5_VLProcessor(model_name_or_path="mock_model_path")
self.ernie_processor.ernie4_5_processor = mock_ernie4_5_processor

def _create_mock_tensor(initial_ids):
mock_tensor = MagicMock()
mock_tensor._data = initial_ids
mock_tensor.extend = lambda x: mock_tensor._data.extend(x)
mock_tensor.tolist = lambda: mock_tensor._data
return mock_tensor

self.ernie_processor.ernie4_5_processor.request2ids.return_value = {
"input_ids": _create_mock_tensor([101] * 200)
}
self.ernie_processor.pack_outputs = lambda x: x

def mock_append_completion_tokens(multimodal_inputs, completion_token_ids):
multimodal_inputs["input_ids"].extend(completion_token_ids)

self.ernie_processor.append_completion_tokens = MagicMock(side_effect=mock_append_completion_tokens)
self.ernie_processor.eos_token_ids = [102]
self.ernie_processor._parse_limits = MagicMock(return_value=None)

with patch.object(EngineClient, "__init__", return_value=None):
self.engine_client = EngineClient("mock_model_path")
self.engine_client.data_processor = self.ernie_processor
self.engine_client.max_model_len = 300
self.engine_client.enable_mm = False
self.engine_client.enable_prefix_caching = False
self.engine_client.zmq_client = MagicMock()
self.engine_client.valid_parameters = MagicMock()

self.mock_api_logger = patch("fastdeploy.entrypoints.engine_client.api_server_logger").start()
self.mock_data_logger = patch(
"fastdeploy.input.ernie4_5_vl_processor.ernie4_5_vl_processor.data_processor_logger"
).start()

async def asyncTearDown(self):
patch.stopall()

def _update_processor_token_ids(self, prompt_token_ids_len: int):
def _create_mock_tensor(initial_ids):
mock_tensor = MagicMock()
mock_tensor._data = initial_ids
mock_tensor.extend = lambda x: mock_tensor._data.extend(x)
mock_tensor.tolist = lambda: mock_tensor._data
return mock_tensor

self.ernie_processor.ernie4_5_processor.request2ids.return_value = {
"input_ids": _create_mock_tensor([101] * prompt_token_ids_len)
}

@patch("uuid.uuid4", return_value="test-request-id")
async def test_continuation_first_request(self, mock_uuid):
request = {"messages": [{"role": "user", "content": "描述这张图片"}], "max_tokens": 50, "min_tokens": 10}

await self.engine_client.format_and_add_data(request)

self.assertEqual(request["max_tokens"], 50)
self.assertEqual(request["min_tokens"], 10)
self.assertEqual(len(request["prompt_token_ids"]), 200)

@patch("uuid.uuid4", return_value="test-request-id-2")
async def test_continuation_second_request(self, mock_uuid):
self._update_processor_token_ids(prompt_token_ids_len=50)

request = {
"messages": [{"role": "user", "content": "描述这张图片"}],
"completion_token_ids": [103] * 30,
"max_tokens": 200,
"min_tokens": 100,
}

await self.engine_client.format_and_add_data(request)

self.assertEqual(request["max_tokens"], 170)
self.assertEqual(request["min_tokens"], 70)
self.assertEqual(len(request["prompt_token_ids"]), 80)

@patch("uuid.uuid4", return_value="test-request-id-3")
async def test_continuation_boundary_max_tokens_exhausted(self, mock_uuid):
self._update_processor_token_ids(prompt_token_ids_len=100)

request = {
"messages": [{"role": "user", "content": "描述这张图片"}],
"completion_token_ids": [103] * 190,
"max_tokens": 200,
"min_tokens": 5,
}

await self.engine_client.format_and_add_data(request)

self.assertEqual(request["max_tokens"], 10)
self.assertEqual(request["min_tokens"], 1)

@patch("uuid.uuid4", return_value="test-request-id-4")
async def test_continuation_boundary_no_capacity(self, mock_uuid):
self._update_processor_token_ids(prompt_token_ids_len=260)

request = {
"messages": [{"role": "user", "content": "描述这张图片"}],
"completion_token_ids": [103] * 50,
"max_tokens": 200,
"min_tokens": 5,
}

with self.assertRaises(EngineError) as ctx:
await self.engine_client.format_and_add_data(request)

self.assertIn("Input text is too long", str(ctx.exception))


if __name__ == "__main__":
unittest.main()