Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions fastdeploy/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,29 @@ async def completion_full_generator(
if dealer is not None:
await self.engine_client.connection_manager.cleanup_request(request_id)

async def _echo_back_prompt(self, request, res, idx):
if res["outputs"].get("send_idx", -1) == 0 and request.echo:
if isinstance(request.prompt, list):
def _echo_back_prompt(self, request, idx):
"""
The echo pre-process of the smallest unit
"""
if isinstance(request.prompt, str):
prompt_text = request.prompt
elif isinstance(request.prompt, list):
if all(isinstance(item, str) for item in request.prompt):
prompt_text = request.prompt[idx]
elif all(isinstance(item, int) for item in request.prompt):
prompt_text = self.engine_client.data_processor.tokenizer.decode(request.prompt)
else:
prompt_text = request.prompt
res["outputs"]["text"] = prompt_text + (res["outputs"]["text"] or "")
prompt_text = self.engine_client.data_processor.tokenizer.decode(request.prompt[idx])
return prompt_text

async def _process_echo_logic(self, request, idx, res_outputs):
"""
Process the echo logic and return the modified text.
"""
if request.echo and res_outputs.get("send_idx", -1) == 0:
prompt_text = self._echo_back_prompt(request, idx)
res_outputs["text"] = prompt_text + (res_outputs["text"] or "")
return res_outputs

def calc_finish_reason(self, max_tokens, token_num, output, tool_called):
if max_tokens is None or token_num != max_tokens:
Expand Down Expand Up @@ -384,7 +400,7 @@ async def completion_stream_generator(
else:
arrival_time = res["metrics"]["arrival_time"] - inference_start_time[idx]

await self._echo_back_prompt(request, res, idx)
await self._process_echo_logic(request, idx, res["outputs"])
output = res["outputs"]
output_top_logprobs = output["top_logprobs"]
logprobs_res: Optional[CompletionLogprobs] = None
Expand Down Expand Up @@ -486,7 +502,6 @@ def request_output_to_completion_response(
final_res = final_res_batch[idx]
prompt_token_ids = prompt_batched_token_ids[idx]
assert prompt_token_ids is not None
prompt_text = request.prompt
completion_token_ids = completion_batched_token_ids[idx]

output = final_res["outputs"]
Expand All @@ -497,12 +512,9 @@ def request_output_to_completion_response(
aggregated_logprobs = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)

if request.echo:
assert prompt_text is not None
prompt_text = self._echo_back_prompt(request, idx)
token_ids = [*prompt_token_ids, *output["token_ids"]]
if isinstance(prompt_text, list):
output_text = prompt_text[idx] + output["text"]
else:
output_text = str(prompt_text) + output["text"]
output_text = prompt_text + output["text"]
else:
token_ids = output["token_ids"]
output_text = output["text"]
Expand Down
147 changes: 75 additions & 72 deletions tests/entrypoints/openai/test_completion_echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,23 @@
"""

import unittest
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock

from fastdeploy.entrypoints.openai.serving_completion import (
CompletionRequest,
OpenAIServingCompletion,
)


class YourClass:
async def _1(self, a, b, c):
if b["outputs"].get("send_idx", -1) == 0 and a.echo:
if isinstance(a.prompt, list):
text = a.prompt[c]
else:
text = a.prompt
b["outputs"]["text"] = text + (b["outputs"]["text"] or "")


class TestCompletionEcho(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.mock_engine = MagicMock()
self.completion_handler = None
self.mock_engine.data_processor.tokenizer.decode = lambda x: f"decoded_{x}"

def test_single_prompt_non_streaming(self):
"""测试单prompt非流式响应"""
"""Testing echo prompt in non-streaming of a single str prompt"""

def test_single_str_prompt_non_streaming(self):
self.completion_handler = OpenAIServingCompletion(
self.mock_engine, models=None, pid=123, ips=None, max_waiting_time=30
)
Expand Down Expand Up @@ -70,32 +62,41 @@ def test_single_prompt_non_streaming(self):

self.assertEqual(response.choices[0].text, "test prompt generated text")

async def test_echo_back_prompt_and_streaming(self):
"""测试_echo_back_prompt方法和流式响应的prompt拼接逻辑"""
"""Testing echo prompt in non-streaming of a single int prompt"""

def test_single_int_prompt_non_streaming(self):
self.completion_handler = OpenAIServingCompletion(
self.mock_engine, models=None, pid=123, ips=None, max_waiting_time=30
)

request = CompletionRequest(prompt="test prompt", max_tokens=10, stream=True, echo=True)

mock_response = {"outputs": {"text": "test output", "token_ids": [1, 2, 3], "finished": True}}

with patch.object(self.completion_handler, "_echo_back_prompt") as mock_echo:

def mock_echo_side_effect(req, res, idx):
res["outputs"]["text"] = req.prompt + res["outputs"]["text"]
request = CompletionRequest(prompt=[1, 2, 3], max_tokens=10, echo=True, logprobs=1)

mock_echo.side_effect = mock_echo_side_effect

await self.completion_handler._echo_back_prompt(request, mock_response, 0)
mock_output = {
"outputs": {
"text": " generated text",
"token_ids": [1, 2, 3],
"top_logprobs": {"token1": -0.1, "token2": -0.2},
"finished": True,
},
"output_token_ids": 3,
}
self.mock_engine.generate.return_value = [mock_output]

mock_echo.assert_called_once_with(request, mock_response, 0)
response = self.completion_handler.request_output_to_completion_response(
final_res_batch=[mock_output],
request=request,
request_id="test_id",
created_time=12345,
model_name="test_model",
prompt_batched_token_ids=[[1, 2]],
completion_batched_token_ids=[[3, 4, 5]],
text_after_process_list=["test prompt"],
)
self.assertEqual(response.choices[0].text, "decoded_[1, 2, 3] generated text")

self.assertEqual(mock_response["outputs"]["text"], "test prompttest output")
self.assertEqual(request.prompt, "test prompt")
"""Testing echo prompts in non-streaming of multiple str prompts"""

def test_multi_prompt_non_streaming(self):
"""测试多prompt非流式响应"""
def test_multi_str_prompt_non_streaming(self):
self.completion_handler = OpenAIServingCompletion(
self.mock_engine, models=None, pid=123, ips=None, max_waiting_time=30
)
Expand Down Expand Up @@ -129,72 +130,74 @@ def test_multi_prompt_non_streaming(self):
self.assertEqual(response.choices[0].text, "prompt1 response1")
self.assertEqual(response.choices[1].text, "prompt2 response2")

async def test_multi_prompt_streaming(self):
"""Testing echo prompts in non-streaming of multiple int prompts"""

def test_multi_int_prompt_non_streaming(self):
self.completion_handler = OpenAIServingCompletion(
self.mock_engine, models=None, pid=123, ips=None, max_waiting_time=30
)

request = CompletionRequest(prompt=["prompt1", "prompt2"], max_tokens=10, stream=True, echo=True)
request = CompletionRequest(prompt=[[1, 2, 3], [4, 5, 6]], max_tokens=10, echo=True)

mock_responses = [
{"outputs": {"text": " response1", "token_ids": [1, 2], "finished": True}},
{"outputs": {"text": " response2", "token_ids": [3, 4], "finished": True}},
mock_outputs = [
{
"outputs": {"text": " response1", "token_ids": [1, 2], "top_logprobs": None, "finished": True},
"output_token_ids": 2,
},
{
"outputs": {"text": " response2", "token_ids": [3, 4], "top_logprobs": None, "finished": True},
"output_token_ids": 2,
},
]
self.mock_engine.generate.return_value = mock_outputs

with patch.object(self.completion_handler, "_echo_back_prompt") as mock_echo:

def mock_echo_side_effect(req, res, idx):
res["outputs"]["text"] = req.prompt[idx] + res["outputs"]["text"]

mock_echo.side_effect = mock_echo_side_effect

await self.completion_handler._echo_back_prompt(request, mock_responses[0], 0)
await self.completion_handler._echo_back_prompt(request, mock_responses[1], 1)
response = self.completion_handler.request_output_to_completion_response(
final_res_batch=mock_outputs,
request=request,
request_id="test_id",
created_time=12345,
model_name="test_model",
prompt_batched_token_ids=[[1], [2]],
completion_batched_token_ids=[[1, 2], [3, 4]],
text_after_process_list=["prompt1", "prompt2"],
)

self.assertEqual(mock_echo.call_count, 2)
mock_echo.assert_any_call(request, mock_responses[0], 0)
mock_echo.assert_any_call(request, mock_responses[1], 1)
self.assertEqual(len(response.choices), 2)
self.assertEqual(response.choices[0].text, "decoded_[1, 2, 3] response1")
self.assertEqual(response.choices[1].text, "decoded_[4, 5, 6] response2")

self.assertEqual(mock_responses[0]["outputs"]["text"], "prompt1 response1")
self.assertEqual(mock_responses[1]["outputs"]["text"], "prompt2 response2")
self.assertEqual(request.prompt, ["prompt1", "prompt2"])
"""Testing echo prompts in streaming of a single str prompt"""

async def test_echo_back_prompt_and_streaming1(self):
request = CompletionRequest(echo=True, prompt=["Hello", "World"])
async def test_single_str_prompt_streaming(self):
request = CompletionRequest(prompt="test prompt", max_tokens=10, stream=True, echo=True)
res = {"outputs": {"send_idx": 0, "text": "!"}}
idx = 0

instance = OpenAIServingCompletion(self.mock_engine, models=None, pid=123, ips=None, max_waiting_time=30)
await instance._echo_back_prompt(request, res, idx)
self.assertEqual(res["outputs"]["text"], "Hello!")
res = await instance._process_echo_logic(request, idx, res["outputs"])
self.assertEqual(res["text"], "test prompt!")

"""Testing echo prompts in streaming of a single int prompt"""

async def test_1_prompt_is_string_and_send_idx_is_0(self):
request = CompletionRequest(echo=True, prompt="Hello")
async def test_single_int_prompt_streaming(self):
request = CompletionRequest(prompt=[1, 2, 3], max_tokens=10, stream=True, echo=True)
res = {"outputs": {"send_idx": 0, "text": "!"}}
idx = 0

instance = OpenAIServingCompletion(self.mock_engine, models=None, pid=123, ips=None, max_waiting_time=30)
await instance._echo_back_prompt(request, res, idx)
self.assertEqual(res["outputs"]["text"], "Hello!")
res = await instance._process_echo_logic(request, idx, res["outputs"])
self.assertEqual(res["text"], "decoded_[1, 2, 3]!")

async def test_1_send_idx_is_not_0(self):
request = CompletionRequest(echo=True, prompt="Hello")
res = {"outputs": {"send_idx": 1, "text": "!"}}
idx = 0

instance = OpenAIServingCompletion(self.mock_engine, models=None, pid=123, ips=None, max_waiting_time=30)
await instance._echo_back_prompt(request, res, idx)
self.assertEqual(res["outputs"]["text"], "!")
"""Testing echo prompts in streaming of multi str prompt"""

async def test_1_echo_is_false(self):
"""测试echo为False时,_echo_back_prompt不拼接prompt"""
request = CompletionRequest(echo=False, prompt="Hello")
async def test_multi_str_prompt_streaming(self):
request = CompletionRequest(prompt=["test prompt1", "test prompt2"], max_tokens=10, stream=True, echo=True)
res = {"outputs": {"send_idx": 0, "text": "!"}}
idx = 0

instance = OpenAIServingCompletion(self.mock_engine, models=None, pid=123, ips=None, max_waiting_time=30)
await instance._echo_back_prompt(request, res, idx)
self.assertEqual(res["outputs"]["text"], "!")
res = await instance._process_echo_logic(request, idx, res["outputs"])
self.assertEqual(res["text"], "test prompt1!")


if __name__ == "__main__":
Expand Down
Loading