diff --git a/src/prisma/_async_http.py b/src/prisma/_async_http.py index 615b01bf8..495c52a51 100644 --- a/src/prisma/_async_http.py +++ b/src/prisma/_async_http.py @@ -4,15 +4,23 @@ import httpx +from .utils import convert_exc from ._types import Method +from .errors import HTTPClientTimeoutError from .http_abstract import AbstractHTTP, AbstractResponse __all__ = ('HTTP', 'AsyncHTTP', 'Response', 'client') +_ASYNC_HTTP_EXC_MAPPING = { + httpx.TimeoutException: HTTPClientTimeoutError, +} + + class AsyncHTTP(AbstractHTTP[httpx.AsyncClient, httpx.Response]): session: httpx.AsyncClient + @convert_exc(_ASYNC_HTTP_EXC_MAPPING) # type: ignore[arg-type] @override async def download(self, url: str, dest: str) -> None: async with self.session.stream('GET', url, timeout=None) as resp: @@ -21,14 +29,17 @@ async def download(self, url: str, dest: str) -> None: async for chunk in resp.aiter_bytes(): fd.write(chunk) + @convert_exc(_ASYNC_HTTP_EXC_MAPPING) # type: ignore[arg-type] @override async def request(self, method: Method, url: str, **kwargs: Any) -> 'Response': return Response(await self.session.request(method, url, **kwargs)) + @convert_exc(_ASYNC_HTTP_EXC_MAPPING) # type: ignore[arg-type] @override def open(self) -> None: self.session = httpx.AsyncClient(**self.session_kwargs) + @convert_exc(_ASYNC_HTTP_EXC_MAPPING) # type: ignore[arg-type] @override async def close(self) -> None: if self.should_close(): diff --git a/src/prisma/_sync_http.py b/src/prisma/_sync_http.py index 9228fb2b0..2636f6fd8 100644 --- a/src/prisma/_sync_http.py +++ b/src/prisma/_sync_http.py @@ -3,15 +3,23 @@ import httpx +from .utils import convert_exc from ._types import Method +from .errors import HTTPClientTimeoutError from .http_abstract import AbstractHTTP, AbstractResponse __all__ = ('HTTP', 'SyncHTTP', 'Response', 'client') +_SYNC_HTTP_EXC_MAPPING = { + httpx.TimeoutException: HTTPClientTimeoutError, +} + + class SyncHTTP(AbstractHTTP[httpx.Client, httpx.Response]): session: httpx.Client + @convert_exc(_SYNC_HTTP_EXC_MAPPING) # type: ignore[arg-type] @override def download(self, url: str, dest: str) -> None: with self.session.stream('GET', url, timeout=None) as resp: @@ -20,14 +28,17 @@ def download(self, url: str, dest: str) -> None: for chunk in resp.iter_bytes(): fd.write(chunk) + @convert_exc(_SYNC_HTTP_EXC_MAPPING) # type: ignore[arg-type] @override def request(self, method: Method, url: str, **kwargs: Any) -> 'Response': return Response(self.session.request(method, url, **kwargs)) + @convert_exc(_SYNC_HTTP_EXC_MAPPING) # type: ignore[arg-type] @override def open(self) -> None: self.session = httpx.Client(**self.session_kwargs) + @convert_exc(_SYNC_HTTP_EXC_MAPPING) # type: ignore[arg-type] @override def close(self) -> None: if self.should_close(): diff --git a/src/prisma/_types.py b/src/prisma/_types.py index b5fbb7a88..bb3fda9f4 100644 --- a/src/prisma/_types.py +++ b/src/prisma/_types.py @@ -23,6 +23,9 @@ FuncType = Callable[..., object] CoroType = Callable[..., Coroutine[Any, Any, object]] +ExcT = TypeVar('ExcT', bound=BaseException) +ExcMapping = Mapping[Type[BaseException], Type[BaseException]] + @runtime_checkable class InheritsGeneric(Protocol): diff --git a/src/prisma/errors.py b/src/prisma/errors.py index e2aca4af2..a6159d502 100644 --- a/src/prisma/errors.py +++ b/src/prisma/errors.py @@ -12,6 +12,7 @@ 'TableNotFoundError', 'RecordNotFoundError', 'HTTPClientClosedError', + 'HTTPClientTimeoutError', 'ClientNotConnectedError', 'PrismaWarning', 'UnsupportedSubclassWarning', @@ -44,6 +45,11 @@ def __init__(self) -> None: super().__init__('Cannot make a request from a closed client.') +class HTTPClientTimeoutError(PrismaError): + def __init__(self) -> None: + super().__init__('HTTP operation has timed out.') + + class UnsupportedDatabaseError(PrismaError): context: str database: str diff --git a/src/prisma/utils.py b/src/prisma/utils.py index 69265291f..9af2fc8cc 100644 --- a/src/prisma/utils.py +++ b/src/prisma/utils.py @@ -6,11 +6,13 @@ import inspect import logging import warnings +import functools import contextlib -from typing import TYPE_CHECKING, Any, Dict, Union, TypeVar, Iterator, NoReturn, Coroutine +from types import TracebackType +from typing import TYPE_CHECKING, Any, Dict, Type, Union, TypeVar, Callable, Iterator, NoReturn, Optional, Coroutine from importlib.util import find_spec -from ._types import CoroType, FuncType, TypeGuard +from ._types import CoroType, FuncType, TypeGuard, ExcMapping if TYPE_CHECKING: from typing_extensions import TypeGuard @@ -139,3 +141,56 @@ def make_optional(value: _T) -> _T | None: def is_dict(obj: object) -> TypeGuard[dict[object, object]]: return isinstance(obj, dict) + + +# TODO: improve typing +class SyncAsyncContextDecorator(contextlib.ContextDecorator): + """`ContextDecorator` compatible with sync/async functions.""" + + def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]: # type:ignore + @functools.wraps(func) + async def async_inner(*args: Any, **kwargs: Any) -> object: + async with self._recreate_cm(): # type: ignore + return await func(*args, **kwargs) + + @functools.wraps(func) + def sync_inner(*args: Any, **kwargs: Any) -> object: + with self._recreate_cm(): # type: ignore + return func(*args, **kwargs) + + if is_coroutine(func): + return async_inner + else: + return sync_inner + + +class convert_exc(SyncAsyncContextDecorator): + """`SyncAsyncContextDecorator` to convert exceptions.""" + + def __init__(self, exc_mapping: ExcMapping) -> None: + self._exc_mapping = exc_mapping + + def __enter__(self) -> 'convert_exc': + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if exc is not None and exc_type is not None: + for source_exc_type, target_exc_type in self._exc_mapping.items(): + if isinstance(exc, source_exc_type): + raise target_exc_type() from exc + + async def __aenter__(self) -> 'convert_exc': + return self.__enter__() + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.__exit__(exc_type, exc, exc_tb) diff --git a/tests/test_http.py b/tests/test_http.py index 3fc6ff58e..d40509529 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -6,7 +6,7 @@ from prisma.http import HTTP from prisma.utils import _NoneType from prisma._types import Literal -from prisma.errors import HTTPClientClosedError +from prisma.errors import HTTPClientClosedError, HTTPClientTimeoutError from .utils import patch_method @@ -81,3 +81,13 @@ async def test_httpx_default_config(monkeypatch: 'MonkeyPatch') -> None: 'timeout': httpx.Timeout(30), }, ) + + +@pytest.mark.asyncio +async def test_http_timeout_error() -> None: + """Ensure that `httpx.TimeoutException` is converted to `prisma.errors.HTTPClientTimeoutError`.""" + http = HTTP(timeout=httpx.Timeout(1e-6)) + http.open() + with pytest.raises(HTTPClientTimeoutError) as exc_info: + await http.request('GET', 'https://google.com') + assert isinstance(exc_info.value.__cause__, httpx.TimeoutException)