diff --git a/injection/_pkg.pyi b/injection/_pkg.pyi index e6be959..5afcd05 100644 --- a/injection/_pkg.pyi +++ b/injection/_pkg.pyi @@ -13,7 +13,7 @@ from typing import ( runtime_checkable, ) -from injection.common.lazy import Lazy +from injection.common.invertible import Invertible _T = TypeVar("_T") @@ -104,11 +104,16 @@ class Module: will be raised. """ - def get_lazy_instance(self, cls: type[_T]) -> Lazy[_T | None]: + def get_lazy_instance( + self, + cls: type[_T], + cache: bool = ..., + ) -> Invertible[_T | None]: """ Function used to retrieve an instance associated with the type passed in - parameter or `None`. Return a `Lazy` object. To access the instance contained - in a lazy object, simply use a wavy line (~). + parameter or `None`. Return a `Invertible` object. To access the instance + contained in an invertible object, simply use a wavy line (~). + With `cache=True`, the instance retrieved will always be the same. Example: instance = ~lazy_instance """ diff --git a/injection/common/invertible.py b/injection/common/invertible.py new file mode 100644 index 0000000..2d96aaa --- /dev/null +++ b/injection/common/invertible.py @@ -0,0 +1,23 @@ +from abc import abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from typing import Protocol, TypeVar, runtime_checkable + +__all__ = ("Invertible", "SimpleInvertible") + +_T = TypeVar("_T") + + +@runtime_checkable +class Invertible(Protocol[_T]): + @abstractmethod + def __invert__(self) -> _T: + raise NotImplementedError + + +@dataclass(repr=False, eq=False, frozen=True, slots=True) +class SimpleInvertible(Invertible[_T]): + callable: Callable[[], _T] + + def __invert__(self) -> _T: + return self.callable() diff --git a/injection/common/lazy.py b/injection/common/lazy.py index aa093b0..9173495 100644 --- a/injection/common/lazy.py +++ b/injection/common/lazy.py @@ -1,6 +1,8 @@ from collections.abc import Callable, Iterator, Mapping from types import MappingProxyType -from typing import Generic, TypeVar +from typing import TypeVar + +from injection.common.invertible import Invertible __all__ = ("Lazy", "LazyMapping") @@ -9,7 +11,7 @@ _V = TypeVar("_V") -class Lazy(Generic[_T]): +class Lazy(Invertible[_T]): __slots__ = ("__cache", "__is_set") def __init__(self, factory: Callable[[], _T]): @@ -23,7 +25,7 @@ def is_set(self) -> bool: return self.__is_set def __setup_cache(self, factory: Callable[[], _T]): - def new_cache() -> Iterator[_T]: + def cache_generator() -> Iterator[_T]: nonlocal factory cached = factory() self.__is_set = True @@ -32,7 +34,7 @@ def new_cache() -> Iterator[_T]: while True: yield cached - self.__cache = new_cache() + self.__cache = cache_generator() self.__is_set = False diff --git a/injection/common/queue.py b/injection/common/queue.py index 66863eb..ca95eaa 100644 --- a/injection/common/queue.py +++ b/injection/common/queue.py @@ -32,7 +32,7 @@ def add(self, item: _T): return self -class NoQueue(Queue[_T]): +class DeadQueue(Queue[_T]): __slots__ = () def __bool__(self) -> bool: @@ -42,23 +42,22 @@ def __next__(self) -> NoReturn: raise StopIteration def add(self, item: _T) -> NoReturn: - raise TypeError("Queue doesn't exist.") + raise TypeError("Queue is dead.") @dataclass(repr=False, slots=True) class LimitedQueue(Queue[_T]): - __queue: Queue[_T] = field(default_factory=SimpleQueue) + __state: Queue[_T] = field(default_factory=SimpleQueue) def __next__(self) -> _T: - if not self.__queue: - raise StopIteration - try: - return next(self.__queue) + return next(self.__state) except StopIteration as exc: - self.__queue = NoQueue() + if self.__state: + self.__state = DeadQueue() + raise exc def add(self, item: _T): - self.__queue.add(item) + self.__state.add(item) return self diff --git a/injection/common/tools/threading.py b/injection/common/tools/threading.py index ed02082..f62709f 100644 --- a/injection/common/tools/threading.py +++ b/injection/common/tools/threading.py @@ -4,10 +4,10 @@ __all__ = ("synchronized",) -__thread_lock = RLock() +__lock = RLock() @contextmanager def synchronized() -> ContextManager | ContextDecorator: - with __thread_lock: + with __lock: yield diff --git a/injection/core/module.py b/injection/core/module.py index 0ee0d40..7b638c0 100644 --- a/injection/core/module.py +++ b/injection/core/module.py @@ -37,6 +37,7 @@ ) from injection.common.event import Event, EventChannel, EventListener +from injection.common.invertible import Invertible, SimpleInvertible from injection.common.lazy import Lazy, LazyMapping from injection.common.queue import LimitedQueue from injection.common.tools.threading import synchronized @@ -157,6 +158,13 @@ def get_instance(self) -> _T: raise NotImplementedError +class FallbackInjectable(Injectable[_T], ABC): + __slots__ = () + + def __bool__(self) -> bool: + return False + + @dataclass(repr=False, frozen=True, slots=True) class BaseInjectable(Injectable[_T], ABC): factory: Callable[[], _T] @@ -197,12 +205,9 @@ def get_instance(self) -> _T: @dataclass(repr=False, frozen=True, slots=True) -class ShouldBeInjectable(Injectable[_T]): +class ShouldBeInjectable(FallbackInjectable[_T]): cls: type[_T] - def __bool__(self) -> bool: - return False - def get_instance(self) -> NoReturn: raise InjectionError(f"`{format_type(self.cls)}` should be an injectable.") @@ -260,28 +265,28 @@ def is_locked(self) -> bool: @property def __classes(self) -> frozenset[type]: - return frozenset(self.__data.keys()) + return frozenset(self.__data) @property def __injectables(self) -> frozenset[Injectable]: return frozenset(self.__data.values()) + @synchronized() def update(self, classes: Iterable[type], injectable: Injectable, override: bool): classes = frozenset(get_origins(*classes)) - with synchronized(): - if not injectable: - classes -= self.__classes - override = True + if not injectable: + classes -= self.__classes + override = True - if classes: - event = ContainerDependenciesUpdated(self, classes, override) + if classes: + event = ContainerDependenciesUpdated(self, classes, override) - with self.notify(event): - if not override: - self.__check_if_exists(classes) + with self.notify(event): + if not override: + self.__check_if_exists(classes) - self.__data.update((cls, injectable) for cls in classes) + self.__data.update((cls, injectable) for cls in classes) return self @@ -290,6 +295,8 @@ def unlock(self): for injectable in self.__injectables: injectable.unlock() + return self + def add_listener(self, listener: EventListener): self.__channel.add_listener(listener) return self @@ -438,8 +445,17 @@ def get_instance(self, cls: type[_T], none: bool = True) -> _T | None: instance = injectable.get_instance() return cast(cls, instance) - def get_lazy_instance(self, cls: type[_T]) -> Lazy[_T | None]: - return Lazy(lambda: self.get_instance(cls)) + def get_lazy_instance( + self, + cls: type[_T], + cache: bool = False, + ) -> Invertible[_T | None]: + if cache: + return Lazy(lambda: self.get_instance(cls)) + + function = self.inject(lambda instance=None: instance) + function.set_owner(cls) + return SimpleInvertible(function) def update( self, @@ -503,6 +519,8 @@ def unlock(self): for broker in self.__brokers: broker.unlock() + return self + def add_listener(self, listener: EventListener): self.__channel.add_listener(listener) return self @@ -661,15 +679,7 @@ def __get__(self, instance: object = None, owner: type = None): return self.__wrapper.__get__(instance, owner) def __set_name__(self, owner: type, name: str): - if self.__dependencies.are_resolved: - raise TypeError( - "Function owner must be assigned before dependencies are resolved." - ) - - if self.__owner: - raise TypeError("Function owner is already defined.") - - self.__owner = owner + self.set_owner(owner) @property def signature(self) -> Signature: @@ -692,14 +702,21 @@ def bind( ) return Arguments(bound.args, bound.kwargs) - def update(self, module: Module): - with synchronized(): - self.__dependencies = Dependencies.resolve( - self.signature, - module, - self.__owner, + def set_owner(self, owner: type): + if self.__dependencies.are_resolved: + raise TypeError( + "Function owner must be assigned before dependencies are resolved." ) + if self.__owner: + raise TypeError("Function owner is already defined.") + + self.__owner = owner + return self + + @synchronized() + def update(self, module: Module): + self.__dependencies = Dependencies.resolve(self.signature, module, self.__owner) return self def on_setup(self, wrapped: Callable[[], Any] = None, /): diff --git a/tests/core/test_module.py b/tests/core/test_module.py index 3e29b36..c6655da 100644 --- a/tests/core/test_module.py +++ b/tests/core/test_module.py @@ -109,18 +109,32 @@ def test_get_instance_with_empty_annotated_return_none(self, module): """ def test_get_lazy_instance_with_success_return_lazy_instance(self, module): - module[SomeClass] = self.get_test_injectable(SomeClass()) + @module.injectable + class A: + pass - lazy_instance = module.get_lazy_instance(SomeClass) - assert not lazy_instance.is_set - assert isinstance(~lazy_instance, SomeClass) - assert lazy_instance.is_set + lazy_instance = module.get_lazy_instance(A) + instance1 = ~lazy_instance + instance2 = ~lazy_instance + assert isinstance(instance1, A) + assert isinstance(instance2, A) + assert instance1 is not instance2 + + def test_get_lazy_instance_with_cache_return_lazy_instance(self, module): + @module.injectable + class A: + pass + + lazy_instance = module.get_lazy_instance(A, cache=True) + instance1 = ~lazy_instance + instance2 = ~lazy_instance + assert isinstance(instance1, A) + assert isinstance(instance2, A) + assert instance1 is instance2 def test_get_lazy_instance_with_no_injectable_return_lazy_none(self, module): lazy_instance = module.get_lazy_instance(SomeClass) - assert not lazy_instance.is_set assert ~lazy_instance is None - assert lazy_instance.is_set """ set_constant