Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions fastdeploy/entrypoints/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
"""

import inspect
import os
import time
import traceback
Expand Down Expand Up @@ -112,7 +113,7 @@ def create_zmq_client(self, model, mode):
self.zmq_client = ZmqClient(model, mode)
self.zmq_client.connect()

def format_and_add_data(self, prompts: dict):
async def format_and_add_data(self, prompts: dict):
"""
Format the request data and send the request to the server.
"""
Expand All @@ -123,10 +124,10 @@ def format_and_add_data(self, prompts: dict):
if "max_tokens" not in prompts:
prompts["max_tokens"] = self.max_model_len - 1

self.add_requests(prompts)
await self.add_requests(prompts)
return prompts["prompt_token_ids"]

def add_requests(self, task):
async def add_requests(self, task):
"""
Add a new request to the queue.

Expand All @@ -140,7 +141,10 @@ def add_requests(self, task):

task["preprocess_start_time"] = time.time()
try:
self.data_processor.process_request_dict(task, self.max_model_len)
if inspect.iscoroutinefunction(self.data_processor.process_request_dict):
await self.data_processor.process_request_dict(task, self.max_model_len)
else:
self.data_processor.process_request_dict(task, self.max_model_len)

task["prompt_token_ids_len"] = len(task["prompt_token_ids"])
input_ids_len = task["prompt_token_ids_len"]
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def create_chat_completion(self, request: ChatCompletionRequest):
if "chat_template" not in current_req_dict:
current_req_dict["chat_template"] = self.chat_template
current_req_dict["arrival_time"] = time.time()
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
prompt_token_ids = await self.engine_client.format_and_add_data(current_req_dict)
text_after_process = current_req_dict.get("text_after_process")
if isinstance(prompt_token_ids, np.ndarray):
prompt_token_ids = prompt_token_ids.tolist()
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ async def create_completion(self, request: CompletionRequest):
request_id_idx = f"{request_id}-{idx}"
current_req_dict = request.to_dict_for_infer(request_id_idx, prompt)
current_req_dict["arrival_time"] = time.time()
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) # tokenize
prompt_token_ids = await self.engine_client.format_and_add_data(current_req_dict) # tokenize
if isinstance(prompt_token_ids, np.ndarray):
prompt_token_ids = prompt_token_ids.tolist()
text_after_process_list.append(current_req_dict.get("text_after_process"))
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/test_custom_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def mock_chat_completion_full_generator(
):
return prompt_token_ids

def mock_format_and_add_data(current_req_dict):
async def mock_format_and_add_data(current_req_dict):
return current_req_dict

self.chat_completion_handler.chat_completion_full_generator = mock_chat_completion_full_generator
Expand All @@ -97,7 +97,7 @@ async def mock_chat_completion_full_generator(
):
return prompt_token_ids

def mock_format_and_add_data(current_req_dict):
async def mock_format_and_add_data(current_req_dict):
return current_req_dict

self.chat_completion_handler.chat_completion_full_generator = mock_chat_completion_full_generator
Expand Down
Loading