Skip to content

Commit

Permalink
feat: improve error handling
Browse files Browse the repository at this point in the history
Signed-off-by: Tomas Dvorak <toomas2d@gmail.com>
  • Loading branch information
Tomas2D committed Nov 10, 2023
1 parent 92b58e4 commit b8016b3
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
7 changes: 5 additions & 2 deletions src/genai/exceptions/genai_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@


class GenAiException(Exception):
def __init__(self, error: Union[Exception, Response, str]) -> None:
if isinstance(error, Response):
def __init__(self, error: Union[Exception, Response, str, ErrorResponse]) -> None:
if isinstance(error, ErrorResponse):
self.error = error
self.error_message = self.error.message
elif isinstance(error, Response):
try:
self.error = ErrorResponse(**error.json())
self.error_message = self.error.message
Expand Down
9 changes: 5 additions & 4 deletions src/genai/schemas/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import Enum
from typing import Any, List, Optional, Type, Union

from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator

from genai.schemas.generate_params import GenerateParams

Expand Down Expand Up @@ -179,12 +179,13 @@ class ErrorExtensionState(GenAiResponseModel):

class ErrorExtensions(GenAiResponseModel):
code: str
state: list[ErrorExtensionState]
state: Optional[list[ErrorExtensionState]] = None
reason: Optional[str] = None


class ErrorResponse(GenAiResponseModel):
status_code: int
error: str
status_code: int = Field(validation_alias=AliasChoices("status_code", "statusCode"))
error: str = ""
message: str
extensions: Optional[ErrorExtensions] = None

Expand Down
10 changes: 10 additions & 0 deletions src/genai/services/request_handler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
from typing import Optional

Expand All @@ -7,6 +8,7 @@
from genai._version import version
from genai.exceptions import GenAiException
from genai.options import Options
from genai.schemas.responses import ErrorResponse
from genai.services.connection_manager import ConnectionManager
from genai.utils.http_provider import HttpProvider

Expand Down Expand Up @@ -297,6 +299,14 @@ def post_stream(endpoint, headers, json_data, files):
) as event_source:
try:
for sse in event_source.iter_sse():
if sse.event == "error":
if sse.data.startswith("{") and sse.data.endswith("}"):
raise GenAiException(ErrorResponse(**json.loads(sse.data)))

raise GenAiException(
f"Invalid server response during streaming!\nRetrieved data: {sse.data}"
)

yield sse.data
except SSEError as e:
response: Response = event_source.response
Expand Down

0 comments on commit b8016b3

Please sign in to comment.