Skip to content

Commit

Permalink
feat(client): improve error message for http timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanblade committed Feb 15, 2024
1 parent b690482 commit af9fa20
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 3 deletions.
11 changes: 11 additions & 0 deletions src/prisma/_async_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,23 @@

import httpx

from .utils import convert_exc
from ._types import Method
from .errors import HTTPClientTimeoutError
from .http_abstract import AbstractHTTP, AbstractResponse

__all__ = ('HTTP', 'AsyncHTTP', 'Response', 'client')


_ASYNC_HTTP_EXC_MAPPING = {
httpx.TimeoutException: HTTPClientTimeoutError,
}


class AsyncHTTP(AbstractHTTP[httpx.AsyncClient, httpx.Response]):
session: httpx.AsyncClient

@convert_exc(_ASYNC_HTTP_EXC_MAPPING) # type: ignore[arg-type]
@override
async def download(self, url: str, dest: str) -> None:
async with self.session.stream('GET', url, timeout=None) as resp:
Expand All @@ -21,14 +29,17 @@ async def download(self, url: str, dest: str) -> None:
async for chunk in resp.aiter_bytes():
fd.write(chunk)

@convert_exc(_ASYNC_HTTP_EXC_MAPPING) # type: ignore[arg-type]
@override
async def request(self, method: Method, url: str, **kwargs: Any) -> 'Response':
return Response(await self.session.request(method, url, **kwargs))

@convert_exc(_ASYNC_HTTP_EXC_MAPPING) # type: ignore[arg-type]
@override
def open(self) -> None:
self.session = httpx.AsyncClient(**self.session_kwargs)

@convert_exc(_ASYNC_HTTP_EXC_MAPPING) # type: ignore[arg-type]
@override
async def close(self) -> None:
if self.should_close():
Expand Down
11 changes: 11 additions & 0 deletions src/prisma/_sync_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,23 @@

import httpx

from .utils import convert_exc
from ._types import Method
from .errors import HTTPClientTimeoutError
from .http_abstract import AbstractHTTP, AbstractResponse

__all__ = ('HTTP', 'SyncHTTP', 'Response', 'client')


_SYNC_HTTP_EXC_MAPPING = {
httpx.TimeoutException: HTTPClientTimeoutError,
}


class SyncHTTP(AbstractHTTP[httpx.Client, httpx.Response]):
session: httpx.Client

@convert_exc(_SYNC_HTTP_EXC_MAPPING) # type: ignore[arg-type]
@override
def download(self, url: str, dest: str) -> None:
with self.session.stream('GET', url, timeout=None) as resp:
Expand All @@ -20,14 +28,17 @@ def download(self, url: str, dest: str) -> None:
for chunk in resp.iter_bytes():
fd.write(chunk)

@convert_exc(_SYNC_HTTP_EXC_MAPPING) # type: ignore[arg-type]
@override
def request(self, method: Method, url: str, **kwargs: Any) -> 'Response':
return Response(self.session.request(method, url, **kwargs))

@convert_exc(_SYNC_HTTP_EXC_MAPPING) # type: ignore[arg-type]
@override
def open(self) -> None:
self.session = httpx.Client(**self.session_kwargs)

@convert_exc(_SYNC_HTTP_EXC_MAPPING) # type: ignore[arg-type]
@override
def close(self) -> None:
if self.should_close():
Expand Down
3 changes: 3 additions & 0 deletions src/prisma/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
FuncType = Callable[..., object]
CoroType = Callable[..., Coroutine[Any, Any, object]]

ExcT = TypeVar('ExcT', bound=BaseException)
ExcMapping = Mapping[Type[BaseException], Type[BaseException]]


@runtime_checkable
class InheritsGeneric(Protocol):
Expand Down
6 changes: 6 additions & 0 deletions src/prisma/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
'TableNotFoundError',
'RecordNotFoundError',
'HTTPClientClosedError',
'HTTPClientTimeoutError',
'ClientNotConnectedError',
'PrismaWarning',
'UnsupportedSubclassWarning',
Expand Down Expand Up @@ -44,6 +45,11 @@ def __init__(self) -> None:
super().__init__('Cannot make a request from a closed client.')


class HTTPClientTimeoutError(PrismaError):
def __init__(self) -> None:
super().__init__('HTTP operation has timed out.')


class UnsupportedDatabaseError(PrismaError):
context: str
database: str
Expand Down
59 changes: 57 additions & 2 deletions src/prisma/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import inspect
import logging
import warnings
import functools
import contextlib
from typing import TYPE_CHECKING, Any, Dict, Union, TypeVar, Iterator, NoReturn, Coroutine
from types import TracebackType
from typing import TYPE_CHECKING, Any, Dict, Type, Union, TypeVar, Callable, Iterator, NoReturn, Optional, Coroutine
from importlib.util import find_spec

from ._types import CoroType, FuncType, TypeGuard
from ._types import CoroType, FuncType, TypeGuard, ExcMapping

if TYPE_CHECKING:
from typing_extensions import TypeGuard
Expand Down Expand Up @@ -139,3 +141,56 @@ def make_optional(value: _T) -> _T | None:

def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
return isinstance(obj, dict)


# TODO: improve typing
class SyncAsyncContextDecorator(contextlib.ContextDecorator):
"""`ContextDecorator` compatible with sync/async functions."""

def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]: # type:ignore
@functools.wraps(func)
async def async_inner(*args: Any, **kwargs: Any) -> object:
async with self._recreate_cm(): # type: ignore
return await func(*args, **kwargs)

@functools.wraps(func)
def sync_inner(*args: Any, **kwargs: Any) -> object:
with self._recreate_cm(): # type: ignore
return func(*args, **kwargs)

if is_coroutine(func):
return async_inner
else:
return sync_inner


class convert_exc(SyncAsyncContextDecorator):
"""`SyncAsyncContextDecorator` to convert exceptions."""

def __init__(self, exc_mapping: ExcMapping) -> None:
self._exc_mapping = exc_mapping

def __enter__(self) -> 'convert_exc':
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
if exc is not None and exc_type is not None:
for source_exc_type, target_exc_type in self._exc_mapping.items():
if isinstance(exc, source_exc_type):
raise target_exc_type() from exc

async def __aenter__(self) -> 'convert_exc':
return self.__enter__()

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.__exit__(exc_type, exc, exc_tb)
12 changes: 11 additions & 1 deletion tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from prisma.http import HTTP
from prisma.utils import _NoneType
from prisma._types import Literal
from prisma.errors import HTTPClientClosedError
from prisma.errors import HTTPClientClosedError, HTTPClientTimeoutError

from .utils import patch_method

Expand Down Expand Up @@ -81,3 +81,13 @@ async def test_httpx_default_config(monkeypatch: 'MonkeyPatch') -> None:
'timeout': httpx.Timeout(30),
},
)


@pytest.mark.asyncio
async def test_http_timeout_error() -> None:
"""Ensure that `httpx.TimeoutException` is converted to `prisma.errors.HTTPClientTimeoutError`."""
http = HTTP(timeout=httpx.Timeout(1e-6))
http.open()
with pytest.raises(HTTPClientTimeoutError) as exc_info:
await http.request('GET', 'https://google.com')
assert isinstance(exc_info.value.__cause__, httpx.TimeoutException)

0 comments on commit af9fa20

Please sign in to comment.