From 274d57a547137557d316c8a3fea640f6a3a8a827 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 20 Jun 2024 09:31:50 +0100 Subject: [PATCH 1/4] propagate error with OpenAISpec (#143) * send status * propagate error * fix * add test --- src/litserve/server.py | 16 ++++++++++++---- src/litserve/specs/openai.py | 22 +++++++++++++--------- tests/test_specs.py | 17 +++++++++++++++++ 3 files changed, 42 insertions(+), 13 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 341c2d48..8a7ec01f 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -478,7 +478,7 @@ async def data_reader(self, read): asyncio.get_event_loop().remove_reader(read.fileno()) return read.recv() - async def win_data_streamer(self, read, write): + async def win_data_streamer(self, read, write, send_status=False): # this is a workaround for Windows since asyncio loop.add_reader is not supported. # https://docs.python.org/3/library/asyncio-platforms.html while True: @@ -493,12 +493,16 @@ async def win_data_streamer(self, read, write): "Error occurred while streaming outputs from the inference worker. " "Please check the above traceback." ) + yield response, status return - yield response + if send_status: + yield response, status + else: + yield response await asyncio.sleep(0.0001) - async def data_streamer(self, read: Connection, write: Connection): + async def data_streamer(self, read: Connection, write: Connection, send_status=False): data_available = asyncio.Event() while True: # Calling poll blocks the event loop, so keep the timeout low @@ -516,8 +520,12 @@ async def data_streamer(self, read: Connection, write: Connection): "Error occurred while streaming outputs from the inference worker. " "Please check the above traceback." ) + yield response, status return - yield response + if send_status: + yield response, status + else: + yield response def cleanup_request(self, request_buffer, uid): with contextlib.suppress(KeyError): diff --git a/src/litserve/specs/openai.py b/src/litserve/specs/openai.py index 5ceeab8e..f6cf1778 100644 --- a/src/litserve/specs/openai.py +++ b/src/litserve/specs/openai.py @@ -26,7 +26,7 @@ from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field -from ..utils import azip +from ..utils import azip, LitAPIStatus, load_and_raise from .base import LitSpec if typing.TYPE_CHECKING: @@ -272,9 +272,9 @@ async def get_from_pipes(self, uids, pipes) -> List[AsyncGenerator]: choice_pipes = [] for uid, (read, write) in zip(uids, pipes): if sys.version_info[0] == 3 and sys.version_info[1] >= 8 and sys.platform.startswith("win"): - data = self._server.win_data_streamer(read, write) + data = self._server.win_data_streamer(read, write, send_status=True) else: - data = self._server.data_streamer(read, write) + data = self._server.data_streamer(read, write, send_status=True) choice_pipes.append(data) return choice_pipes @@ -320,8 +320,10 @@ async def streaming_completion(self, request: ChatCompletionRequest, pipe_respon usage = None async for streaming_response in azip(*pipe_responses): choices = [] - for i, chat_msg in enumerate(streaming_response): - chat_msg = json.loads(chat_msg) + for i, (response, status) in enumerate(streaming_response): + if status == LitAPIStatus.ERROR: + load_and_raise(response) + chat_msg = json.loads(response) logger.debug(chat_msg) chat_msg = ChoiceDelta(**chat_msg) choice = ChatCompletionStreamingChoice( @@ -345,15 +347,17 @@ async def streaming_completion(self, request: ChatCompletionRequest, pipe_respon yield f"data: {last_chunk}\n\n" yield "data: [DONE]\n\n" - async def non_streaming_completion(self, request: ChatCompletionRequest, pipe_responses: List): + async def non_streaming_completion(self, request: ChatCompletionRequest, generator_list: List[AsyncGenerator]): model = request.model usage = UsageInfo() choices = [] - for i, streaming_response in enumerate(pipe_responses): + for i, streaming_response in enumerate(generator_list): msgs = [] tool_calls = None - async for chat_msg in streaming_response: - chat_msg = json.loads(chat_msg) + async for response, status in streaming_response: + if status == LitAPIStatus.ERROR: + load_and_raise(response) + chat_msg = json.loads(response) logger.debug(chat_msg) chat_msg = ChatMessage(**chat_msg) msgs.append(chat_msg.content) diff --git a/tests/test_specs.py b/tests/test_specs.py index 3819acf1..1b34b08e 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -129,3 +129,20 @@ async def test_oai_prepopulated_context(openai_request_data): assert ( resp.json()["choices"][0]["message"]["content"] == "This is a" ), "OpenAISpec must return only 3 tokens as specified using `max_tokens` parameter" + + +class WrongLitAPI(ls.LitAPI): + def setup(self, device): + self.model = None + + def predict(self, prompt): + yield "This is a sample generated text" + raise Exception("random error") + + +@pytest.mark.asyncio() +async def test_fail_http(openai_request_data): + server = ls.LitServer(WrongLitAPI(), spec=ls.OpenAISpec()) + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) + assert resp.status_code == 500, "Server raises an exception so client should fail" From 53dd29ff0c19c3fca7c9d090c0c4c98cce4027a4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 20 Jun 2024 10:49:23 +0200 Subject: [PATCH 2/4] Bump pypa/gh-action-pypi-publish from 1.8.14 to 1.9.0 (#142) Bumps [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish) from 1.8.14 to 1.9.0. - [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) - [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/v1.8.14...v1.9.0) --- updated-dependencies: - dependency-name: pypa/gh-action-pypi-publish dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/release-pypi.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release-pypi.yml b/.github/workflows/release-pypi.yml index ee7a4da7..644b178a 100644 --- a/.github/workflows/release-pypi.yml +++ b/.github/workflows/release-pypi.yml @@ -29,7 +29,7 @@ jobs: # We do this, since failures on test.pypi aren't that bad # - name: Publish to Test PyPI # if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' - # uses: pypa/gh-action-pypi-publish@v1.8.14 + # uses: pypa/gh-action-pypi-publish@v1.9.0 # with: # user: __token__ # password: ${{ secrets.test_pypi_password }} @@ -37,7 +37,7 @@ jobs: - name: Publish distribution 📦 to PyPI if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' - uses: pypa/gh-action-pypi-publish@v1.8.14 + uses: pypa/gh-action-pypi-publish@v1.9.0 with: user: __token__ password: ${{ secrets.pypi_password }} From 178e6c13690b966e5b757131dc0836354b08176e Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 20 Jun 2024 12:16:54 +0100 Subject: [PATCH 3/4] remove busy wait from data_streamer (#140) * remove blocking wait from data_streamer * update * remove multiple get_event_loop * graceful error handling * fix consumer speed * add timeout * remove timeout * fix test --- src/litserve/server.py | 57 +++++++++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 8a7ec01f..311edd30 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -504,28 +504,45 @@ async def win_data_streamer(self, read, write, send_status=False): async def data_streamer(self, read: Connection, write: Connection, send_status=False): data_available = asyncio.Event() - while True: - # Calling poll blocks the event loop, so keep the timeout low - if not read.poll(0.001): - asyncio.get_event_loop().add_reader(read.fileno(), data_available.set) + queue = asyncio.Queue() + loop = asyncio.get_event_loop() + + def reader(): + try: + while read.poll(): # Check if there's data available to read + response, status = read.recv() + queue.put_nowait((response, status)) + data_available.set() + except Exception as e: + logger.error(f"Exception in reader: {e}") + + loop.add_reader(read.fileno(), reader) + + try: + while True: await data_available.wait() data_available.clear() - asyncio.get_event_loop().remove_reader(read.fileno()) - if read.poll(0.001): - response, status = read.recv() - if status == LitAPIStatus.FINISH_STREAMING: - return - if status == LitAPIStatus.ERROR: - logger.error( - "Error occurred while streaming outputs from the inference worker. " - "Please check the above traceback." - ) - yield response, status - return - if send_status: - yield response, status - else: - yield response + + while not queue.empty(): + response, status = await queue.get() + if status == LitAPIStatus.FINISH_STREAMING: + loop.remove_reader(read.fileno()) + return + if status == LitAPIStatus.ERROR: + logger.error( + "Error occurred while streaming outputs from the inference worker. " + "Please check the above traceback." + ) + loop.remove_reader(read.fileno()) + if send_status: + yield response, status + return + if send_status: + yield response, status + else: + yield response + finally: + loop.remove_reader(read.fileno()) def cleanup_request(self, request_buffer, uid): with contextlib.suppress(KeyError): From 51b3358fb2b48ff18878bcf0b2ce4e9e7210df56 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 20 Jun 2024 12:49:40 +0100 Subject: [PATCH 4/4] raise HTTPException (#145) --- _requirements/test.txt | 3 +-- src/litserve/utils.py | 5 +++-- tests/test_lit_server.py | 15 ++++++++++++++- tests/test_specs.py | 8 +++++--- 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/_requirements/test.txt b/_requirements/test.txt index 2e494614..c172c2f5 100644 --- a/_requirements/test.txt +++ b/_requirements/test.txt @@ -1,4 +1,4 @@ -coverage[toml] >=5.0 +coverage[toml] >=7.5.3 pytest >=8.0 pytest-cov mypy ==1.9.0 @@ -6,7 +6,6 @@ pytest-asyncio asgi-lifespan python-multipart psutil - requests lightning >2.0.0 torch >2.0.0 diff --git a/src/litserve/utils.py b/src/litserve/utils.py index c5570c6b..deec4961 100644 --- a/src/litserve/utils.py +++ b/src/litserve/utils.py @@ -49,13 +49,14 @@ async def wait_for_queue_timeout(coro: Coroutine, timeout: Optional[float], uid: def load_and_raise(response): try: - pickle.loads(response) - raise HTTPException(500, "Internal Server Error") + exception = pickle.loads(response) + raise exception except pickle.PickleError: logger.exception( f"main process failed to load the exception from the parallel worker process. " f"{response} couldn't be unpickled." ) + raise async def azip(*async_iterables): diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 297333a6..73d2c4be 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -18,7 +18,7 @@ from multiprocessing import Pipe, Manager from asgi_lifespan import LifespanManager from litserve import LitAPI -from fastapi import Request, Response +from fastapi import Request, Response, HTTPException import torch import torch.nn as nn @@ -416,3 +416,16 @@ def test_custom_api_path(): with TestClient(server.app) as client: response = client.post(url, json={"input": 4.0}) assert response.status_code == 200, "Server response should be 200 (OK)" + + +class TestHTTPExceptionAPI(ls.examples.SimpleLitAPI): + def decode_request(self, request): + raise HTTPException(501, "decode request is bad") + + +def test_http_exception(): + server = LitServer(TestHTTPExceptionAPI()) + with TestClient(server.app) as client: + response = client.post("/predict", json={"input": 4.0}) + assert response.status_code == 501, "Server raises 501 error" + assert response.text == '{"detail":"decode request is bad"}', "decode request is bad" diff --git a/tests/test_specs.py b/tests/test_specs.py index 1b34b08e..bba0898b 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -14,6 +14,7 @@ import pytest from asgi_lifespan import LifespanManager +from fastapi import HTTPException from httpx import AsyncClient from litserve.examples.openai_spec_example import TestAPI, TestAPIWithCustomEncode, TestAPIWithToolCalls from litserve.specs.openai import OpenAISpec, ChatMessage @@ -137,12 +138,13 @@ def setup(self, device): def predict(self, prompt): yield "This is a sample generated text" - raise Exception("random error") + raise HTTPException(501, "test LitAPI.predict error") @pytest.mark.asyncio() async def test_fail_http(openai_request_data): server = ls.LitServer(WrongLitAPI(), spec=ls.OpenAISpec()) async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) - assert resp.status_code == 500, "Server raises an exception so client should fail" + res = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) + assert res.status_code == 501, "Server raises 501 error" + assert res.text == '{"detail":"test LitAPI.predict error"}'