Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions injection/_pkg.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ from typing import (
runtime_checkable,
)

from injection.common.lazy import Lazy
from injection.common.invertible import Invertible

_T = TypeVar("_T")

Expand Down Expand Up @@ -104,11 +104,16 @@ class Module:
will be raised.
"""

def get_lazy_instance(self, cls: type[_T]) -> Lazy[_T | None]:
def get_lazy_instance(
self,
cls: type[_T],
cache: bool = ...,
) -> Invertible[_T | None]:
"""
Function used to retrieve an instance associated with the type passed in
parameter or `None`. Return a `Lazy` object. To access the instance contained
in a lazy object, simply use a wavy line (~).
parameter or `None`. Return a `Invertible` object. To access the instance
contained in an invertible object, simply use a wavy line (~).
With `cache=True`, the instance retrieved will always be the same.

Example: instance = ~lazy_instance
"""
Expand Down
23 changes: 23 additions & 0 deletions injection/common/invertible.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from abc import abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import Protocol, TypeVar, runtime_checkable

__all__ = ("Invertible", "SimpleInvertible")

_T = TypeVar("_T")


@runtime_checkable
class Invertible(Protocol[_T]):
@abstractmethod
def __invert__(self) -> _T:
raise NotImplementedError


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class SimpleInvertible(Invertible[_T]):
callable: Callable[[], _T]

def __invert__(self) -> _T:
return self.callable()
10 changes: 6 additions & 4 deletions injection/common/lazy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections.abc import Callable, Iterator, Mapping
from types import MappingProxyType
from typing import Generic, TypeVar
from typing import TypeVar

from injection.common.invertible import Invertible

__all__ = ("Lazy", "LazyMapping")

Expand All @@ -9,7 +11,7 @@
_V = TypeVar("_V")


class Lazy(Generic[_T]):
class Lazy(Invertible[_T]):
__slots__ = ("__cache", "__is_set")

def __init__(self, factory: Callable[[], _T]):
Expand All @@ -23,7 +25,7 @@ def is_set(self) -> bool:
return self.__is_set

def __setup_cache(self, factory: Callable[[], _T]):
def new_cache() -> Iterator[_T]:
def cache_generator() -> Iterator[_T]:
nonlocal factory
cached = factory()
self.__is_set = True
Expand All @@ -32,7 +34,7 @@ def new_cache() -> Iterator[_T]:
while True:
yield cached

self.__cache = new_cache()
self.__cache = cache_generator()
self.__is_set = False


Expand Down
17 changes: 8 additions & 9 deletions injection/common/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def add(self, item: _T):
return self


class NoQueue(Queue[_T]):
class DeadQueue(Queue[_T]):
__slots__ = ()

def __bool__(self) -> bool:
Expand All @@ -42,23 +42,22 @@ def __next__(self) -> NoReturn:
raise StopIteration

def add(self, item: _T) -> NoReturn:
raise TypeError("Queue doesn't exist.")
raise TypeError("Queue is dead.")


@dataclass(repr=False, slots=True)
class LimitedQueue(Queue[_T]):
__queue: Queue[_T] = field(default_factory=SimpleQueue)
__state: Queue[_T] = field(default_factory=SimpleQueue)

def __next__(self) -> _T:
if not self.__queue:
raise StopIteration

try:
return next(self.__queue)
return next(self.__state)
except StopIteration as exc:
self.__queue = NoQueue()
if self.__state:
self.__state = DeadQueue()

raise exc

def add(self, item: _T):
self.__queue.add(item)
self.__state.add(item)
return self
4 changes: 2 additions & 2 deletions injection/common/tools/threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

__all__ = ("synchronized",)

__thread_lock = RLock()
__lock = RLock()


@contextmanager
def synchronized() -> ContextManager | ContextDecorator:
with __thread_lock:
with __lock:
yield
81 changes: 49 additions & 32 deletions injection/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)

from injection.common.event import Event, EventChannel, EventListener
from injection.common.invertible import Invertible, SimpleInvertible
from injection.common.lazy import Lazy, LazyMapping
from injection.common.queue import LimitedQueue
from injection.common.tools.threading import synchronized
Expand Down Expand Up @@ -157,6 +158,13 @@ def get_instance(self) -> _T:
raise NotImplementedError


class FallbackInjectable(Injectable[_T], ABC):
__slots__ = ()

def __bool__(self) -> bool:
return False


@dataclass(repr=False, frozen=True, slots=True)
class BaseInjectable(Injectable[_T], ABC):
factory: Callable[[], _T]
Expand Down Expand Up @@ -197,12 +205,9 @@ def get_instance(self) -> _T:


@dataclass(repr=False, frozen=True, slots=True)
class ShouldBeInjectable(Injectable[_T]):
class ShouldBeInjectable(FallbackInjectable[_T]):
cls: type[_T]

def __bool__(self) -> bool:
return False

def get_instance(self) -> NoReturn:
raise InjectionError(f"`{format_type(self.cls)}` should be an injectable.")

Expand Down Expand Up @@ -260,28 +265,28 @@ def is_locked(self) -> bool:

@property
def __classes(self) -> frozenset[type]:
return frozenset(self.__data.keys())
return frozenset(self.__data)

@property
def __injectables(self) -> frozenset[Injectable]:
return frozenset(self.__data.values())

@synchronized()
def update(self, classes: Iterable[type], injectable: Injectable, override: bool):
classes = frozenset(get_origins(*classes))

with synchronized():
if not injectable:
classes -= self.__classes
override = True
if not injectable:
classes -= self.__classes
override = True

if classes:
event = ContainerDependenciesUpdated(self, classes, override)
if classes:
event = ContainerDependenciesUpdated(self, classes, override)

with self.notify(event):
if not override:
self.__check_if_exists(classes)
with self.notify(event):
if not override:
self.__check_if_exists(classes)

self.__data.update((cls, injectable) for cls in classes)
self.__data.update((cls, injectable) for cls in classes)

return self

Expand All @@ -290,6 +295,8 @@ def unlock(self):
for injectable in self.__injectables:
injectable.unlock()

return self

def add_listener(self, listener: EventListener):
self.__channel.add_listener(listener)
return self
Expand Down Expand Up @@ -438,8 +445,17 @@ def get_instance(self, cls: type[_T], none: bool = True) -> _T | None:
instance = injectable.get_instance()
return cast(cls, instance)

def get_lazy_instance(self, cls: type[_T]) -> Lazy[_T | None]:
return Lazy(lambda: self.get_instance(cls))
def get_lazy_instance(
self,
cls: type[_T],
cache: bool = False,
) -> Invertible[_T | None]:
if cache:
return Lazy(lambda: self.get_instance(cls))

function = self.inject(lambda instance=None: instance)
function.set_owner(cls)
return SimpleInvertible(function)

def update(
self,
Expand Down Expand Up @@ -503,6 +519,8 @@ def unlock(self):
for broker in self.__brokers:
broker.unlock()

return self

def add_listener(self, listener: EventListener):
self.__channel.add_listener(listener)
return self
Expand Down Expand Up @@ -661,15 +679,7 @@ def __get__(self, instance: object = None, owner: type = None):
return self.__wrapper.__get__(instance, owner)

def __set_name__(self, owner: type, name: str):
if self.__dependencies.are_resolved:
raise TypeError(
"Function owner must be assigned before dependencies are resolved."
)

if self.__owner:
raise TypeError("Function owner is already defined.")

self.__owner = owner
self.set_owner(owner)

@property
def signature(self) -> Signature:
Expand All @@ -692,14 +702,21 @@ def bind(
)
return Arguments(bound.args, bound.kwargs)

def update(self, module: Module):
with synchronized():
self.__dependencies = Dependencies.resolve(
self.signature,
module,
self.__owner,
def set_owner(self, owner: type):
if self.__dependencies.are_resolved:
raise TypeError(
"Function owner must be assigned before dependencies are resolved."
)

if self.__owner:
raise TypeError("Function owner is already defined.")

self.__owner = owner
return self

@synchronized()
def update(self, module: Module):
self.__dependencies = Dependencies.resolve(self.signature, module, self.__owner)
return self

def on_setup(self, wrapped: Callable[[], Any] = None, /):
Expand Down
28 changes: 21 additions & 7 deletions tests/core/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,32 @@ def test_get_instance_with_empty_annotated_return_none(self, module):
"""

def test_get_lazy_instance_with_success_return_lazy_instance(self, module):
module[SomeClass] = self.get_test_injectable(SomeClass())
@module.injectable
class A:
pass

lazy_instance = module.get_lazy_instance(SomeClass)
assert not lazy_instance.is_set
assert isinstance(~lazy_instance, SomeClass)
assert lazy_instance.is_set
lazy_instance = module.get_lazy_instance(A)
instance1 = ~lazy_instance
instance2 = ~lazy_instance
assert isinstance(instance1, A)
assert isinstance(instance2, A)
assert instance1 is not instance2

def test_get_lazy_instance_with_cache_return_lazy_instance(self, module):
@module.injectable
class A:
pass

lazy_instance = module.get_lazy_instance(A, cache=True)
instance1 = ~lazy_instance
instance2 = ~lazy_instance
assert isinstance(instance1, A)
assert isinstance(instance2, A)
assert instance1 is instance2

def test_get_lazy_instance_with_no_injectable_return_lazy_none(self, module):
lazy_instance = module.get_lazy_instance(SomeClass)
assert not lazy_instance.is_set
assert ~lazy_instance is None
assert lazy_instance.is_set

"""
set_constant
Expand Down