Skip to content

Commit

Permalink
Merge branch 'main' into aniket/optimize-batch-collate
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed Jun 20, 2024
2 parents 7bffe8d + 51b3358 commit 5bff63e
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 35 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/release-pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ jobs:
# We do this, since failures on test.pypi aren't that bad
# - name: Publish to Test PyPI
# if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
# uses: pypa/gh-action-pypi-publish@v1.8.14
# uses: pypa/gh-action-pypi-publish@v1.9.0
# with:
# user: __token__
# password: ${{ secrets.test_pypi_password }}
# repository_url: https://test.pypi.org/legacy/

- name: Publish distribution 📦 to PyPI
if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
uses: pypa/gh-action-pypi-publish@v1.8.14
uses: pypa/gh-action-pypi-publish@v1.9.0
with:
user: __token__
password: ${{ secrets.pypi_password }}
3 changes: 1 addition & 2 deletions _requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
coverage[toml] >=5.0
coverage[toml] >=7.5.3
pytest >=8.0
pytest-cov
mypy ==1.9.0
pytest-asyncio
asgi-lifespan
python-multipart
psutil

requests
lightning >2.0.0
torch >2.0.0
Expand Down
63 changes: 44 additions & 19 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ async def data_reader(self, read):
loop.remove_reader(read.fileno())
return read.recv()

async def win_data_streamer(self, read, write):
async def win_data_streamer(self, read, write, send_status=False):
# this is a workaround for Windows since asyncio loop.add_reader is not supported.
# https://docs.python.org/3/library/asyncio-platforms.html
while True:
Expand All @@ -506,31 +506,56 @@ async def win_data_streamer(self, read, write):
"Error occurred while streaming outputs from the inference worker. "
"Please check the above traceback."
)
yield response, status
return
yield response
if send_status:
yield response, status
else:
yield response

await asyncio.sleep(0.0001)

async def data_streamer(self, read: Connection, write: Connection):
async def data_streamer(self, read: Connection, write: Connection, send_status=False):
data_available = asyncio.Event()
while True:
# Calling poll blocks the event loop, so keep the timeout low
if not read.poll(0.001):
asyncio.get_event_loop().add_reader(read.fileno(), data_available.set)
queue = asyncio.Queue()
loop = asyncio.get_event_loop()

def reader():
try:
while read.poll(): # Check if there's data available to read
response, status = read.recv()
queue.put_nowait((response, status))
data_available.set()
except Exception as e:
logger.error(f"Exception in reader: {e}")

loop.add_reader(read.fileno(), reader)

try:
while True:
await data_available.wait()
data_available.clear()
asyncio.get_event_loop().remove_reader(read.fileno())
if read.poll(0.001):
response, status = read.recv()
if status == LitAPIStatus.FINISH_STREAMING:
return
if status == LitAPIStatus.ERROR:
logger.error(
"Error occurred while streaming outputs from the inference worker. "
"Please check the above traceback."
)
return
yield response

while not queue.empty():
response, status = await queue.get()
if status == LitAPIStatus.FINISH_STREAMING:
loop.remove_reader(read.fileno())
return
if status == LitAPIStatus.ERROR:
logger.error(
"Error occurred while streaming outputs from the inference worker. "
"Please check the above traceback."
)
loop.remove_reader(read.fileno())
if send_status:
yield response, status
return
if send_status:
yield response, status
else:
yield response
finally:
loop.remove_reader(read.fileno())

def cleanup_request(self, request_buffer, uid):
with contextlib.suppress(KeyError):
Expand Down
22 changes: 13 additions & 9 deletions src/litserve/specs/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field

from ..utils import azip
from ..utils import azip, LitAPIStatus, load_and_raise
from .base import LitSpec

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -272,9 +272,9 @@ async def get_from_pipes(self, uids, pipes) -> List[AsyncGenerator]:
choice_pipes = []
for uid, (read, write) in zip(uids, pipes):
if sys.version_info[0] == 3 and sys.version_info[1] >= 8 and sys.platform.startswith("win"):
data = self._server.win_data_streamer(read, write)
data = self._server.win_data_streamer(read, write, send_status=True)
else:
data = self._server.data_streamer(read, write)
data = self._server.data_streamer(read, write, send_status=True)

choice_pipes.append(data)
return choice_pipes
Expand Down Expand Up @@ -320,8 +320,10 @@ async def streaming_completion(self, request: ChatCompletionRequest, pipe_respon
usage = None
async for streaming_response in azip(*pipe_responses):
choices = []
for i, chat_msg in enumerate(streaming_response):
chat_msg = json.loads(chat_msg)
for i, (response, status) in enumerate(streaming_response):
if status == LitAPIStatus.ERROR:
load_and_raise(response)
chat_msg = json.loads(response)
logger.debug(chat_msg)
chat_msg = ChoiceDelta(**chat_msg)
choice = ChatCompletionStreamingChoice(
Expand All @@ -345,15 +347,17 @@ async def streaming_completion(self, request: ChatCompletionRequest, pipe_respon
yield f"data: {last_chunk}\n\n"
yield "data: [DONE]\n\n"

async def non_streaming_completion(self, request: ChatCompletionRequest, pipe_responses: List):
async def non_streaming_completion(self, request: ChatCompletionRequest, generator_list: List[AsyncGenerator]):
model = request.model
usage = UsageInfo()
choices = []
for i, streaming_response in enumerate(pipe_responses):
for i, streaming_response in enumerate(generator_list):
msgs = []
tool_calls = None
async for chat_msg in streaming_response:
chat_msg = json.loads(chat_msg)
async for response, status in streaming_response:
if status == LitAPIStatus.ERROR:
load_and_raise(response)
chat_msg = json.loads(response)
logger.debug(chat_msg)
chat_msg = ChatMessage(**chat_msg)
msgs.append(chat_msg.content)
Expand Down
5 changes: 3 additions & 2 deletions src/litserve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@ async def wait_for_queue_timeout(coro: Coroutine, timeout: Optional[float], uid:

def load_and_raise(response):
try:
pickle.loads(response)
raise HTTPException(500, "Internal Server Error")
exception = pickle.loads(response)
raise exception
except pickle.PickleError:
logger.exception(
f"main process failed to load the exception from the parallel worker process. "
f"{response} couldn't be unpickled."
)
raise


async def azip(*async_iterables):
Expand Down
15 changes: 14 additions & 1 deletion tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from multiprocessing import Pipe, Manager
from asgi_lifespan import LifespanManager
from litserve import LitAPI
from fastapi import Request, Response
from fastapi import Request, Response, HTTPException

import torch
import torch.nn as nn
Expand Down Expand Up @@ -416,3 +416,16 @@ def test_custom_api_path():
with TestClient(server.app) as client:
response = client.post(url, json={"input": 4.0})
assert response.status_code == 200, "Server response should be 200 (OK)"


class TestHTTPExceptionAPI(ls.examples.SimpleLitAPI):
def decode_request(self, request):
raise HTTPException(501, "decode request is bad")


def test_http_exception():
server = LitServer(TestHTTPExceptionAPI())
with TestClient(server.app) as client:
response = client.post("/predict", json={"input": 4.0})
assert response.status_code == 501, "Server raises 501 error"
assert response.text == '{"detail":"decode request is bad"}', "decode request is bad"
19 changes: 19 additions & 0 deletions tests/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pytest
from asgi_lifespan import LifespanManager
from fastapi import HTTPException
from httpx import AsyncClient
from litserve.examples.openai_spec_example import TestAPI, TestAPIWithCustomEncode, TestAPIWithToolCalls
from litserve.specs.openai import OpenAISpec, ChatMessage
Expand Down Expand Up @@ -129,3 +130,21 @@ async def test_oai_prepopulated_context(openai_request_data):
assert (
resp.json()["choices"][0]["message"]["content"] == "This is a"
), "OpenAISpec must return only 3 tokens as specified using `max_tokens` parameter"


class WrongLitAPI(ls.LitAPI):
def setup(self, device):
self.model = None

def predict(self, prompt):
yield "This is a sample generated text"
raise HTTPException(501, "test LitAPI.predict error")


@pytest.mark.asyncio()
async def test_fail_http(openai_request_data):
server = ls.LitServer(WrongLitAPI(), spec=ls.OpenAISpec())
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
res = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10)
assert res.status_code == 501, "Server raises 501 error"
assert res.text == '{"detail":"test LitAPI.predict error"}'

0 comments on commit 5bff63e

Please sign in to comment.