From f86f916ceeb55136d6df56b5a9a115e5cc7ae7b1 Mon Sep 17 00:00:00 2001 From: remimd Date: Tue, 13 Feb 2024 12:24:21 +0100 Subject: [PATCH] =?UTF-8?q?feat:=20=E2=9C=A8=20`override`=20parameter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- documentation/basic-usage.md | 14 ++++++++++ injection/_pkg.pyi | 6 ++++- injection/core/module.py | 50 +++++++++++++++++++++--------------- tests/test_injectable.py | 12 +++++++++ tests/test_singleton.py | 12 +++++++++ 5 files changed, 72 insertions(+), 22 deletions(-) diff --git a/documentation/basic-usage.md b/documentation/basic-usage.md index 0234a47..e6a31d6 100644 --- a/documentation/basic-usage.md +++ b/documentation/basic-usage.md @@ -120,6 +120,20 @@ class C(B): ... ``` +If a class is registered in a package and you want to override it, there is the `override` parameter: + +```python +@singleton +class A: + ... + +# ... + +@singleton(on=A, override=True) +class B(A): + ... +``` + ## Recipes A recipe is a function that tells the injector how to construct the instance to be injected. It is important to specify diff --git a/injection/_pkg.pyi b/injection/_pkg.pyi index 4e23afb..9e54da7 100644 --- a/injection/_pkg.pyi +++ b/injection/_pkg.pyi @@ -60,6 +60,7 @@ class Module: *, cls: type[Injectable] = ..., on: type | Iterable[type] | UnionType = ..., + override: bool = ..., ): """ Decorator applicable to a class or function. It is used to indicate how the @@ -73,6 +74,7 @@ class Module: /, *, on: type | Iterable[type] | UnionType = ..., + override: bool = ..., ): """ Decorator applicable to a class or function. It is used to indicate how the @@ -84,6 +86,8 @@ class Module: self, instance: _T, on: type | Iterable[type] | UnionType = ..., + *, + override: bool = ..., ) -> _T: """ Function for registering a specific instance to be injected. This is useful for @@ -149,7 +153,7 @@ class ModulePriorities(Enum): @runtime_checkable class Injectable(Protocol[_T]): - def __init__(self, factory: Callable[[], _T] = ..., *args, **kwargs): ... + def __init__(self, factory: Callable[[], _T] = ..., /): ... @property def is_locked(self) -> bool: ... def unlock(self): ... diff --git a/injection/core/module.py b/injection/core/module.py index c5110da..dd06ed9 100644 --- a/injection/core/module.py +++ b/injection/core/module.py @@ -61,6 +61,7 @@ class ContainerEvent(Event, ABC): @dataclass(frozen=True, slots=True) class ContainerDependenciesUpdated(ContainerEvent): classes: Collection[type] + override: bool def __str__(self) -> str: length = len(self.classes) @@ -212,26 +213,19 @@ def is_locked(self) -> bool: def __injectables(self) -> frozenset[Injectable]: return frozenset(self.__data.values()) - def update(self, classes: Types, injectable: Injectable): - classes = frozenset(get_origins(*classes)) + def update(self, classes: Types, injectable: Injectable, override: bool): + values = {origin: injectable for origin in get_origins(*classes)} - if classes: - event = ContainerDependenciesUpdated(self, classes) + if values: + event = ContainerDependenciesUpdated(self, values, override) with self.notify(event): - self.__data.update( - (self.check_if_exists(cls), injectable) for cls in classes - ) - - return self + if not override: + self.__check_if_exists(*values) - def check_if_exists(self, cls: type) -> type: - if cls in self.__data: - raise RuntimeError( - f"An injectable already exists for the class `{format_type(cls)}`." - ) + self.__data.update(values) - return cls + return self def unlock(self): for injectable in self.__injectables: @@ -244,6 +238,13 @@ def add_listener(self, listener: EventListener): def notify(self, event: Event) -> ContextManager | ContextDecorator: return self.__channel.dispatch(event) + def __check_if_exists(self, *classes: type): + for cls in classes: + if cls in self.__data: + raise RuntimeError( + f"An injectable already exists for the class `{format_type(cls)}`." + ) + """ Module @@ -280,7 +281,7 @@ def __getitem__(self, cls: type[_T] | UnionType, /) -> Injectable[_T]: raise NoInjectable(cls) def __setitem__(self, cls: type | UnionType, injectable: Injectable, /): - self.update((cls,), injectable) + self.update((cls,), injectable, override=True) def __contains__(self, cls: type | UnionType, /) -> bool: return any(cls in broker for broker in self.__brokers) @@ -304,22 +305,29 @@ def injectable( *, cls: type[Injectable] = NewInjectable, on: type | Types = None, + override: bool = False, ): def decorator(wp): factory = self.inject(wp, return_factory=True) injectable = cls(factory) classes = find_types(wp, on) - self.update(classes, injectable) + self.update(classes, injectable, override) return wp return decorator(wrapped) if wrapped else decorator singleton = partialmethod(injectable, cls=SingletonInjectable) - def set_constant(self, instance: _T, on: type | Types = None) -> _T: + def set_constant( + self, + instance: _T, + on: type | Types = None, + *, + override: bool = False, + ) -> _T: cls = type(instance) - @self.injectable(on=(cls, on)) + @self.injectable(on=(cls, on), override=override) def get_constant(): return instance @@ -364,8 +372,8 @@ def get_instance(self, cls: type[_T], none: bool = True) -> _T | None: def get_lazy_instance(self, cls: type[_T]) -> Lazy[_T | None]: return Lazy(lambda: self.get_instance(cls)) - def update(self, classes: Types, injectable: Injectable): - self.__container.update(classes, injectable) + def update(self, classes: Types, injectable: Injectable, override: bool = False): + self.__container.update(classes, injectable, override) return self def use( diff --git a/tests/test_injectable.py b/tests/test_injectable.py index 3a0fff0..46704b3 100644 --- a/tests/test_injectable.py +++ b/tests/test_injectable.py @@ -166,3 +166,15 @@ class B(A): @injectable(on=A) class C(A): pass + + def test_injectable_with_override(self): + @injectable + class A: + pass + + @injectable(on=A, override=True) + class B(A): + pass + + a = get_instance(A) + assert isinstance(a, B) diff --git a/tests/test_singleton.py b/tests/test_singleton.py index e012e22..bada834 100644 --- a/tests/test_singleton.py +++ b/tests/test_singleton.py @@ -164,3 +164,15 @@ class B(A): @singleton(on=A) class C(A): pass + + def test_injectable_with_override(self): + @singleton + class A: + pass + + @singleton(on=A, override=True) + class B(A): + pass + + a = get_instance(A) + assert isinstance(a, B)