Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions documentation/basic-usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,20 @@ class C(B):
...
```

If a class is registered in a package and you want to override it, there is the `override` parameter:

```python
@singleton
class A:
...

# ...

@singleton(on=A, override=True)
class B(A):
...
```

## Recipes

A recipe is a function that tells the injector how to construct the instance to be injected. It is important to specify
Expand Down
6 changes: 5 additions & 1 deletion injection/_pkg.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class Module:
*,
cls: type[Injectable] = ...,
on: type | Iterable[type] | UnionType = ...,
override: bool = ...,
):
"""
Decorator applicable to a class or function. It is used to indicate how the
Expand All @@ -73,6 +74,7 @@ class Module:
/,
*,
on: type | Iterable[type] | UnionType = ...,
override: bool = ...,
):
"""
Decorator applicable to a class or function. It is used to indicate how the
Expand All @@ -84,6 +86,8 @@ class Module:
self,
instance: _T,
on: type | Iterable[type] | UnionType = ...,
*,
override: bool = ...,
) -> _T:
"""
Function for registering a specific instance to be injected. This is useful for
Expand Down Expand Up @@ -149,7 +153,7 @@ class ModulePriorities(Enum):

@runtime_checkable
class Injectable(Protocol[_T]):
def __init__(self, factory: Callable[[], _T] = ..., *args, **kwargs): ...
def __init__(self, factory: Callable[[], _T] = ..., /): ...
@property
def is_locked(self) -> bool: ...
def unlock(self): ...
Expand Down
50 changes: 29 additions & 21 deletions injection/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class ContainerEvent(Event, ABC):
@dataclass(frozen=True, slots=True)
class ContainerDependenciesUpdated(ContainerEvent):
classes: Collection[type]
override: bool

def __str__(self) -> str:
length = len(self.classes)
Expand Down Expand Up @@ -212,26 +213,19 @@ def is_locked(self) -> bool:
def __injectables(self) -> frozenset[Injectable]:
return frozenset(self.__data.values())

def update(self, classes: Types, injectable: Injectable):
classes = frozenset(get_origins(*classes))
def update(self, classes: Types, injectable: Injectable, override: bool):
values = {origin: injectable for origin in get_origins(*classes)}

if classes:
event = ContainerDependenciesUpdated(self, classes)
if values:
event = ContainerDependenciesUpdated(self, values, override)

with self.notify(event):
self.__data.update(
(self.check_if_exists(cls), injectable) for cls in classes
)

return self
if not override:
self.__check_if_exists(*values)

def check_if_exists(self, cls: type) -> type:
if cls in self.__data:
raise RuntimeError(
f"An injectable already exists for the class `{format_type(cls)}`."
)
self.__data.update(values)

return cls
return self

def unlock(self):
for injectable in self.__injectables:
Expand All @@ -244,6 +238,13 @@ def add_listener(self, listener: EventListener):
def notify(self, event: Event) -> ContextManager | ContextDecorator:
return self.__channel.dispatch(event)

def __check_if_exists(self, *classes: type):
for cls in classes:
if cls in self.__data:
raise RuntimeError(
f"An injectable already exists for the class `{format_type(cls)}`."
)


"""
Module
Expand Down Expand Up @@ -280,7 +281,7 @@ def __getitem__(self, cls: type[_T] | UnionType, /) -> Injectable[_T]:
raise NoInjectable(cls)

def __setitem__(self, cls: type | UnionType, injectable: Injectable, /):
self.update((cls,), injectable)
self.update((cls,), injectable, override=True)

def __contains__(self, cls: type | UnionType, /) -> bool:
return any(cls in broker for broker in self.__brokers)
Expand All @@ -304,22 +305,29 @@ def injectable(
*,
cls: type[Injectable] = NewInjectable,
on: type | Types = None,
override: bool = False,
):
def decorator(wp):
factory = self.inject(wp, return_factory=True)
injectable = cls(factory)
classes = find_types(wp, on)
self.update(classes, injectable)
self.update(classes, injectable, override)
return wp

return decorator(wrapped) if wrapped else decorator

singleton = partialmethod(injectable, cls=SingletonInjectable)

def set_constant(self, instance: _T, on: type | Types = None) -> _T:
def set_constant(
self,
instance: _T,
on: type | Types = None,
*,
override: bool = False,
) -> _T:
cls = type(instance)

@self.injectable(on=(cls, on))
@self.injectable(on=(cls, on), override=override)
def get_constant():
return instance

Expand Down Expand Up @@ -364,8 +372,8 @@ def get_instance(self, cls: type[_T], none: bool = True) -> _T | None:
def get_lazy_instance(self, cls: type[_T]) -> Lazy[_T | None]:
return Lazy(lambda: self.get_instance(cls))

def update(self, classes: Types, injectable: Injectable):
self.__container.update(classes, injectable)
def update(self, classes: Types, injectable: Injectable, override: bool = False):
self.__container.update(classes, injectable, override)
return self

def use(
Expand Down
12 changes: 12 additions & 0 deletions tests/test_injectable.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,15 @@ class B(A):
@injectable(on=A)
class C(A):
pass

def test_injectable_with_override(self):
@injectable
class A:
pass

@injectable(on=A, override=True)
class B(A):
pass

a = get_instance(A)
assert isinstance(a, B)
12 changes: 12 additions & 0 deletions tests/test_singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,15 @@ class B(A):
@singleton(on=A)
class C(A):
pass

def test_injectable_with_override(self):
@singleton
class A:
pass

@singleton(on=A, override=True)
class B(A):
pass

a = get_instance(A)
assert isinstance(a, B)