Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions injection/_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"inject",
"injectable",
"set_constant",
"should_be_injectable",
"singleton",
)

Expand All @@ -20,4 +21,5 @@
inject = default_module.inject
injectable = default_module.injectable
set_constant = default_module.set_constant
should_be_injectable = default_module.should_be_injectable
singleton = default_module.singleton
17 changes: 17 additions & 0 deletions injection/_pkg.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ get_lazy_instance = default_module.get_lazy_instance
inject = default_module.inject
injectable = default_module.injectable
set_constant = default_module.set_constant
should_be_injectable = default_module.should_be_injectable
singleton = default_module.singleton

@final
Expand Down Expand Up @@ -59,6 +60,7 @@ class Module:
/,
*,
cls: type[Injectable] = ...,
inject: bool = ...,
on: type | Iterable[type] | UnionType = ...,
override: bool = ...,
):
Expand All @@ -73,6 +75,7 @@ class Module:
wrapped: Callable[..., Any] = ...,
/,
*,
inject: bool = ...,
on: type | Iterable[type] | UnionType = ...,
override: bool = ...,
):
Expand All @@ -82,6 +85,20 @@ class Module:
always be the same.
"""

def should_be_injectable(
self,
wrapped: Callable[..., Any] = ...,
/,
*,
on: type | Iterable[type] | UnionType = ...,
override: bool = ...,
):
"""
Decorator applicable to a class. It is used to specify whether an injectable
should be registered. Raise an exception at injection time if the class isn't
registered.
"""

def set_constant(
self,
instance: _T,
Expand Down
75 changes: 69 additions & 6 deletions injection/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Any,
ContextManager,
NamedTuple,
NoReturn,
Protocol,
TypeVar,
cast,
Expand All @@ -33,6 +34,7 @@
from injection.common.lazy import Lazy, LazyMapping
from injection.common.tools import find_types, format_type, get_origins
from injection.exceptions import (
InjectionError,
ModuleError,
ModuleLockError,
ModuleNotUsedError,
Expand Down Expand Up @@ -185,13 +187,67 @@ def get_instance(self) -> _T:
return instance


class InjectableWarning(BaseInjectable[_T], ABC):
__slots__ = ()

def __bool__(self) -> bool:
return False

@property
def formatted_type(self) -> str:
return format_type(self.factory)

@property
@abstractmethod
def exception(self) -> BaseException:
raise NotImplementedError

def get_instance(self) -> NoReturn:
raise self.exception


class ShouldBeInjectable(InjectableWarning[_T]):
__slots__ = ()

@property
def exception(self) -> BaseException:
return InjectionError(f"`{self.formatted_type}` should be an injectable.")


"""
Broker
"""


@runtime_checkable
class Broker(Protocol):
__slots__ = ()

@abstractmethod
def __getitem__(self, cls: type[_T] | UnionType, /) -> Injectable[_T]:
raise NotImplementedError

@abstractmethod
def __contains__(self, cls: type | UnionType, /) -> bool:
raise NotImplementedError

@property
@abstractmethod
def is_locked(self) -> bool:
raise NotImplementedError

@abstractmethod
def unlock(self):
raise NotImplementedError


"""
Container
"""


@dataclass(repr=False, frozen=True, slots=True)
class Container:
class Container(Broker):
__data: dict[type, Injectable] = field(default_factory=dict, init=False)
__channel: EventChannel = field(default_factory=EventChannel, init=False)

Expand Down Expand Up @@ -242,7 +298,7 @@ def notify(self, event: Event) -> ContextManager | ContextDecorator:

def __check_if_exists(self, *classes: type):
for cls in classes:
if cls in self.__data:
if self.__data.get(cls):
raise RuntimeError(
f"An injectable already exists for the class `{format_type(cls)}`."
)
Expand All @@ -263,7 +319,7 @@ def get_default(cls):


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class Module(EventListener):
class Module(EventListener, Broker):
name: str = field(default=None)
__channel: EventChannel = field(default_factory=EventChannel, init=False)
__container: Container = field(default_factory=Container, init=False)
Expand Down Expand Up @@ -296,7 +352,7 @@ def is_locked(self) -> bool:
return any(broker.is_locked for broker in self.__brokers)

@property
def __brokers(self) -> Iterator[Container | Module]:
def __brokers(self) -> Iterator[Broker]:
yield from tuple(self.__modules)
yield self.__container

Expand All @@ -306,11 +362,12 @@ def injectable(
/,
*,
cls: type[Injectable] = NewInjectable,
inject: bool = True,
on: type | Types = None,
override: bool = False,
):
def decorator(wp):
factory = self.inject(wp, return_factory=True)
factory = self.inject(wp, return_factory=True) if inject else wp
injectable = cls(factory)
classes = find_types(wp, on)
self.update(classes, injectable, override)
Expand All @@ -319,6 +376,11 @@ def decorator(wp):
return decorator(wrapped) if wrapped else decorator

singleton = partialmethod(injectable, cls=SingletonInjectable)
should_be_injectable = partialmethod(
injectable,
cls=ShouldBeInjectable,
inject=False,
)

def set_constant(
self,
Expand All @@ -330,6 +392,7 @@ def set_constant(
cls = type(instance)
self.injectable(
lambda: instance,
inject=False,
on=(cls, on),
override=override,
)
Expand Down Expand Up @@ -366,7 +429,7 @@ def get_instance(self, cls: type[_T], none: bool = True) -> _T | None:
if none:
return None

raise exc from exc
raise exc

instance = injectable.get_instance()
return cast(cls, instance)
Expand Down
Loading