diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 4ec2d58636..28096af9b6 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -57,9 +57,10 @@ def create_error_response(status: HTTPStatus, message: str): status (HTTPStatus): HTTP status codes and reason phrases message (str): error message """ - return JSONResponse(ErrorResponse(message=message, - type='invalid_request_error').dict(), - status_code=status.value) + return JSONResponse( + ErrorResponse(message=message, + type='invalid_request_error', + code=status.value).dict()) async def check_request(request) -> Optional[JSONResponse]: @@ -117,7 +118,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, result_generator = VariableInterface.async_engine.generate_openai( request.messages, instance_id, - request.stream, + True, # always use stream to enable batching request.renew_session, request_output_len=request.max_tokens if request.max_tokens else 512, stop=request.stop, @@ -130,7 +131,7 @@ async def abort_request() -> None: async for _ in VariableInterface.async_engine.generate_openai( request.messages, instance_id, - request.stream, + True, request.renew_session, stop=True): pass @@ -188,6 +189,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: # Non-streaming response final_res = None + text = '' async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. @@ -195,11 +197,12 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') final_res = res + text += res.response assert final_res is not None choices = [] choice_data = ChatCompletionResponseChoice( index=0, - message=ChatMessage(role='assistant', content=final_res.response), + message=ChatMessage(role='assistant', content=text), finish_reason=final_res.finish_reason, ) choices.append(choice_data) @@ -308,7 +311,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: finish_reason = None async for out in generation: text += out.response - tokens += out.generate_token_len + tokens = out.generate_token_len finish_reason = out.finish_reason ret = {'text': text, 'tokens': tokens, 'finish_reason': finish_reason} return JSONResponse(ret)