Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions injection/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def mod(name: str = ..., /) -> Module:
"""
Short syntax for `Module.from_name`.
"""

@runtime_checkable
class Injectable[T](Protocol):
@property
Expand Down Expand Up @@ -234,9 +235,8 @@ class Module:
) -> _Decorator: ...
def scoped(
self,
scope_name: str,
/,
*,
*scope_names: str,
ignore_type_hint: bool = ...,
inject: bool = ...,
on: _TypeInfo[Any] = ...,
Expand Down
20 changes: 11 additions & 9 deletions injection/_core/injectables.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable, MutableMapping
from collections.abc import Awaitable, Callable, MutableMapping, Sequence
from contextlib import suppress
from dataclasses import dataclass, field
from functools import partial
Expand All @@ -16,7 +16,7 @@

from injection._core.common.asynchronous import AsyncSemaphore, Caller
from injection._core.common.type import InputType
from injection._core.scope import Scope, get_scope, in_scope_cache
from injection._core.scope import Scope, get_first_scope, in_scope_cache
from injection._core.slots import SlotKey
from injection.exceptions import EmptySlotError, InjectionError

Expand Down Expand Up @@ -129,13 +129,13 @@ def get_instance(self) -> T:
@dataclass(repr=False, eq=False, frozen=True, slots=True)
class ScopedInjectable[R, T](Injectable[T], ABC):
factory: Caller[..., R]
scope_name: str
scope_names: Sequence[str]
key: SlotKey[T] = field(default_factory=SlotKey)
logic: CacheLogic[T] = field(default_factory=CacheLogic)

@property
def is_locked(self) -> bool:
return in_scope_cache(self.key, self.scope_name)
return in_scope_cache(self.key, *self.scope_names)

@abstractmethod
async def abuild(self, scope: Scope) -> T:
Expand All @@ -157,14 +157,16 @@ def get_instance(self) -> T:

def unlock(self) -> None:
if self.is_locked:
raise RuntimeError(f"To unlock, close the `{self.scope_name}` scope.")
raise RuntimeError(
f"To unlock, close all open scopes in [{', '.join(f'`{name}`' for name in self.scope_names)}]."
)

def __get_scope(self) -> Scope:
return get_scope(self.scope_name)
return get_first_scope(*self.scope_names)

@classmethod
def bind_scope_name(cls, name: str) -> Callable[[Caller[..., R]], Self]:
return partial(cls, scope_name=name)
def bind_scope_names(cls, names: Sequence[str]) -> Callable[[Caller[..., R]], Self]:
return partial(cls, scope_names=names)


class AsyncCMScopedInjectable[T](ScopedInjectable[AsyncContextManager[T], T]):
Expand Down Expand Up @@ -211,7 +213,7 @@ async def aget_instance(self) -> T:

def get_instance(self) -> T:
scope_name = self.scope_name
scope = get_scope(scope_name)
scope = get_first_scope(scope_name)

try:
return scope.cache[self.key]
Expand Down
5 changes: 2 additions & 3 deletions injection/_core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,8 @@ def decorator(wp: Recipe[P, T]) -> Recipe[P, T]:

def scoped[**P, T](
self,
scope_name: str,
/,
*,
*scope_names: str,
ignore_type_hint: bool = False,
inject: bool = True,
on: TypeInfo[T] = (),
Expand Down Expand Up @@ -284,7 +283,7 @@ def decorator(

self.injectable(
ctx.wrapper,
cls=ctx.cls.bind_scope_name(scope_name),
cls=ctx.cls.bind_scope_names(scope_names),
ignore_type_hint=True,
inject=inject,
on=(*ctx.hints, on),
Expand Down
24 changes: 13 additions & 11 deletions injection/_core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,34 +151,36 @@ def define_scope(
if TYPE_CHECKING: # pragma: no cover

@overload
def get_scope(name: str, default: EllipsisType = ...) -> Scope: ...
def get_first_scope(*names: str, default: EllipsisType = ...) -> Scope: ...

@overload
def get_scope[T](name: str, default: T) -> Scope | T: ...
def get_first_scope[T](*names: str, default: T) -> Scope | T: ...


def get_scope[T](name: str, default: T | EllipsisType = ...) -> Scope | T:
def get_first_scope[T](*names: str, default: T | EllipsisType = ...) -> Scope | T:
for resolvers in __scope_resolvers.values():
resolver = resolvers.get(name)
if resolver and (scope := resolver.get_scope()):
return scope
for name in names:
resolver = resolvers.get(name)
if resolver and (scope := resolver.get_scope()):
return scope

if default is ...:
raise ScopeUndefinedError(
f"Scope `{name}` isn't defined in the current context."
f"No scope in [{', '.join(f'`{name}`' for name in names)}] is defined in the current context."
)

return default


def in_scope_cache(key: SlotKey[Any], scope_name: str) -> bool:
return any(key in scope.cache for scope in iter_active_scopes(scope_name))
def in_scope_cache(key: SlotKey[Any], *scope_names: str) -> bool:
return any(key in scope.cache for scope in iter_active_scopes(*scope_names))


def iter_active_scopes(name: str) -> Iterator[Scope]:
def iter_active_scopes(*names: str) -> Iterator[Scope]:
active_scopes = (
resolver.active_scopes
for resolvers in __scope_resolvers.values()
for name in names
if (resolver := resolvers.get(name))
)
return itertools.chain.from_iterable(active_scopes)
Expand All @@ -194,7 +196,7 @@ def _bind_scope(
lock = get_lock(threadsafe)

with lock:
if get_scope(name, None):
if get_first_scope(name, default=None):
raise ScopeAlreadyDefinedError(
f"Scope `{name}` is already defined in the current context."
)
Expand Down
15 changes: 15 additions & 0 deletions tests/test_scoped.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,21 @@ class SomeInjectable: ...

assert instance_1 is instance_2

def test_scoped_with_several_scopes(self):
@scoped("scope_2", "scope_1")
class Dependency: ...

with define_scope("scope_1"):
d1 = find_instance(Dependency)

with define_scope("scope_2"):
d2 = find_instance(Dependency)

d3 = find_instance(Dependency)

assert d1 is not d2
assert d1 is d3

def test_scoped_with_on(self):
class A: ...

Expand Down
Loading
Loading