Skip to content

Commit

Permalink
Merge pull request #136 from hnefatl/dev
Browse files Browse the repository at this point in the history
Allow returning raw bytes from a request, without attempting to decode.
  • Loading branch information
GrandMoff100 committed Jan 3, 2023
2 parents 208d068 + d1dd678 commit 2d69843
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
6 changes: 4 additions & 2 deletions homeassistant_api/errors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Module for custom error classes"""

from typing import Union


class HomeassistantAPIError(BaseException):
"""Base class for custom errors"""
Expand Down Expand Up @@ -40,10 +42,10 @@ class ParameterMissingError(HomeassistantAPIError):
class InternalServerError(HomeassistantAPIError):
"""Error raised when Home Assistant says that it got itself in trouble."""

def __init__(self, status_code: int, content: str) -> None:
def __init__(self, status_code: int, content: Union[str, bytes]) -> None:
super().__init__(
f"Home Assistant returned a response with an error status code {status_code!r}.\n"
f"{content}\n"
f"{content!r}\n"
"If this happened, "
"please report it at https://github.com/GrandMoff100/HomeAssistantAPI/issues "
"with the request status code and the request content. Thanks!"
Expand Down
20 changes: 12 additions & 8 deletions homeassistant_api/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ class Processing:
_response: AllResponseType
_processors: ClassVar[Dict[str, Tuple[ProcessorType, ...]]] = {}

def __init__(self, response: AllResponseType) -> None:
def __init__(self, response: AllResponseType, decode_bytes: bool = True) -> None:
self._response = response
self._decode_bytes = decode_bytes

@staticmethod
def processor(mimetype: str) -> Callable[[ProcessorType], ProcessorType]:
Expand Down Expand Up @@ -72,13 +73,17 @@ def process_content(self, *, async_: bool = False) -> Any:

def process(self) -> Any:
"""Validates the http status code before starting to process the repsonse content"""
content: Union[str, bytes]
if async_ := isinstance(self._response, (ClientResponse, AsyncCachedResponse)):
status_code = self._response.status
_buffer = self._response.content._buffer
content = "" if not _buffer else _buffer[0].decode()
content = b"" if not _buffer else _buffer[0]
elif isinstance(self._response, (Response, CachedResponse)):
status_code = self._response.status_code
content = self._response.content.decode()
content = self._response.content
if self._decode_bytes and isinstance(content, bytes):
content = content.decode()

if status_code in (200, 201):
return self.process_content(async_=async_)
if status_code == 400:
Expand All @@ -88,11 +93,10 @@ def process(self) -> Any:
if status_code == 404:
raise EndpointNotFoundError(str(self._response.url))
if status_code == 405:
method = (
self._response.request.method
if isinstance(self._response, (Response, CachedResponse))
else self._response.method
)
if isinstance(self._response, (Response, CachedResponse)):
method = self._response.request.method
else:
method = self._response.method
raise MethodNotAllowedError(cast(str, method))
if status_code >= 500:
raise InternalServerError(status_code, content)
Expand Down
7 changes: 4 additions & 3 deletions homeassistant_api/rawclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def request(
path,
method="GET",
headers: Dict[str, str] = None,
decode_bytes: bool = True,
**kwargs,
) -> Any:
"""Base method for making requests to the api"""
Expand All @@ -101,12 +102,12 @@ def request(
raise RequestTimeoutError(
f'Home Assistant did not respond in time (timeout: {kwargs.get("timeout", 300)} sec)'
) from err
return self.response_logic(resp)
return self.response_logic(response=resp, decode_bytes=decode_bytes)

@classmethod
def response_logic(cls, response: ResponseType) -> Any:
def response_logic(cls, response: ResponseType, decode_bytes: bool = True) -> Any:
"""Processes responses from the API and formats them"""
return Processing(response=response).process()
return Processing(response=response, decode_bytes=decode_bytes).process()

# API information methods
def get_error_log(self) -> str:
Expand Down

0 comments on commit 2d69843

Please sign in to comment.