diff --git a/.gitignore b/.gitignore index fb33be7..576e154 100644 --- a/.gitignore +++ b/.gitignore @@ -133,4 +133,4 @@ dmypy.json # vscode settings .vscode/ -src/_your_package_version.py +src/_error_handler_version.py diff --git a/README.md b/README.md index cdae84f..3449012 100644 --- a/README.md +++ b/README.md @@ -44,11 +44,13 @@ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True) logger = logging.root op = aiostream.stream.iterate(range(10)) + def log_error(error: Exception, num: int): """Only log error and reraise it""" logger.error("double_only_odd_nums_except_5 failed for input %d. ", num) raise error + @error_handler.decorator(on_error=log_error) async def double_only_odd_nums_except_5(num: int) -> int: if num % 2 == 0: @@ -59,13 +61,16 @@ async def double_only_odd_nums_except_5(num: int) -> int: num *= 2 return num + def catch_value_errors(error: Exception, _: int): if not isinstance(error, ValueError): raise error + def log_success(result_num: int, provided_num: int): logger.info("Success: %d -> %d", provided_num, result_num) + op = op | error_handler.pipe.map( double_only_odd_nums_except_5, on_error=catch_value_errors, diff --git a/src/error_handler/__init__.py b/src/error_handler/__init__.py index 7c4dc58..b3fec36 100644 --- a/src/error_handler/__init__.py +++ b/src/error_handler/__init__.py @@ -7,17 +7,9 @@ from typing import TYPE_CHECKING from .context_manager import context_manager -from .decorator import decorator, retry_on_error -from .types import ( - ERRORED, - UNSET, - AsyncFunctionType, - ErroredType, - FunctionType, - SecuredAsyncFunctionType, - SecuredFunctionType, - UnsetType, -) +from .decorator import decorator, decorator_as_result, retry_on_error +from .result import NegativeResult, PositiveResult, ResultType +from .types import UNSET, AsyncFunctionType, FunctionType, SecuredAsyncFunctionType, SecuredFunctionType, UnsetType if TYPE_CHECKING: from . import pipe, stream diff --git a/src/error_handler/callback.py b/src/error_handler/callback.py index 314f6ba..4e93db3 100644 --- a/src/error_handler/callback.py +++ b/src/error_handler/callback.py @@ -102,7 +102,9 @@ class ErrorCallback(Callback[_P, _T]): signature. """ - _CALLBACK_ERROR_PARAM = inspect.Parameter("error", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Exception) + _CALLBACK_ERROR_PARAM = inspect.Parameter( + "error", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=BaseException + ) @classmethod def from_callable( diff --git a/src/error_handler/context_manager.py b/src/error_handler/context_manager.py index 8c3c1f1..569adfe 100644 --- a/src/error_handler/context_manager.py +++ b/src/error_handler/context_manager.py @@ -6,15 +6,16 @@ from typing import Any, Callable, Iterator from .callback import Callback, ErrorCallback -from .core import Catcher +from .core import Catcher, ContextCatcher from .types import UnsetType # pylint: disable=unsubscriptable-object @contextmanager def context_manager( + *, on_success: Callable[[], Any] | None = None, - on_error: Callable[[Exception], Any] | None = None, + on_error: Callable[[BaseException], Any] | None = None, on_finalize: Callable[[], Any] | None = None, suppress_recalling_on_error: bool = True, ) -> Iterator[Catcher[UnsetType]]: @@ -28,7 +29,7 @@ def context_manager( If suppress_recalling_on_error is True, the on_error callable will not be called if the error were already caught by a previous catcher. """ - catcher = Catcher[UnsetType]( + catcher = ContextCatcher( Callback.from_callable(on_success, return_type=Any) if on_success is not None else None, ErrorCallback.from_callable(on_error, return_type=Any) if on_error is not None else None, Callback.from_callable(on_finalize, return_type=Any) if on_finalize is not None else None, diff --git a/src/error_handler/core.py b/src/error_handler/core.py index 1e8dd1c..cd192f3 100644 --- a/src/error_handler/core.py +++ b/src/error_handler/core.py @@ -19,7 +19,7 @@ ResultType, ReturnValues, ) -from .types import ERRORED, UNSET, ErroredType, T, UnsetType +from .types import UNSET, T, UnsetType _T = TypeVar("_T") _U = TypeVar("_U") @@ -43,7 +43,6 @@ def __init__( on_success: Callback | None = None, on_error: Callback | None = None, on_finalize: Callback | None = None, - on_error_return_always: T | ErroredType = ERRORED, suppress_recalling_on_error: bool = True, raise_callback_errors: bool = True, no_wrap_exception_group_when_reraise: bool = True, @@ -51,7 +50,6 @@ def __init__( self.on_success = on_success self.on_error = on_error self.on_finalize = on_finalize - self.on_error_return_always = on_error_return_always self.suppress_recalling_on_error = suppress_recalling_on_error """ If this flag is set, the framework won't call the callbacks if the caught exception was already caught by @@ -208,7 +206,7 @@ def secure_call( # type: ignore[return] # Because mypy is stupid, idk. result = callable_to_secure(*args, **kwargs) self._result = PositiveResult(result=result) except BaseException as error: # pylint: disable=broad-exception-caught - self._result = NegativeResult(error=error, result=self.on_error_return_always) + self._result = NegativeResult(error=error) return self.result async def secure_await( # type: ignore[return] # Because mypy is stupid, idk. @@ -227,9 +225,15 @@ async def secure_await( # type: ignore[return] # Because mypy is stupid, idk. result = await awaitable_to_secure self._result = PositiveResult(result=result) except BaseException as error: # pylint: disable=broad-exception-caught - self._result = NegativeResult(error=error, result=self.on_error_return_always) + self._result = NegativeResult(error=error) return self.result + +class ContextCatcher(Catcher[UnsetType]): + """ + This class is a special case of the Catcher class. It is meant to use the context manager. + """ + @contextmanager def secure_context(self) -> Iterator[Self]: """ @@ -245,4 +249,4 @@ def secure_context(self) -> Iterator[Self]: yield self self._result = PositiveResult(result=UNSET) except BaseException as error: # pylint: disable=broad-exception-caught - self._result = NegativeResult(error=error, result=self.on_error_return_always) + self._result = NegativeResult(error=error) diff --git a/src/error_handler/decorator.py b/src/error_handler/decorator.py index 1cab875..e9000fe 100644 --- a/src/error_handler/decorator.py +++ b/src/error_handler/decorator.py @@ -7,20 +7,12 @@ import inspect import logging import time -from typing import Any, Callable, Concatenate, Generator, ParamSpec, TypeGuard, TypeVar, cast +from typing import Any, Callable, Concatenate, Generator, ParamSpec, Protocol, TypeGuard, TypeVar, cast, overload from .callback import Callback, ErrorCallback, SuccessCallback from .core import Catcher from .result import CallbackResultType, PositiveResult, ResultType -from .types import ( - ERRORED, - AsyncFunctionType, - ErroredType, - FunctionType, - SecuredAsyncFunctionType, - SecuredFunctionType, - UnsetType, -) +from .types import UNSET, AsyncFunctionType, FunctionType, SecuredAsyncFunctionType, SecuredFunctionType, UnsetType _P = ParamSpec("_P") _T = TypeVar("_T") @@ -35,16 +27,67 @@ def iscoroutinefunction( return asyncio.iscoroutinefunction(callable_) +# pylint: disable=too-few-public-methods +class SecureDecorator(Protocol[_P, _T]): + """ + This protocol represents a decorator that secures a callable and returns a ResultType[T]. + """ + + @overload + def __call__( # type: ignore[overload-overlap] + # This error happens, because Callable[..., Awaitable[T]] is a subtype of Callable[..., T] and + # therefore the overloads are overlapping. This leads to problems with the type checker if you use it like this: + # + # async def some_coroutine_function() -> None: ... + # callable_to_secure: FunctionType[_P, _T] = some_coroutine_function + # reveal_type(decorator(callable_to_secure)) + # + # Revealed type is 'SecuredAsyncFunctionType[_P, _T]' but mypy will think it is 'SecuredFunctionType[_P, _T]'. + # Since it is not possible to 'negate' types (e.g. something like 'Callable[..., T \ Awaitable[T]]'), + # we have no other choice than to ignore this error. Anyway, it should be fine if you are plainly decorating + # your functions, so it's ok. + # Reference: https://stackoverflow.com/a/74567241/21303427 + self, + callable_to_secure: AsyncFunctionType[_P, _T], + ) -> SecuredAsyncFunctionType[_P, _T]: ... + + @overload + def __call__(self, callable_to_secure: FunctionType[_P, _T]) -> SecuredFunctionType[_P, _T]: ... + + def __call__( + self, callable_to_secure: FunctionType[_P, _T] | AsyncFunctionType[_P, _T] + ) -> SecuredFunctionType[_P, _T] | SecuredAsyncFunctionType[_P, _T]: ... + + +# pylint: disable=too-few-public-methods +class Decorator(Protocol[_P, _T]): + """ + This protocol represents a decorator that secures a callable but does not change the return type. + """ + + @overload + def __call__( # type: ignore[overload-overlap] + self, + callable_to_secure: AsyncFunctionType[_P, _T], + ) -> AsyncFunctionType[_P, _T]: ... + + @overload + def __call__(self, callable_to_secure: FunctionType[_P, _T]) -> FunctionType[_P, _T]: ... + + def __call__( + self, + callable_to_secure: FunctionType[_P, _T] | AsyncFunctionType[_P, _T], + ) -> FunctionType[_P, _T] | AsyncFunctionType[_P, _T]: ... + + # pylint: disable=too-many-arguments -def decorator( +def decorator_as_result( + *, on_success: Callable[Concatenate[_T, _P], Any] | None = None, - on_error: Callable[Concatenate[Exception, _P], Any] | None = None, + on_error: Callable[Concatenate[BaseException, _P], Any] | None = None, on_finalize: Callable[_P, Any] | None = None, - on_error_return_always: _T | ErroredType = ERRORED, suppress_recalling_on_error: bool = True, -) -> Callable[ - [FunctionType[_P, _T] | AsyncFunctionType[_P, _T]], SecuredFunctionType[_P, _T] | SecuredAsyncFunctionType[_P, _T] -]: +) -> SecureDecorator[_P, _T]: """ This decorator secures a callable (sync or async) and handles its errors. If the callable raises an error, the on_error callback will be called and the value if on_error_return_always @@ -58,6 +101,14 @@ def decorator( """ # pylint: disable=unsubscriptable-object + @overload + def decorator_inner( # type: ignore[overload-overlap] # See above + callable_to_secure: AsyncFunctionType[_P, _T], + ) -> SecuredAsyncFunctionType[_P, _T]: ... + + @overload + def decorator_inner(callable_to_secure: FunctionType[_P, _T]) -> SecuredFunctionType[_P, _T]: ... + def decorator_inner( callable_to_secure: FunctionType[_P, _T] | AsyncFunctionType[_P, _T] ) -> SecuredFunctionType[_P, _T] | SecuredAsyncFunctionType[_P, _T]: @@ -66,30 +117,27 @@ def decorator_inner( SuccessCallback.from_callable(on_success, sig, return_type=Any) if on_success is not None else None, ErrorCallback.from_callable(on_error, sig, return_type=Any) if on_error is not None else None, Callback.from_callable(on_finalize, sig, return_type=Any) if on_finalize is not None else None, - on_error_return_always, suppress_recalling_on_error, ) if iscoroutinefunction(callable_to_secure): @functools.wraps(callable_to_secure) - async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T | ErroredType: + async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> ResultType[_T]: result = await catcher.secure_await(callable_to_secure(*args, **kwargs)) catcher.handle_result_and_call_callbacks(result, *args, **kwargs) - assert not isinstance(result.result, UnsetType), "Internal error: result is unset" - return result.result + return result else: @functools.wraps(callable_to_secure) - def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T | ErroredType: + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> ResultType[_T]: result = catcher.secure_call( callable_to_secure, # type: ignore[arg-type] *args, **kwargs, ) catcher.handle_result_and_call_callbacks(result, *args, **kwargs) - assert not isinstance(result.result, UnsetType), "Internal error: result is unset" - return result.result + return result return_func = cast(SecuredFunctionType[_P, _T] | SecuredAsyncFunctionType[_P, _T], wrapper) return_func.__catcher__ = catcher @@ -101,18 +149,17 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T | ErroredType: # pylint: disable=too-many-arguments, too-many-locals def retry_on_error( - on_error: Callable[Concatenate[Exception, int, _P], bool], + *, + on_error: Callable[Concatenate[BaseException, int, _P], bool], retry_stepping_func: Callable[[int], float] = lambda retry_count: 1.71**retry_count, # <-- with max_retries = 10 the whole decorator may wait up to 5 minutes. # because sum(1.71seconds**i for i in range(10)) == 5minutes max_retries: int = 10, on_success: Callable[Concatenate[_T, int, _P], Any] | None = None, - on_fail: Callable[Concatenate[Exception, int, _P], Any] | None = None, + on_fail: Callable[Concatenate[BaseException, int, _P], Any] | None = None, on_finalize: Callable[Concatenate[int, _P], Any] | None = None, logger: logging.Logger = logging.getLogger(__name__), -) -> Callable[ - [FunctionType[_P, _T] | AsyncFunctionType[_P, _T]], SecuredFunctionType[_P, _T] | SecuredAsyncFunctionType[_P, _T] -]: +) -> Decorator[_P, _T]: """ This decorator retries a callable (sync or async) on error. The retry_stepping_func is called with the retry count and should return the time to wait until the next retry. @@ -126,7 +173,7 @@ def retry_on_error( def decorator_inner( callable_to_secure: FunctionType[_P, _T] | AsyncFunctionType[_P, _T] - ) -> SecuredFunctionType[_P, _T] | SecuredAsyncFunctionType[_P, _T]: + ) -> FunctionType[_P, _T] | AsyncFunctionType[_P, _T]: sig = inspect.signature(callable_to_secure) sig = sig.replace( parameters=[ @@ -134,13 +181,13 @@ def decorator_inner( *sig.parameters.values(), ], ) - on_error_callback: ErrorCallback[Concatenate[Exception, int, _P], bool] = ErrorCallback.from_callable( + on_error_callback: ErrorCallback[Concatenate[BaseException, int, _P], bool] = ErrorCallback.from_callable( on_error, sig, return_type=bool ) on_success_callback: SuccessCallback[Concatenate[_T, int, _P], Any] | None = ( SuccessCallback.from_callable(on_success, sig, return_type=Any) if on_success is not None else None ) - on_fail_callback: ErrorCallback[Concatenate[Exception, int, _P], Any] | None = ( + on_fail_callback: ErrorCallback[Concatenate[BaseException, int, _P], Any] | None = ( ErrorCallback.from_callable(on_fail, sig, return_type=Any) if on_fail is not None else None ) on_finalize_callback: Callback[Concatenate[int, _P], Any] | None = ( @@ -183,7 +230,6 @@ def retry_generator(*args: _P.args, **kwargs: _P.kwargs) -> Generator[int, Resul def handle_result_and_call_callbacks(result: ResultType[_T], *args: _P.args, **kwargs: _P.kwargs) -> _T: if isinstance(result, PositiveResult): - assert not isinstance(result.result, UnsetType), "Internal error: result is unset" catcher_retrier.handle_success_case( result.result, retry_count, @@ -238,9 +284,57 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: result = catcher_retrier.secure_call(retry_function_sync, *args, **kwargs) return handle_result_and_call_callbacks(result, *args, **kwargs) - return_func = cast(SecuredFunctionType[_P, _T] | SecuredAsyncFunctionType[_P, _T], wrapper) - return_func.__catcher__ = catcher_retrier - return_func.__original_callable__ = callable_to_secure + return_func = cast(FunctionType[_P, _T] | AsyncFunctionType[_P, _T], wrapper) + return_func.__catcher__ = catcher_retrier # type: ignore[union-attr] + return_func.__original_callable__ = callable_to_secure # type: ignore[union-attr] return return_func - return decorator_inner + return decorator_inner # type: ignore[return-value] + + +def decorator( + *, + on_success: Callable[Concatenate[_T, _P], Any] | None = None, + on_error: Callable[Concatenate[BaseException, _P], Any] | None = None, + on_finalize: Callable[_P, Any] | None = None, + suppress_recalling_on_error: bool = True, + on_error_return_always: _T | UnsetType = UNSET, +) -> Decorator[_P, _T]: + """ + Returns a callback that converts the result of a secured function back to the original return type. + To make this work, you need to define which value should be returned in error cases. + Otherwise, if the secured function returns an error result, the error will be raised. + """ + + def decorator_inner( + func: FunctionType[_P, _T] | AsyncFunctionType[_P, _T] + ) -> FunctionType[_P, _T] | AsyncFunctionType[_P, _T]: + secured_func = decorator_as_result( + on_success=on_success, + on_error=on_error, + on_finalize=on_finalize, + suppress_recalling_on_error=suppress_recalling_on_error, + )(func) + + def handle_result(result: ResultType[_T]) -> _T: + if isinstance(result, PositiveResult): + return result.result + if isinstance(on_error_return_always, UnsetType): + raise result.error + return on_error_return_always + + if iscoroutinefunction(secured_func): + + @functools.wraps(secured_func) + async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + return handle_result(await secured_func(*args, **kwargs)) + + else: + + @functools.wraps(secured_func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + return handle_result(secured_func(*args, **kwargs)) # type: ignore[arg-type] + + return cast(FunctionType[_P, _T] | AsyncFunctionType[_P, _T], wrapper) + + return decorator_inner # type: ignore[return-value] diff --git a/src/error_handler/result.py b/src/error_handler/result.py index 5d64528..03aee04 100644 --- a/src/error_handler/result.py +++ b/src/error_handler/result.py @@ -4,9 +4,11 @@ from dataclasses import dataclass from enum import StrEnum -from typing import Any, Generic, TypeAlias +from typing import Any, Generic, TypeAlias, TypeVar -from error_handler.types import UNSET, ErroredType, T, UnsetType +from error_handler.types import UNSET + +T = TypeVar("T") class CallbackResultType(StrEnum): @@ -59,7 +61,7 @@ class PositiveResult(Generic[T]): Represents a successful result. """ - result: T | UnsetType + result: T @dataclass(frozen=True) @@ -68,7 +70,6 @@ class NegativeResult(Generic[T]): Represents an errored result. """ - result: T | ErroredType error: BaseException diff --git a/src/error_handler/stream.py b/src/error_handler/stream.py index f8cf790..da37bd6 100644 --- a/src/error_handler/stream.py +++ b/src/error_handler/stream.py @@ -6,9 +6,10 @@ import sys from typing import Any, AsyncIterable, AsyncIterator, Callable +from . import NegativeResult, PositiveResult, ResultType from ._extra import IS_AIOSTREAM_INSTALLED -from .decorator import decorator -from .types import ERRORED, AsyncFunctionType, FunctionType, SecuredAsyncFunctionType, SecuredFunctionType, is_secured +from .decorator import decorator_as_result +from .types import AsyncFunctionType, FunctionType, SecuredAsyncFunctionType, SecuredFunctionType, is_secured if IS_AIOSTREAM_INSTALLED: import aiostream @@ -28,7 +29,7 @@ def map( ordered: bool = True, task_limit: int | None = None, on_success: Callable[[U, T], Any] | None = None, - on_error: Callable[[Exception, T], Any] | None = None, + on_error: Callable[[BaseException, T], Any] | None = None, on_finalize: Callable[[T], Any] | None = None, wrap_secured_function: bool = False, suppress_recalling_on_error: bool = True, @@ -41,11 +42,6 @@ def map( caught by a previous catcher. """ if not wrap_secured_function and is_secured(func): - if func.__catcher__.on_error_return_always is not ERRORED: - raise ValueError( - "The given function is already secured but does not return ERRORED in error case. " - "If the secured function re-raises errors you can set wrap_secured_function=True" - ) if ( on_success is not None or on_error is not None @@ -62,22 +58,32 @@ def map( ) secured_func = func else: - secured_func = decorator( # type: ignore[assignment] + # pylint: disable=duplicate-code + secured_func = decorator_as_result( # type: ignore[assignment] on_success=on_success, on_error=on_error, on_finalize=on_finalize, - on_error_return_always=ERRORED, suppress_recalling_on_error=suppress_recalling_on_error, )( func # type: ignore[arg-type] ) # Ignore that T | ErroredType is not compatible with T. All ErroredType results are filtered out # in a subsequent step. - next_source: AsyncIterator[U] = aiostream.stream.map.raw( + results: AsyncIterator[ResultType[U]] = aiostream.stream.map.raw( source, secured_func, *more_sources, ordered=ordered, task_limit=task_limit # type: ignore[arg-type] ) - next_source = aiostream.stream.filter.raw(next_source, lambda result: result is not ERRORED) - return next_source + positive_results: AsyncIterator[PositiveResult[U]] = aiostream.stream.filter.raw( + results, # type: ignore[arg-type] + # mypy can't successfully narrow the type here. + lambda result: not isinstance(result, NegativeResult), + ) + result_values: AsyncIterator[U] = aiostream.stream.map.raw( + positive_results, + lambda result: ( # type: ignore[arg-type, misc] + result.result if isinstance(result, PositiveResult) else result + ), + ) + return result_values else: from ._extra import _NotInstalled diff --git a/src/error_handler/types.py b/src/error_handler/types.py index 0a40eaa..8ec2d2e 100644 --- a/src/error_handler/types.py +++ b/src/error_handler/types.py @@ -3,10 +3,11 @@ """ import inspect -from typing import TYPE_CHECKING, Awaitable, Callable, ParamSpec, Protocol, TypeAlias, TypeGuard, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Coroutine, ParamSpec, Protocol, TypeAlias, TypeGuard, TypeVar if TYPE_CHECKING: from .core import Catcher + from .result import ResultType T = TypeVar("T") P = ParamSpec("P") @@ -51,14 +52,6 @@ def __singleton_new__(cls, *args, **kwargs): return __singleton_new__ -# pylint: disable=too-few-public-methods -class ErroredType(metaclass=SingletonMeta): - """ - This type is meant to be used as singleton. Do not instantiate it on your own. - The instance below represents an errored result. - """ - - # pylint: disable=too-few-public-methods class UnsetType(metaclass=SingletonMeta): """ @@ -72,15 +65,10 @@ class UnsetType(metaclass=SingletonMeta): """ Represents an unset value. It is used as default value for parameters that can be of any type. """ -ERRORED = ErroredType() -""" -Represents an errored result. It is used to be able to return something in error cases. See Catcher.secure_call -for more information. -""" FunctionType: TypeAlias = Callable[P, T] -AsyncFunctionType: TypeAlias = Callable[P, Awaitable[T]] +AsyncFunctionType: TypeAlias = Callable[P, Coroutine[Any, Any, T]] class SecuredFunctionType(Protocol[P, T]): @@ -91,7 +79,7 @@ class SecuredFunctionType(Protocol[P, T]): __catcher__: "Catcher[T]" __original_callable__: FunctionType[P, T] - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T | ErroredType: ... + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "ResultType[T]": ... class SecuredAsyncFunctionType(Protocol[P, T]): @@ -102,22 +90,13 @@ class SecuredAsyncFunctionType(Protocol[P, T]): __catcher__: "Catcher[T]" __original_callable__: AsyncFunctionType[P, T] - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[T | ErroredType]: ... + async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "ResultType[T]": ... def is_secured( func: FunctionType[P, T] | SecuredFunctionType[P, T] | AsyncFunctionType[P, T] | SecuredAsyncFunctionType[P, T] ) -> TypeGuard[SecuredFunctionType[P, T] | SecuredAsyncFunctionType[P, T]]: """ - Returns True if the given function is secured, False otherwise. + Returns True if the given function is secured and returns a ResultType[T]. False otherwise. """ return hasattr(func, "__catcher__") and hasattr(func, "__original_callable__") - - -def is_unsecured( - func: FunctionType[P, T] | AsyncFunctionType[P, T] | SecuredFunctionType[P, T] | SecuredAsyncFunctionType[P, T] -) -> TypeGuard[FunctionType[P, T] | AsyncFunctionType[P, T]]: - """ - Returns True if the given function is not secured, False otherwise. - """ - return not hasattr(func, "__catcher__") or not hasattr(func, "__original_callable__") diff --git a/unittests/test_callback_errors.py b/unittests/test_callback_errors.py index 5b6f351..33893ee 100644 --- a/unittests/test_callback_errors.py +++ b/unittests/test_callback_errors.py @@ -12,7 +12,7 @@ def test_decorator_callbacks_wrong_signature_call_all_callbacks(self): def on_finalize_wrong_signature(): pass - @error_handler.decorator( + @error_handler.decorator_as_result( on_success=on_success_callback, on_error=assert_not_called, on_finalize=on_finalize_wrong_signature ) def func(hello: str) -> str: @@ -36,7 +36,7 @@ def on_finalize_wrong_signature(): def on_error_callback(_: BaseException, __: str): raise ValueError("This is a test error") - @error_handler.decorator( + @error_handler.decorator_as_result( on_success=assert_not_called, on_error=on_error_callback, on_finalize=on_finalize_wrong_signature ) def func(hello: str) -> str: diff --git a/unittests/test_decorator.py b/unittests/test_decorator.py index 5cf8276..3d4f829 100644 --- a/unittests/test_decorator.py +++ b/unittests/test_decorator.py @@ -15,11 +15,10 @@ async def test_decorator_coroutine_error_case(self): error_callback, error_tracker = create_callback_tracker() finalize_callback, finalize_tracker = create_callback_tracker() - @error_handler.decorator( + @error_handler.decorator_as_result( on_error=error_callback, on_finalize=finalize_callback, on_success=assert_not_called, - on_error_return_always=error_handler.ERRORED, ) async def async_function(hello: str) -> None: raise ValueError(f"This is a test error {hello}") @@ -28,7 +27,7 @@ async def async_function(hello: str) -> None: result = await awaitable assert str(error_tracker[0][0][0]) == "This is a test error world" assert str(error_tracker[0][0][1]) == "world" - assert result == error_handler.ERRORED + assert isinstance(result, error_handler.NegativeResult) assert finalize_tracker == [(("world",), {})] async def test_decorator_coroutine_success_case(self): @@ -52,7 +51,7 @@ async def async_function(hello: str) -> str: def test_decorator_function_error_case(self): error_callback, error_tracker = create_callback_tracker() - @error_handler.decorator(on_error=error_callback, on_success=assert_not_called) + @error_handler.decorator(on_error=error_callback, on_success=assert_not_called, on_error_return_always=None) def func(hello: str) -> None: raise ValueError(f"This is a test error {hello}") @@ -60,7 +59,7 @@ def func(hello: str) -> None: assert isinstance(error_tracker[0][0][0], ValueError) assert str(error_tracker[0][0][0]) == "This is a test error world" assert error_tracker[0][0][1] == "world" - assert result == error_handler.ERRORED + assert result is None def test_decorator_function_success_case(self): on_success_callback, success_tracker = create_callback_tracker() @@ -81,7 +80,7 @@ def store_error(error: Exception, _: str): catched_error = error raise error - @error_handler.decorator(on_error=store_error) + @error_handler.decorator_as_result(on_error=store_error) async def async_function(hello: str) -> None: raise ValueError(f"This is a test error {hello}") @@ -101,7 +100,7 @@ def store_error(error: Exception, _: str): catched_error = error raise error - @error_handler.decorator(on_error=store_error) + @error_handler.decorator_as_result(on_error=store_error) def func(hello: str) -> None: raise ValueError(f"This is a test error {hello}") @@ -124,14 +123,14 @@ def __init__(self, value: int): on_error=error_callback, on_finalize=finalize_callback, on_success=assert_not_called, - on_error_return_always=error_handler.ERRORED, + on_error_return_always=None, ) def func(self, hello: str) -> None: raise ValueError(f"This is a test error {hello}") instance = MyClass(42) result = instance.func("world") - assert result == error_handler.ERRORED + assert result is None assert str(error_tracker[0][0][0]) == "This is a test error world" assert error_tracker == [((error_tracker[0][0][0], instance, "world"), {})] assert finalize_tracker == [((instance, "world"), {})] diff --git a/unittests/test_pipable_operators.py b/unittests/test_pipable_operators.py index f08e2de..369a381 100644 --- a/unittests/test_pipable_operators.py +++ b/unittests/test_pipable_operators.py @@ -63,22 +63,10 @@ def store(error: Exception, _: int): assert set(elements) == {1, 3, 5} assert errored_nums == {2, 4, 6} - async def test_secured_map_stream_double_secure_invalid_return_value(self): - op = stream.iterate([1, 2, 3, 4, 5, 6]) - - @error_handler.decorator(on_error_return_always=0) - def return_1(_: int) -> int: - return 1 - - with pytest.raises(ValueError) as error: - _ = error_handler.stream.map(op, return_1) - - assert "The given function is already secured but does not return ERRORED in error case" in str(error.value) - async def test_secured_map_stream_double_secure_invalid_arguments(self): op = stream.iterate([1, 2, 3, 4, 5, 6]) - @error_handler.decorator() + @error_handler.decorator_as_result() def return_1(_: int) -> int: return 1 @@ -93,7 +81,7 @@ async def test_secured_map_stream_double_secure_no_wrap(self): op = stream.iterate([1, 2, 3, 4, 5, 6]) - @error_handler.decorator(on_error=error_callback, on_success=success_callback) + @error_handler.decorator_as_result(on_error=error_callback, on_success=success_callback) def raise_for_even(num: int) -> int: if num % 2 == 0: raise ValueError(num)