diff --git a/injection/__init__.pyi b/injection/__init__.pyi index 8201633..379c341 100644 --- a/injection/__init__.pyi +++ b/injection/__init__.pyi @@ -125,6 +125,11 @@ class Module: that no dependencies are resolved, so the module doesn't need to be locked. """ + def make_injected_function[**P, T]( + self, + wrapped: Callable[P, T], + /, + ) -> Callable[P, T]: ... def find_instance[T](self, cls: _InputType[T]) -> T: """ Function used to retrieve an instance associated with the type passed in diff --git a/injection/_core/module.py b/injection/_core/module.py index 76f04f6..a00bbbb 100644 --- a/injection/_core/module.py +++ b/injection/_core/module.py @@ -477,7 +477,7 @@ def injectable[**P, T]( # type: ignore[no-untyped-def] mode: Mode | ModeStr = Mode.get_default(), ): def decorator(wp): # type: ignore[no-untyped-def] - factory = self.inject(wp, return_factory=True) if inject else wp + factory = self.make_injected_function(wp) if inject else wp classes = get_return_types(wp, on) updater = Updater( factory=factory, @@ -544,28 +544,29 @@ def set_constant[T]( ) return self - def inject[**P, T]( # type: ignore[no-untyped-def] - self, - wrapped: Callable[P, T] | None = None, - /, - *, - return_factory: bool = False, - ): + def inject[**P, T](self, wrapped: Callable[P, T] | None = None, /): # type: ignore[no-untyped-def] def decorator(wp): # type: ignore[no-untyped-def] - if not return_factory and isclass(wp): + if isclass(wp): wp.__init__ = self.inject(wp.__init__) return wp - injected = Injected(wp) + return self.make_injected_function(wp) - @injected.on_setup - def listen() -> None: - injected.update(self) - self.add_listener(injected) + return decorator(wrapped) if wrapped else decorator - return InjectedFunction(injected) + def make_injected_function[**P, T]( + self, + wrapped: Callable[P, T], + /, + ) -> InjectedFunction[P, T]: + injected = Injected(wrapped) - return decorator(wrapped) if wrapped else decorator + @injected.on_setup + def listen() -> None: + injected.update(self) + self.add_listener(injected) + + return InjectedFunction(injected) def find_instance[T](self, cls: InputType[T]) -> T: injectable = self[cls]