Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ def create_error_response(status: HTTPStatus, message: str):
status (HTTPStatus): HTTP status codes and reason phrases
message (str): error message
"""
return JSONResponse(ErrorResponse(message=message,
type='invalid_request_error').dict(),
status_code=status.value)
return JSONResponse(
ErrorResponse(message=message,
type='invalid_request_error',
code=status.value).dict())


async def check_request(request) -> Optional[JSONResponse]:
Expand Down Expand Up @@ -117,7 +118,7 @@ async def chat_completions_v1(request: ChatCompletionRequest,
result_generator = VariableInterface.async_engine.generate_openai(
request.messages,
instance_id,
request.stream,
True, # always use stream to enable batching
request.renew_session,
request_output_len=request.max_tokens if request.max_tokens else 512,
stop=request.stop,
Expand All @@ -130,7 +131,7 @@ async def abort_request() -> None:
async for _ in VariableInterface.async_engine.generate_openai(
request.messages,
instance_id,
request.stream,
True,
request.renew_session,
stop=True):
pass
Expand Down Expand Up @@ -188,18 +189,20 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:

# Non-streaming response
final_res = None
text = ''
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await abort_request()
return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected')
final_res = res
text += res.response
assert final_res is not None
choices = []
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role='assistant', content=final_res.response),
message=ChatMessage(role='assistant', content=text),
finish_reason=final_res.finish_reason,
)
choices.append(choice_data)
Expand Down Expand Up @@ -308,7 +311,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
finish_reason = None
async for out in generation:
text += out.response
tokens += out.generate_token_len
tokens = out.generate_token_len
finish_reason = out.finish_reason
ret = {'text': text, 'tokens': tokens, 'finish_reason': finish_reason}
return JSONResponse(ret)
Expand Down