From a106ec4541ae9c90e61c835173c11876d249e775 Mon Sep 17 00:00:00 2001 From: Ankush Chander Date: Mon, 8 Aug 2022 14:27:06 +0530 Subject: [PATCH] fix: handle stream api connection errors gracefully (#1636) Closes #1559 Thanks to @Ankush-Chander (cherry picked from commit 4ab30b935c96f41a6fa4a75e5b42dc28a88b7495) --- src/rubrix/client/sdk/commons/api.py | 12 ++++++++++-- src/rubrix/server/apis/v0/handlers/text2text.py | 5 +++-- .../server/apis/v0/handlers/text_classification.py | 9 +++++---- .../apis/v0/handlers/token_classification.py | 5 +++-- src/rubrix/server/responses/__init__.py | 1 + src/rubrix/server/responses/api_responses.py | 14 ++++++++++++++ 6 files changed, 36 insertions(+), 10 deletions(-) create mode 100644 src/rubrix/server/responses/__init__.py create mode 100644 src/rubrix/server/responses/api_responses.py diff --git a/src/rubrix/client/sdk/commons/api.py b/src/rubrix/client/sdk/commons/api.py index 0f159acdde..ed5def7f2f 100644 --- a/src/rubrix/client/sdk/commons/api.py +++ b/src/rubrix/client/sdk/commons/api.py @@ -31,6 +31,7 @@ import httpx from rubrix.client.sdk.client import AuthenticatedClient +from rubrix.client.sdk.commons.errors import GenericApiError from rubrix.client.sdk.commons.errors_handler import handle_response_error from rubrix.client.sdk.commons.models import ( BulkResponse, @@ -121,12 +122,19 @@ def build_data_response( response: httpx.Response, data_type: Type[T] ) -> Response[List[T]]: if 200 <= response.status_code < 400: - parsed_response = [data_type(**json.loads(r)) for r in response.iter_lines()] + parsed_responses = [] + for r in response.iter_lines(): + parsed_record = json.loads(r) + try: + parsed_response = data_type(**parsed_record) + except Exception as err: + raise GenericApiError(**parsed_record) from None + parsed_responses.append(parsed_response) return Response( status_code=response.status_code, content=b"", headers=response.headers, - parsed=parsed_response, + parsed=parsed_responses, ) content = next(response.iter_lines()) diff --git a/src/rubrix/server/apis/v0/handlers/text2text.py b/src/rubrix/server/apis/v0/handlers/text2text.py index b11f3a9946..a2edd5aaa7 100644 --- a/src/rubrix/server/apis/v0/handlers/text2text.py +++ b/src/rubrix/server/apis/v0/handlers/text2text.py @@ -35,6 +35,7 @@ Text2TextSearchResults, ) from rubrix.server.errors import EntityNotFoundError +from rubrix.server.responses import StreamingResponseWithErrorHandling from rubrix.server.security import auth from rubrix.server.security.model import User from rubrix.server.services.datasets import DatasetsService @@ -204,7 +205,7 @@ def scan_data_response( data_stream: Iterable[Text2TextRecord], chunk_size: int = 1000, limit: Optional[int] = None, -) -> StreamingResponse: +) -> StreamingResponseWithErrorHandling: """Generate an textual stream data response for a dataset scan""" async def stream_generator(stream): @@ -228,7 +229,7 @@ def grouper(n, iterable, fillvalue=None): ) ) + "\n" - return StreamingResponse( + return StreamingResponseWithErrorHandling( stream_generator(data_stream), media_type="application/json" ) diff --git a/src/rubrix/server/apis/v0/handlers/text_classification.py b/src/rubrix/server/apis/v0/handlers/text_classification.py index b42f568797..976646c1a3 100644 --- a/src/rubrix/server/apis/v0/handlers/text_classification.py +++ b/src/rubrix/server/apis/v0/handlers/text_classification.py @@ -42,6 +42,7 @@ ) from rubrix.server.apis.v0.validators.text_classification import DatasetValidator from rubrix.server.errors import EntityNotFoundError +from rubrix.server.responses import StreamingResponseWithErrorHandling from rubrix.server.security import auth from rubrix.server.security.model import User from rubrix.server.services.datasets import DatasetsService @@ -226,7 +227,7 @@ def scan_data_response( data_stream: Iterable[TextClassificationRecord], chunk_size: int = 1000, limit: Optional[int] = None, -) -> StreamingResponse: +) -> StreamingResponseWithErrorHandling: """Generate an textual stream data response for a dataset scan""" async def stream_generator(stream): @@ -240,8 +241,8 @@ def grouper(n, iterable, fillvalue=None): stream = takeuntil(stream, limit=limit) for batch in grouper( - n=chunk_size, - iterable=stream, + n=chunk_size, + iterable=stream, ): filtered_records = filter(lambda r: r is not None, batch) yield "\n".join( @@ -250,7 +251,7 @@ def grouper(n, iterable, fillvalue=None): ) ) + "\n" - return StreamingResponse( + return StreamingResponseWithErrorHandling( stream_generator(data_stream), media_type="application/json" ) diff --git a/src/rubrix/server/apis/v0/handlers/token_classification.py b/src/rubrix/server/apis/v0/handlers/token_classification.py index 1df960a42c..ed56beff43 100644 --- a/src/rubrix/server/apis/v0/handlers/token_classification.py +++ b/src/rubrix/server/apis/v0/handlers/token_classification.py @@ -37,6 +37,7 @@ ) from rubrix.server.apis.v0.validators.token_classification import DatasetValidator from rubrix.server.errors import EntityNotFoundError +from rubrix.server.responses import StreamingResponseWithErrorHandling from rubrix.server.security import auth from rubrix.server.security.model import User from rubrix.server.services.datasets import DatasetsService @@ -221,7 +222,7 @@ def scan_data_response( data_stream: Iterable[TokenClassificationRecord], chunk_size: int = 1000, limit: Optional[int] = None, -) -> StreamingResponse: +) -> StreamingResponseWithErrorHandling: """Generate an textual stream data response for a dataset scan""" async def stream_generator(stream): @@ -245,7 +246,7 @@ def grouper(n, iterable, fillvalue=None): ) ) + "\n" - return StreamingResponse( + return StreamingResponseWithErrorHandling( stream_generator(data_stream), media_type="application/json" ) diff --git a/src/rubrix/server/responses/__init__.py b/src/rubrix/server/responses/__init__.py new file mode 100644 index 0000000000..26c7490261 --- /dev/null +++ b/src/rubrix/server/responses/__init__.py @@ -0,0 +1 @@ +from .api_responses import * \ No newline at end of file diff --git a/src/rubrix/server/responses/api_responses.py b/src/rubrix/server/responses/api_responses.py new file mode 100644 index 0000000000..abcf69c9bb --- /dev/null +++ b/src/rubrix/server/responses/api_responses.py @@ -0,0 +1,14 @@ +from starlette.responses import JSONResponse, StreamingResponse +from starlette.types import Send + +from rubrix.server.errors import APIErrorHandler + + +class StreamingResponseWithErrorHandling(StreamingResponse): + + async def stream_response(self, send: Send) -> None: + try: + return await super().stream_response(send) + except Exception as ex: + json_response: JSONResponse = await APIErrorHandler.common_exception_handler(send, error=ex) + await send({"type": "http.response.body", "body": json_response.body, "more_body": False}) \ No newline at end of file