diff --git a/cq/middlewares/exc.py b/cq/middlewares/exc.py new file mode 100644 index 0000000..9582c07 --- /dev/null +++ b/cq/middlewares/exc.py @@ -0,0 +1,56 @@ +from collections.abc import Awaitable, Callable, Sequence +from typing import Any, Concatenate, Self + +from cq import MiddlewareResult + +__all__ = ("CaptureExceptionMiddleware",) + + +class CaptureExceptionMiddleware[**P, Exc: BaseException]: + __slots__ = ("__exceptions", "__on_error", "__reraise") + + __exceptions: tuple[type[Exc], ...] + __on_error: Callable[Concatenate[Exc, P], Awaitable[Any]] + __reraise: bool + + def __init__( + self, + on_error: Callable[Concatenate[Exc, P], Awaitable[Any]], + /, + exceptions: Sequence[type[Exc]] | None = None, + reraise: bool = False, + ) -> None: + self.__exceptions = (Exception,) if exceptions is None else tuple(exceptions) # type: ignore[assignment] + self.__on_error = on_error + self.__reraise = reraise + + async def __call__( + self, + /, + *args: P.args, + **kwargs: P.kwargs, + ) -> MiddlewareResult[Any]: + try: + yield + except self.__exceptions as exc: + await self.__on_error(exc, *args, **kwargs) + if self.__reraise: + raise + + @classmethod + def sync( + cls, + on_error: Callable[Concatenate[Exc, P], Any], + /, + exceptions: Sequence[type[Exc]] | None = None, + reraise: bool = False, + ) -> Self: + async def async_on_error( + exception: Exc, + /, + *args: P.args, + **kwargs: P.kwargs, + ) -> Any: + return on_error(exception, *args, **kwargs) + + return cls(async_on_error, exceptions, reraise) diff --git a/docs/guides/configuring.md b/docs/guides/configuring.md index d012699..a7c71e4 100644 --- a/docs/guides/configuring.md +++ b/docs/guides/configuring.md @@ -131,3 +131,26 @@ The parameters are: * `exceptions`: the exception types that trigger a retry. Defaults to `(Exception,)`, which retries on any non-`BaseException` failure. If every attempt fails, the last exception is re-raised. + +### `CaptureExceptionMiddleware` + +`cq.middlewares.exc.CaptureExceptionMiddleware` catches exceptions raised by downstream handlers and forwards them to a callback. Use it to log, report, or push errors to an external sink without changing how they propagate: + +```python +from cq import new_command_bus +from cq.middlewares.exc import CaptureExceptionMiddleware + +async def report(exception, message): + sentry_sdk.capture_exception(exception) + +bus = new_command_bus() +bus.add_middlewares(CaptureExceptionMiddleware(report, reraise=True)) +``` + +The parameters are: + +* `on_error`: an async callback invoked with the captured exception followed by the same arguments the handler received (typically the message). Use `CaptureExceptionMiddleware.sync(...)` if your callback is synchronous. +* `exceptions`: the exception types to capture. Defaults to `(Exception,)`. +* `reraise`: whether to re-raise the exception after the callback returns. Defaults to `False`, in which case the exception is swallowed. + +`on_error` is meant for side effects only (logging, metrics, notifications) and must not raise. If it does, its own exception will propagate in place of the original one. diff --git a/tests/middlewares/test_exc.py b/tests/middlewares/test_exc.py new file mode 100644 index 0000000..25cd5cd --- /dev/null +++ b/tests/middlewares/test_exc.py @@ -0,0 +1,60 @@ +from typing import Any, Self + +import anyio +import pytest + +from cq import Bus +from cq.middlewares.exc import CaptureExceptionMiddleware + + +class TestCaptureExceptionMiddleware: + async def test_capture_exception_middleware_with_success( + self, + bus: Bus[Any, Any], + ) -> None: + class Handler: + async def handle(self, message: str) -> str: + raise Exception + + @classmethod + async def async_factory(cls) -> Self: + return cls() + + captured = anyio.Event() + + def capture(exc: Exception, message: str) -> None: + captured.set() + + bus.add_middlewares(CaptureExceptionMiddleware.sync(capture)) + bus.subscribe(str, Handler.async_factory) + + assert not captured.is_set() + await bus.dispatch("Hello world!") + assert captured.is_set() + + async def test_capture_exception_middleware_with_reraise( + self, + bus: Bus[Any, Any], + ) -> None: + class Handler: + async def handle(self, message: str) -> str: + raise Exception + + @classmethod + async def async_factory(cls) -> Self: + return cls() + + captured = anyio.Event() + + def capture(exc: Exception, message: str) -> None: + captured.set() + + bus.add_middlewares(CaptureExceptionMiddleware.sync(capture, reraise=True)) + bus.subscribe(str, Handler.async_factory) + + assert not captured.is_set() + + with pytest.raises(Exception): + await bus.dispatch("Hello world!") + + assert captured.is_set()