Skip to content

Commit

Permalink
fix: handle stream api connection errors gracefully (#1636)
Browse files Browse the repository at this point in the history
Closes #1559

Thanks to @Ankush-Chander

(cherry picked from commit 4ab30b9)
  • Loading branch information
Ankush-Chander authored and frascuchon committed Aug 22, 2022
1 parent 0e8c635 commit 3fa15bc
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 10 deletions.
12 changes: 10 additions & 2 deletions src/rubrix/client/sdk/commons/api.py
Expand Up @@ -31,6 +31,7 @@
import httpx

from rubrix.client.sdk.client import AuthenticatedClient
from rubrix.client.sdk.commons.errors import GenericApiError
from rubrix.client.sdk.commons.errors_handler import handle_response_error
from rubrix.client.sdk.commons.models import (
BulkResponse,
Expand Down Expand Up @@ -121,12 +122,19 @@ def build_data_response(
response: httpx.Response, data_type: Type[T]
) -> Response[List[T]]:
if 200 <= response.status_code < 400:
parsed_response = [data_type(**json.loads(r)) for r in response.iter_lines()]
parsed_responses = []
for r in response.iter_lines():
parsed_record = json.loads(r)
try:
parsed_response = data_type(**parsed_record)
except Exception as err:
raise GenericApiError(**parsed_record) from None
parsed_responses.append(parsed_response)
return Response(
status_code=response.status_code,
content=b"",
headers=response.headers,
parsed=parsed_response,
parsed=parsed_responses,
)

content = next(response.iter_lines())
Expand Down
5 changes: 3 additions & 2 deletions src/rubrix/server/apis/v0/handlers/text2text.py
Expand Up @@ -35,6 +35,7 @@
Text2TextSearchResults,
)
from rubrix.server.errors import EntityNotFoundError
from rubrix.server.responses import StreamingResponseWithErrorHandling
from rubrix.server.security import auth
from rubrix.server.security.model import User
from rubrix.server.services.datasets import DatasetsService
Expand Down Expand Up @@ -204,7 +205,7 @@ def scan_data_response(
data_stream: Iterable[Text2TextRecord],
chunk_size: int = 1000,
limit: Optional[int] = None,
) -> StreamingResponse:
) -> StreamingResponseWithErrorHandling:
"""Generate an textual stream data response for a dataset scan"""

async def stream_generator(stream):
Expand All @@ -228,7 +229,7 @@ def grouper(n, iterable, fillvalue=None):
)
) + "\n"

return StreamingResponse(
return StreamingResponseWithErrorHandling(
stream_generator(data_stream), media_type="application/json"
)

Expand Down
9 changes: 5 additions & 4 deletions src/rubrix/server/apis/v0/handlers/text_classification.py
Expand Up @@ -42,6 +42,7 @@
)
from rubrix.server.apis.v0.validators.text_classification import DatasetValidator
from rubrix.server.errors import EntityNotFoundError
from rubrix.server.responses import StreamingResponseWithErrorHandling
from rubrix.server.security import auth
from rubrix.server.security.model import User
from rubrix.server.services.datasets import DatasetsService
Expand Down Expand Up @@ -226,7 +227,7 @@ def scan_data_response(
data_stream: Iterable[TextClassificationRecord],
chunk_size: int = 1000,
limit: Optional[int] = None,
) -> StreamingResponse:
) -> StreamingResponseWithErrorHandling:
"""Generate an textual stream data response for a dataset scan"""

async def stream_generator(stream):
Expand All @@ -240,8 +241,8 @@ def grouper(n, iterable, fillvalue=None):
stream = takeuntil(stream, limit=limit)

for batch in grouper(
n=chunk_size,
iterable=stream,
n=chunk_size,
iterable=stream,
):
filtered_records = filter(lambda r: r is not None, batch)
yield "\n".join(
Expand All @@ -250,7 +251,7 @@ def grouper(n, iterable, fillvalue=None):
)
) + "\n"

return StreamingResponse(
return StreamingResponseWithErrorHandling(
stream_generator(data_stream), media_type="application/json"
)

Expand Down
5 changes: 3 additions & 2 deletions src/rubrix/server/apis/v0/handlers/token_classification.py
Expand Up @@ -37,6 +37,7 @@
)
from rubrix.server.apis.v0.validators.token_classification import DatasetValidator
from rubrix.server.errors import EntityNotFoundError
from rubrix.server.responses import StreamingResponseWithErrorHandling
from rubrix.server.security import auth
from rubrix.server.security.model import User
from rubrix.server.services.datasets import DatasetsService
Expand Down Expand Up @@ -221,7 +222,7 @@ def scan_data_response(
data_stream: Iterable[TokenClassificationRecord],
chunk_size: int = 1000,
limit: Optional[int] = None,
) -> StreamingResponse:
) -> StreamingResponseWithErrorHandling:
"""Generate an textual stream data response for a dataset scan"""

async def stream_generator(stream):
Expand All @@ -245,7 +246,7 @@ def grouper(n, iterable, fillvalue=None):
)
) + "\n"

return StreamingResponse(
return StreamingResponseWithErrorHandling(
stream_generator(data_stream), media_type="application/json"
)

Expand Down
1 change: 1 addition & 0 deletions src/rubrix/server/responses/__init__.py
@@ -0,0 +1 @@
from .api_responses import *
14 changes: 14 additions & 0 deletions src/rubrix/server/responses/api_responses.py
@@ -0,0 +1,14 @@
from starlette.responses import JSONResponse, StreamingResponse
from starlette.types import Send

from rubrix.server.errors import APIErrorHandler


class StreamingResponseWithErrorHandling(StreamingResponse):

async def stream_response(self, send: Send) -> None:
try:
return await super().stream_response(send)
except Exception as ex:
json_response: JSONResponse = await APIErrorHandler.common_exception_handler(send, error=ex)
await send({"type": "http.response.body", "body": json_response.body, "more_body": False})

0 comments on commit 3fa15bc

Please sign in to comment.