Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(#1559): handle api connection err gracefully #1636

Merged
merged 8 commits into from Aug 8, 2022
12 changes: 10 additions & 2 deletions src/rubrix/client/sdk/commons/api.py
Expand Up @@ -31,6 +31,7 @@
import httpx

from rubrix.client.sdk.client import AuthenticatedClient
from rubrix.client.sdk.commons.errors import GenericApiError
from rubrix.client.sdk.commons.errors_handler import handle_response_error
from rubrix.client.sdk.commons.models import (
BulkResponse,
Expand Down Expand Up @@ -112,12 +113,19 @@ def build_data_response(
response: httpx.Response, data_type: Type[T]
) -> Response[List[T]]:
if 200 <= response.status_code < 400:
parsed_response = [data_type(**json.loads(r)) for r in response.iter_lines()]
parsed_responses = []
for r in response.iter_lines():
parsed_record = json.loads(r)
try:
parsed_response = data_type(**parsed_record)
except Exception as err:
raise GenericApiError(message=parsed_record.get("message", "Cannot process response!"), error=parsed_record.get("type","ServerError")) from None
parsed_responses.append(parsed_response)
return Response(
status_code=response.status_code,
content=b"",
headers=response.headers,
parsed=parsed_response,
parsed=parsed_responses,
)

content = next(response.iter_lines())
Expand Down
26 changes: 15 additions & 11 deletions src/rubrix/server/apis/v0/handlers/text_classification.py
Expand Up @@ -19,6 +19,7 @@
from fastapi import APIRouter, Depends, Query, Security
from fastapi.responses import StreamingResponse

from rubrix.server import errors
from rubrix.server.apis.v0.config.tasks_factory import TaskFactory
from rubrix.server.apis.v0.handlers import text_classification_dataset_settings
from rubrix.server.apis.v0.helpers import takeuntil
Expand Down Expand Up @@ -238,17 +239,20 @@ def grouper(n, iterable, fillvalue=None):
if limit:
stream = takeuntil(stream, limit=limit)

for batch in grouper(
n=chunk_size,
iterable=stream,
):
filtered_records = filter(lambda r: r is not None, batch)
yield "\n".join(
map(
lambda r: r.json(by_alias=True, exclude_none=True), filtered_records
)
) + "\n"

try:
for batch in grouper(
n=chunk_size,
iterable=stream,
):
filtered_records = filter(lambda r: r is not None, batch)
yield "\n".join(
map(
lambda r: r.json(by_alias=True, exclude_none=True), filtered_records
)
) + "\n"
except Exception as error:
yield errors.exception_to_rubrix_error(error)
return
return StreamingResponse(
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
stream_generator(data_stream), media_type="application/json"
)
Expand Down
4 changes: 4 additions & 0 deletions src/rubrix/server/errors/base_errors.py
@@ -1,3 +1,4 @@
import json
from typing import Any, Dict, Optional, Type, Union

import pydantic
Expand Down Expand Up @@ -71,6 +72,9 @@ def api_documentation(cls):
},
}

def encode(self, charset="utf-8"):
return bytes(json.dumps({"type": self.type, "message":self.message, "args": self.args}), charset)

frascuchon marked this conversation as resolved.
Show resolved Hide resolved

class ForbiddenOperationError(RubrixServerError):
"""Forbidden operation"""
Expand Down