From 36c6d713698969d865de79d90a20448fb0abd035 Mon Sep 17 00:00:00 2001 From: remimd Date: Tue, 16 Jan 2024 23:59:04 +0100 Subject: [PATCH 1/2] =?UTF-8?q?refactoring:=20=E2=99=BB=EF=B8=8F=20Remove?= =?UTF-8?q?=20decorator=20classes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- injection/_pkg.py | 7 +- injection/_pkg.pyi | 25 +++++- injection/common/tools/_type.py | 42 +++++---- injection/core/module.py | 152 ++++++++++++-------------------- 4 files changed, 105 insertions(+), 121 deletions(-) diff --git a/injection/_pkg.py b/injection/_pkg.py index c250026..95027a5 100644 --- a/injection/_pkg.py +++ b/injection/_pkg.py @@ -1,6 +1,7 @@ -from .core import Module, ModulePriorities +from .core import Injectable, Module, ModulePriorities __all__ = ( + "Injectable", "Module", "ModulePriorities", "default_module", @@ -16,9 +17,7 @@ get_instance = default_module.get_instance get_lazy_instance = default_module.get_lazy_instance - inject = default_module.inject injectable = default_module.injectable -singleton = default_module.singleton - set_constant = default_module.set_constant +singleton = default_module.singleton diff --git a/injection/_pkg.pyi b/injection/_pkg.pyi index 0b4691e..6221529 100644 --- a/injection/_pkg.pyi +++ b/injection/_pkg.pyi @@ -1,8 +1,17 @@ +from abc import abstractmethod from collections.abc import Callable, Iterable from contextlib import ContextDecorator from enum import Enum from types import UnionType -from typing import Any, ContextManager, Final, TypeVar, final +from typing import ( + Any, + ContextManager, + Final, + Protocol, + TypeVar, + final, + runtime_checkable, +) from injection.common.lazy import Lazy @@ -12,12 +21,10 @@ default_module: Final[Module] = ... get_instance = default_module.get_instance get_lazy_instance = default_module.get_lazy_instance - inject = default_module.inject injectable = default_module.injectable -singleton = default_module.singleton - set_constant = default_module.set_constant +singleton = default_module.singleton @final class Module: @@ -42,6 +49,7 @@ class Module: wrapped: Callable[..., Any] = ..., /, *, + cls: type[Injectable] = ..., on: type | Iterable[type] | UnionType = ..., ): """ @@ -119,3 +127,12 @@ class Module: class ModulePriorities(Enum): HIGH = ... LOW = ... + +@runtime_checkable +class Injectable(Protocol[_T]): + def __init__(self, factory: Callable[[], _T], *args, **kwargs): ... + @property + def is_locked(self) -> bool: ... + def unlock(self): ... + @abstractmethod + def get_instance(self) -> _T: ... diff --git a/injection/common/tools/_type.py b/injection/common/tools/_type.py index 2d026df..e344d28 100644 --- a/injection/common/tools/_type.py +++ b/injection/common/tools/_type.py @@ -1,8 +1,9 @@ -from collections.abc import Iterator +from collections.abc import Iterable, Iterator +from inspect import get_annotations, isfunction from types import NoneType, UnionType from typing import Annotated, Any, Union, get_args, get_origin -__all__ = ("format_type", "get_origins") +__all__ = ("find_types", "format_type", "get_origins") def format_type(cls: type | Any) -> str: @@ -12,25 +13,36 @@ def format_type(cls: type | Any) -> str: return str(cls) -def get_origins(*classes: type | Any) -> Iterator[type | Any]: - for cls in classes: - origin = get_origin(cls) or cls +def get_origins(*types: type | Any) -> Iterator[type | Any]: + for tp in types: + origin = get_origin(tp) or tp if origin in (None, NoneType): continue - arguments = get_args(cls) + elif origin in (Union, UnionType): + args = get_args(tp) - if origin in (Union, UnionType): - yield from get_origins(*arguments) + elif origin is Annotated is not tp: + args = (tp.__origin__,) + + else: + yield origin + continue + + yield from get_origins(*args) - elif origin is Annotated: - try: - annotated = arguments[0] - except IndexError: - continue - yield from get_origins(annotated) +def find_types(*args: Any) -> Iterator[type | UnionType]: + for argument in args: + if isinstance(argument, Iterable) and not isinstance(argument, type | str): + arguments = argument + + elif isfunction(argument): + arguments = (get_annotations(argument, eval_str=True).get("return"),) else: - yield origin + yield argument + continue + + yield from find_types(*arguments) diff --git a/injection/core/module.py b/injection/core/module.py index 775af5d..725aaee 100644 --- a/injection/core/module.py +++ b/injection/core/module.py @@ -15,8 +15,8 @@ from contextlib import ContextDecorator, contextmanager, suppress from dataclasses import dataclass, field from enum import Enum, auto -from functools import singledispatchmethod, wraps -from inspect import Signature, get_annotations, isclass, isfunction +from functools import partialmethod, singledispatchmethod, wraps +from inspect import Signature, isclass from threading import RLock from types import MappingProxyType, UnionType from typing import ( @@ -26,13 +26,12 @@ Protocol, TypeVar, cast, - final, runtime_checkable, ) from injection.common.event import Event, EventChannel, EventListener from injection.common.lazy import Lazy, LazyMapping -from injection.common.tools import format_type, get_origins +from injection.common.tools import find_types, format_type, get_origins from injection.exceptions import ( ModuleError, ModuleLockError, @@ -133,6 +132,9 @@ def __str__(self) -> str: class Injectable(Protocol[_T]): __slots__ = () + def __init__(self, factory: Callable[[], _T], *args, **kwargs): + ... + @property def is_locked(self) -> bool: return False @@ -288,18 +290,6 @@ def __contains__(self, cls: type | UnionType, /) -> bool: def __str__(self) -> str: return self.name or object.__str__(self) - @property - def inject(self) -> InjectDecorator: - return InjectDecorator(self) - - @property - def injectable(self) -> InjectableDecorator: - return InjectableDecorator(self, NewInjectable) - - @property - def singleton(self) -> InjectableDecorator: - return InjectableDecorator(self, SingletonInjectable) - @property def is_locked(self) -> bool: return any(broker.is_locked for broker in self.__brokers) @@ -309,6 +299,25 @@ def __brokers(self) -> Iterator[Container | Module]: yield from tuple(self.__modules) yield self.__container + def injectable( + self, + wrapped: Callable[..., Any] = None, + /, + *, + cls: type[Injectable] = NewInjectable, + on: type | Types = None, + ): + def decorator(wp): + factory = self.inject(wp, return_factory=True) + injectable = cls(factory) + classes = find_types(wp, on) + self.update(classes, injectable) + return wp + + return decorator(wrapped) if wrapped else decorator + + singleton = partialmethod(injectable, cls=SingletonInjectable) + def set_constant(self, instance: _T, on: type | Types = None) -> _T: cls = type(instance) @@ -318,6 +327,29 @@ def get_constant(): return instance + def inject( + self, + wrapped: Callable[..., Any] = None, + /, + *, + return_factory: bool = False, + ): + def decorator(wp): + if not return_factory and isclass(wp): + wp.__init__ = decorator(wp.__init__) + return wp + + lazy_binder = Lazy[Binder](lambda: self.__new_binder(wp)) + + @wraps(wp) + def wrapper(*args, **kwargs): + arguments = (~lazy_binder).bind(*args, **kwargs) + return wp(*arguments.args, **arguments.kwargs) + + return wrapper + + return decorator(wrapped) if wrapped else decorator + def get_instance(self, cls: type[_T]) -> _T | None: try: injectable = self[cls] @@ -420,6 +452,12 @@ def __move_module(self, module: Module, priority: ModulePriorities): f"`{module}` can't be found in the modules used by `{self}`." ) from exc + def __new_binder(self, target: Callable[..., Any]) -> Binder: + signature = inspect.signature(target, eval_str=True) + binder = Binder(signature).update(self) + self.add_listener(binder) + return binder + """ Binder @@ -502,85 +540,3 @@ def on_event(self, event: Event, /): def _(self, event: ModuleEvent, /) -> ContextManager: yield self.update(event.on_module) - - -""" -Decorators -""" - - -@final -@dataclass(repr=False, frozen=True, slots=True) -class InjectDecorator: - __module: Module - - def __call__(self, wrapped: Callable[..., Any] = None, /): - def decorator(wp): - if isclass(wp): - return self.__class_decorator(wp) - - return self.__decorator(wp) - - return decorator(wrapped) if wrapped else decorator - - def __decorator(self, function: Callable[..., Any], /) -> Callable[..., Any]: - lazy_binder = Lazy[Binder](lambda: self.__new_binder(function)) - - @wraps(function) - def wrapper(*args, **kwargs): - arguments = (~lazy_binder).bind(*args, **kwargs) - return function(*arguments.args, **arguments.kwargs) - - return wrapper - - def __class_decorator(self, cls: type, /) -> type: - cls.__init__ = self.__decorator(cls.__init__) - return cls - - def __new_binder(self, function: Callable[..., Any]) -> Binder: - signature = inspect.signature(function, eval_str=True) - binder = Binder(signature).update(self.__module) - self.__module.add_listener(binder) - return binder - - -@final -@dataclass(repr=False, frozen=True, slots=True) -class InjectableDecorator: - __module: Module - __injectable_type: type[BaseInjectable] - - def __repr__(self) -> str: - return f"<{self.__injectable_type.__qualname__} decorator>" - - def __call__( - self, - wrapped: Callable[..., Any] = None, - /, - *, - on: type | Types = None, - ): - def decorator(wp): - @self.__module.inject - @wraps(wp, updated=()) - def factory(*args, **kwargs): - return wp(*args, **kwargs) - - injectable = self.__injectable_type(factory) - classes = self.__get_classes(wp, on) - self.__module.update(classes, injectable) - return wp - - return decorator(wrapped) if wrapped else decorator - - @classmethod - def __get_classes(cls, *objects: Any) -> Iterator[type | UnionType]: - for obj in objects: - if isinstance(obj, Iterable) and not isinstance(obj, type | str): - yield from cls.__get_classes(*obj) - - elif isfunction(obj): - yield get_annotations(obj, eval_str=True).get("return") - - else: - yield obj From beb9d63a3bc601f920a1313089574a7ab56e811f Mon Sep 17 00:00:00 2001 From: remimd Date: Wed, 17 Jan 2024 17:34:57 +0100 Subject: [PATCH 2/2] fix --- injection/_pkg.pyi | 2 +- injection/core/module.py | 2 +- tests/core/test_module.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/injection/_pkg.pyi b/injection/_pkg.pyi index 6221529..edfa351 100644 --- a/injection/_pkg.pyi +++ b/injection/_pkg.pyi @@ -130,7 +130,7 @@ class ModulePriorities(Enum): @runtime_checkable class Injectable(Protocol[_T]): - def __init__(self, factory: Callable[[], _T], *args, **kwargs): ... + def __init__(self, factory: Callable[[], _T] = ..., *args, **kwargs): ... @property def is_locked(self) -> bool: ... def unlock(self): ... diff --git a/injection/core/module.py b/injection/core/module.py index 725aaee..15b7576 100644 --- a/injection/core/module.py +++ b/injection/core/module.py @@ -132,7 +132,7 @@ def __str__(self) -> str: class Injectable(Protocol[_T]): __slots__ = () - def __init__(self, factory: Callable[[], _T], *args, **kwargs): + def __init__(self, factory: Callable[[], _T] = ..., *args, **kwargs): ... @property diff --git a/tests/core/test_module.py b/tests/core/test_module.py index 5cf615f..7d2b844 100644 --- a/tests/core/test_module.py +++ b/tests/core/test_module.py @@ -25,7 +25,7 @@ def get_instance(self) -> class_: __getitem__ """ - def test_getitem_with_success_injectable(self, module): + def test_getitem_with_success_return_injectable(self, module): injectable_w = self.get_test_injectable(SomeClass()) module[SomeClass] = injectable_w assert module[SomeClass] is injectable_w