Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions injection/_pkg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .core import Module, ModulePriorities
from .core import Injectable, Module, ModulePriorities

__all__ = (
"Injectable",
"Module",
"ModulePriorities",
"default_module",
Expand All @@ -16,9 +17,7 @@

get_instance = default_module.get_instance
get_lazy_instance = default_module.get_lazy_instance

inject = default_module.inject
injectable = default_module.injectable
singleton = default_module.singleton

set_constant = default_module.set_constant
singleton = default_module.singleton
25 changes: 21 additions & 4 deletions injection/_pkg.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from abc import abstractmethod
from collections.abc import Callable, Iterable
from contextlib import ContextDecorator
from enum import Enum
from types import UnionType
from typing import Any, ContextManager, Final, TypeVar, final
from typing import (
Any,
ContextManager,
Final,
Protocol,
TypeVar,
final,
runtime_checkable,
)

from injection.common.lazy import Lazy

Expand All @@ -12,12 +21,10 @@ default_module: Final[Module] = ...

get_instance = default_module.get_instance
get_lazy_instance = default_module.get_lazy_instance

inject = default_module.inject
injectable = default_module.injectable
singleton = default_module.singleton

set_constant = default_module.set_constant
singleton = default_module.singleton

@final
class Module:
Expand All @@ -42,6 +49,7 @@ class Module:
wrapped: Callable[..., Any] = ...,
/,
*,
cls: type[Injectable] = ...,
on: type | Iterable[type] | UnionType = ...,
):
"""
Expand Down Expand Up @@ -119,3 +127,12 @@ class Module:
class ModulePriorities(Enum):
HIGH = ...
LOW = ...

@runtime_checkable
class Injectable(Protocol[_T]):
def __init__(self, factory: Callable[[], _T] = ..., *args, **kwargs): ...
@property
def is_locked(self) -> bool: ...
def unlock(self): ...
@abstractmethod
def get_instance(self) -> _T: ...
42 changes: 27 additions & 15 deletions injection/common/tools/_type.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from collections.abc import Iterator
from collections.abc import Iterable, Iterator
from inspect import get_annotations, isfunction
from types import NoneType, UnionType
from typing import Annotated, Any, Union, get_args, get_origin

__all__ = ("format_type", "get_origins")
__all__ = ("find_types", "format_type", "get_origins")


def format_type(cls: type | Any) -> str:
Expand All @@ -12,25 +13,36 @@ def format_type(cls: type | Any) -> str:
return str(cls)


def get_origins(*classes: type | Any) -> Iterator[type | Any]:
for cls in classes:
origin = get_origin(cls) or cls
def get_origins(*types: type | Any) -> Iterator[type | Any]:
for tp in types:
origin = get_origin(tp) or tp

if origin in (None, NoneType):
continue

arguments = get_args(cls)
elif origin in (Union, UnionType):
args = get_args(tp)

if origin in (Union, UnionType):
yield from get_origins(*arguments)
elif origin is Annotated is not tp:
args = (tp.__origin__,)

else:
yield origin
continue

yield from get_origins(*args)

elif origin is Annotated:
try:
annotated = arguments[0]
except IndexError:
continue

yield from get_origins(annotated)
def find_types(*args: Any) -> Iterator[type | UnionType]:
for argument in args:
if isinstance(argument, Iterable) and not isinstance(argument, type | str):
arguments = argument

elif isfunction(argument):
arguments = (get_annotations(argument, eval_str=True).get("return"),)

else:
yield origin
yield argument
continue

yield from find_types(*arguments)
152 changes: 54 additions & 98 deletions injection/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from contextlib import ContextDecorator, contextmanager, suppress
from dataclasses import dataclass, field
from enum import Enum, auto
from functools import singledispatchmethod, wraps
from inspect import Signature, get_annotations, isclass, isfunction
from functools import partialmethod, singledispatchmethod, wraps
from inspect import Signature, isclass
from threading import RLock
from types import MappingProxyType, UnionType
from typing import (
Expand All @@ -26,13 +26,12 @@
Protocol,
TypeVar,
cast,
final,
runtime_checkable,
)

from injection.common.event import Event, EventChannel, EventListener
from injection.common.lazy import Lazy, LazyMapping
from injection.common.tools import format_type, get_origins
from injection.common.tools import find_types, format_type, get_origins
from injection.exceptions import (
ModuleError,
ModuleLockError,
Expand Down Expand Up @@ -133,6 +132,9 @@ def __str__(self) -> str:
class Injectable(Protocol[_T]):
__slots__ = ()

def __init__(self, factory: Callable[[], _T] = ..., *args, **kwargs):
...

@property
def is_locked(self) -> bool:
return False
Expand Down Expand Up @@ -288,18 +290,6 @@ def __contains__(self, cls: type | UnionType, /) -> bool:
def __str__(self) -> str:
return self.name or object.__str__(self)

@property
def inject(self) -> InjectDecorator:
return InjectDecorator(self)

@property
def injectable(self) -> InjectableDecorator:
return InjectableDecorator(self, NewInjectable)

@property
def singleton(self) -> InjectableDecorator:
return InjectableDecorator(self, SingletonInjectable)

@property
def is_locked(self) -> bool:
return any(broker.is_locked for broker in self.__brokers)
Expand All @@ -309,6 +299,25 @@ def __brokers(self) -> Iterator[Container | Module]:
yield from tuple(self.__modules)
yield self.__container

def injectable(
self,
wrapped: Callable[..., Any] = None,
/,
*,
cls: type[Injectable] = NewInjectable,
on: type | Types = None,
):
def decorator(wp):
factory = self.inject(wp, return_factory=True)
injectable = cls(factory)
classes = find_types(wp, on)
self.update(classes, injectable)
return wp

return decorator(wrapped) if wrapped else decorator

singleton = partialmethod(injectable, cls=SingletonInjectable)

def set_constant(self, instance: _T, on: type | Types = None) -> _T:
cls = type(instance)

Expand All @@ -318,6 +327,29 @@ def get_constant():

return instance

def inject(
self,
wrapped: Callable[..., Any] = None,
/,
*,
return_factory: bool = False,
):
def decorator(wp):
if not return_factory and isclass(wp):
wp.__init__ = decorator(wp.__init__)
return wp

lazy_binder = Lazy[Binder](lambda: self.__new_binder(wp))

@wraps(wp)
def wrapper(*args, **kwargs):
arguments = (~lazy_binder).bind(*args, **kwargs)
return wp(*arguments.args, **arguments.kwargs)

return wrapper

return decorator(wrapped) if wrapped else decorator

def get_instance(self, cls: type[_T]) -> _T | None:
try:
injectable = self[cls]
Expand Down Expand Up @@ -420,6 +452,12 @@ def __move_module(self, module: Module, priority: ModulePriorities):
f"`{module}` can't be found in the modules used by `{self}`."
) from exc

def __new_binder(self, target: Callable[..., Any]) -> Binder:
signature = inspect.signature(target, eval_str=True)
binder = Binder(signature).update(self)
self.add_listener(binder)
return binder


"""
Binder
Expand Down Expand Up @@ -502,85 +540,3 @@ def on_event(self, event: Event, /):
def _(self, event: ModuleEvent, /) -> ContextManager:
yield
self.update(event.on_module)


"""
Decorators
"""


@final
@dataclass(repr=False, frozen=True, slots=True)
class InjectDecorator:
__module: Module

def __call__(self, wrapped: Callable[..., Any] = None, /):
def decorator(wp):
if isclass(wp):
return self.__class_decorator(wp)

return self.__decorator(wp)

return decorator(wrapped) if wrapped else decorator

def __decorator(self, function: Callable[..., Any], /) -> Callable[..., Any]:
lazy_binder = Lazy[Binder](lambda: self.__new_binder(function))

@wraps(function)
def wrapper(*args, **kwargs):
arguments = (~lazy_binder).bind(*args, **kwargs)
return function(*arguments.args, **arguments.kwargs)

return wrapper

def __class_decorator(self, cls: type, /) -> type:
cls.__init__ = self.__decorator(cls.__init__)
return cls

def __new_binder(self, function: Callable[..., Any]) -> Binder:
signature = inspect.signature(function, eval_str=True)
binder = Binder(signature).update(self.__module)
self.__module.add_listener(binder)
return binder


@final
@dataclass(repr=False, frozen=True, slots=True)
class InjectableDecorator:
__module: Module
__injectable_type: type[BaseInjectable]

def __repr__(self) -> str:
return f"<{self.__injectable_type.__qualname__} decorator>"

def __call__(
self,
wrapped: Callable[..., Any] = None,
/,
*,
on: type | Types = None,
):
def decorator(wp):
@self.__module.inject
@wraps(wp, updated=())
def factory(*args, **kwargs):
return wp(*args, **kwargs)

injectable = self.__injectable_type(factory)
classes = self.__get_classes(wp, on)
self.__module.update(classes, injectable)
return wp

return decorator(wrapped) if wrapped else decorator

@classmethod
def __get_classes(cls, *objects: Any) -> Iterator[type | UnionType]:
for obj in objects:
if isinstance(obj, Iterable) and not isinstance(obj, type | str):
yield from cls.__get_classes(*obj)

elif isfunction(obj):
yield get_annotations(obj, eval_str=True).get("return")

else:
yield obj
2 changes: 1 addition & 1 deletion tests/core/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_instance(self) -> class_:
__getitem__
"""

def test_getitem_with_success_injectable(self, module):
def test_getitem_with_success_return_injectable(self, module):
injectable_w = self.get_test_injectable(SomeClass())
module[SomeClass] = injectable_w
assert module[SomeClass] is injectable_w
Expand Down