diff --git a/injection/_core/module.py b/injection/_core/module.py index ffa688a..04c291c 100644 --- a/injection/_core/module.py +++ b/injection/_core/module.py @@ -156,6 +156,12 @@ def __str__(self) -> str: ) +@dataclass(frozen=True, slots=True) +class UnlockCalled(Event): + def __str__(self) -> str: + return "An `unlock` method has been called." + + """ Broker """ @@ -802,8 +808,11 @@ def change_priority(self, module: Module, priority: Priority | PriorityStr) -> S return self def unlock(self) -> Self: - for broker in self.__brokers: - broker.unlock() + event = UnlockCalled() + + with self.dispatch(event, lock_bypass=True): + for broker in self.__brokers: + broker.unlock() return self @@ -838,20 +847,20 @@ def remove_listener(self, listener: EventListener) -> Self: self.__channel.remove_listener(listener) return self - def on_event(self, event: Event, /) -> ContextManager[None] | None: + def on_event(self, event: Event, /) -> ContextManager[None]: self_event = ModuleEventProxy(self, event) return self.dispatch(self_event) @contextmanager - def dispatch(self, event: Event) -> Iterator[None]: - self.__check_locking() + def dispatch(self, event: Event, *, lock_bypass: bool = False) -> Iterator[None]: + if not lock_bypass: + self.__check_locking() with self.__channel.dispatch(event): try: yield finally: - message = str(event) - self.__debug(message) + self.__debug(event) def __debug(self, message: object) -> None: for logger in self.__loggers: diff --git a/tests/core/test_module.py b/tests/core/test_module.py index 532c422..eb15436 100644 --- a/tests/core/test_module.py +++ b/tests/core/test_module.py @@ -405,6 +405,22 @@ class C(A): ... assert isinstance(b1.a, A) assert isinstance(b2.a, C) + def test_unlock_with_module_in_use_raise_module_lock_error(self, module): + second_module = Module() + module.use(second_module) + + @module.singleton + class A: ... + + @second_module.singleton + class B: ... + + module.get_instance(A) + second_module.get_instance(B) + + with pytest.raises(ModuleLockError): + second_module.unlock() + def test_unlock_with_scoped_dependency(self, module): @module.scoped("test") class Dependency: ...