Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions injection/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,13 @@ class Module:
/,
threadsafe: bool = ...,
) -> Callable[..., Awaitable[T]]: ...
async def afind_instance[T](self, cls: _InputType[T]) -> T: ...
def find_instance[T](self, cls: _InputType[T]) -> T:
async def afind_instance[T](
self,
cls: _InputType[T],
*,
threadsafe: bool = ...,
) -> T: ...
def find_instance[T](self, cls: _InputType[T], *, threadsafe: bool = ...) -> T:
"""
Function used to retrieve an instance associated with the type passed in
parameter or an exception will be raised.
Expand All @@ -229,59 +234,66 @@ class Module:
self,
cls: _InputType[T],
default: Default,
*,
threadsafe: bool = ...,
) -> T | Default: ...
@overload
async def aget_instance[T](
self,
cls: _InputType[T],
default: None = ...,
) -> T | None: ...
default: T = ...,
*,
threadsafe: bool = ...,
) -> T: ...
@overload
def get_instance[T, Default](
self,
cls: _InputType[T],
default: Default,
*,
threadsafe: bool = ...,
) -> T | Default:
"""
Function used to retrieve an instance associated with the type passed in
parameter or return `None`.
parameter or return `NotImplemented`.
"""

@overload
def get_instance[T](
self,
cls: _InputType[T],
default: None = ...,
) -> T | None: ...
default: T = ...,
*,
threadsafe: bool = ...,
) -> T: ...
@overload
def aget_lazy_instance[T, Default](
self,
cls: _InputType[T],
default: Default,
*,
cache: bool = ...,
threadsafe: bool = ...,
) -> Awaitable[T | Default]: ...
@overload
def aget_lazy_instance[T](
self,
cls: _InputType[T],
default: None = ...,
default: T = ...,
*,
cache: bool = ...,
) -> Awaitable[T | None]: ...
threadsafe: bool = ...,
) -> Awaitable[T]: ...
@overload
def get_lazy_instance[T, Default](
self,
cls: _InputType[T],
default: Default,
*,
cache: bool = ...,
threadsafe: bool = ...,
) -> _Invertible[T | Default]:
"""
Function used to retrieve an instance associated with the type passed in
parameter or `None`. Return a `Invertible` object. To access the instance
parameter or `NotImplemented`. Return a `Invertible` object. To access the instance
contained in an invertible object, simply use a wavy line (~).
With `cache=True`, the instance retrieved will always be the same.

Example: instance = ~lazy_instance
"""
Expand All @@ -290,10 +302,10 @@ class Module:
def get_lazy_instance[T](
self,
cls: _InputType[T],
default: None = ...,
default: T = ...,
*,
cache: bool = ...,
) -> _Invertible[T | None]: ...
threadsafe: bool = ...,
) -> _Invertible[T]: ...
def init_modules(self, *modules: Module) -> Self:
"""
Function to clean modules in use and to use those passed as parameters.
Expand Down
1 change: 1 addition & 0 deletions injection/_core/common/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def call(self, /, *args: P.args, **kwargs: P.kwargs) -> T:

def create_semaphore(value: int) -> AsyncContextManager[Any]:
return anyio.Semaphore(value)

except ImportError: # pragma: no cover
import asyncio

Expand Down
16 changes: 1 addition & 15 deletions injection/_core/common/lazy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator
from collections.abc import Callable, Iterator
from functools import partial

from injection._core.common.asynchronous import SimpleAwaitable
from injection._core.common.invertible import Invertible, SimpleInvertible


Expand All @@ -18,19 +17,6 @@ def cache() -> Iterator[T]:
return SimpleInvertible(getter)


def alazy[T](factory: Callable[..., Awaitable[T]]) -> Awaitable[T]:
async def cache() -> AsyncIterator[T]:
nonlocal factory
value = await factory()
del factory

while True:
yield value

getter = partial(anext, cache())
return SimpleAwaitable(getter)


class Lazy[T](Invertible[T]):
__slots__ = ("__invertible", "__is_set")

Expand Down
7 changes: 7 additions & 0 deletions injection/_core/common/threading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from contextlib import nullcontext
from threading import RLock
from typing import Any, ContextManager


def get_lock(threadsafe: bool) -> ContextManager[Any]:
return RLock() if threadsafe else nullcontext()
103 changes: 61 additions & 42 deletions injection/_core/module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import threading
from abc import ABC, abstractmethod
from collections import OrderedDict, deque
from collections.abc import (
Expand All @@ -14,7 +13,7 @@
Iterator,
Mapping,
)
from contextlib import asynccontextmanager, contextmanager, nullcontext, suppress
from contextlib import asynccontextmanager, contextmanager, suppress
from dataclasses import dataclass, field
from enum import StrEnum
from functools import partial, partialmethod, singledispatchmethod, update_wrapper
Expand Down Expand Up @@ -51,7 +50,8 @@
from injection._core.common.event import Event, EventChannel, EventListener
from injection._core.common.invertible import Invertible, SimpleInvertible
from injection._core.common.key import new_short_key
from injection._core.common.lazy import Lazy, alazy, lazy
from injection._core.common.lazy import Lazy, lazy
from injection._core.common.threading import get_lock
from injection._core.common.type import (
InputType,
TypeInfo,
Expand Down Expand Up @@ -617,35 +617,48 @@ def make_async_factory[T](
)
return factory.__inject_metadata__.acall

async def afind_instance[T](self, cls: InputType[T]) -> T:
injectable = self[cls]
return await injectable.aget_instance()
async def afind_instance[T](
self,
cls: InputType[T],
*,
threadsafe: bool = False,
) -> T:
with get_lock(threadsafe):
injectable = self[cls]
return await injectable.aget_instance()

def find_instance[T](self, cls: InputType[T]) -> T:
injectable = self[cls]
return injectable.get_instance()
def find_instance[T](self, cls: InputType[T], *, threadsafe: bool = False) -> T:
with get_lock(threadsafe):
injectable = self[cls]
return injectable.get_instance()

@overload
async def aget_instance[T, Default](
self,
cls: InputType[T],
default: Default,
*,
threadsafe: bool = ...,
) -> T | Default: ...

@overload
async def aget_instance[T](
self,
cls: InputType[T],
default: None = ...,
) -> T | None: ...
default: T = ...,
*,
threadsafe: bool = ...,
) -> T: ...

async def aget_instance[T, Default](
self,
cls: InputType[T],
default: Default | None = None,
) -> T | Default | None:
default: Default = NotImplemented,
*,
threadsafe: bool = False,
) -> T | Default:
try:
return await self.afind_instance(cls)
return await self.afind_instance(cls, threadsafe=threadsafe)
except (KeyError, SkipInjectable):
return default

Expand All @@ -654,22 +667,28 @@ def get_instance[T, Default](
self,
cls: InputType[T],
default: Default,
*,
threadsafe: bool = ...,
) -> T | Default: ...

@overload
def get_instance[T](
self,
cls: InputType[T],
default: None = ...,
) -> T | None: ...
default: T = ...,
*,
threadsafe: bool = ...,
) -> T: ...

def get_instance[T, Default](
self,
cls: InputType[T],
default: Default | None = None,
) -> T | Default | None:
default: Default = NotImplemented,
*,
threadsafe: bool = False,
) -> T | Default:
try:
return self.find_instance(cls)
return self.find_instance(cls, threadsafe=threadsafe)
except (KeyError, SkipInjectable):
return default

Expand All @@ -679,29 +698,29 @@ def aget_lazy_instance[T, Default](
cls: InputType[T],
default: Default,
*,
cache: bool = ...,
threadsafe: bool = ...,
) -> Awaitable[T | Default]: ...

@overload
def aget_lazy_instance[T](
self,
cls: InputType[T],
default: None = ...,
default: T = ...,
*,
cache: bool = ...,
) -> Awaitable[T | None]: ...
threadsafe: bool = ...,
) -> Awaitable[T]: ...

def aget_lazy_instance[T, Default](
self,
cls: InputType[T],
default: Default | None = None,
default: Default = NotImplemented,
*,
cache: bool = False,
) -> Awaitable[T | Default | None]:
if cache:
return alazy(lambda: self.aget_instance(cls, default))

function = self.make_injected_function(lambda instance=default: instance)
threadsafe: bool = False,
) -> Awaitable[T | Default]:
function = self.make_injected_function(
lambda instance=default: instance,
threadsafe=threadsafe,
)
metadata = function.__inject_metadata__.set_owner(cls)
return SimpleAwaitable(metadata.acall)

Expand All @@ -711,29 +730,29 @@ def get_lazy_instance[T, Default](
cls: InputType[T],
default: Default,
*,
cache: bool = ...,
threadsafe: bool = ...,
) -> Invertible[T | Default]: ...

@overload
def get_lazy_instance[T](
self,
cls: InputType[T],
default: None = ...,
default: T = ...,
*,
cache: bool = ...,
) -> Invertible[T | None]: ...
threadsafe: bool = ...,
) -> Invertible[T]: ...

def get_lazy_instance[T, Default](
self,
cls: InputType[T],
default: Default | None = None,
default: Default = NotImplemented,
*,
cache: bool = False,
) -> Invertible[T | Default | None]:
if cache:
return lazy(lambda: self.get_instance(cls, default))

function = self.make_injected_function(lambda instance=default: instance)
threadsafe: bool = False,
) -> Invertible[T | Default]:
function = self.make_injected_function(
lambda instance=default: instance,
threadsafe=threadsafe,
)
metadata = function.__inject_metadata__.set_owner(cls)
return SimpleInvertible(metadata.call)

Expand Down Expand Up @@ -996,7 +1015,7 @@ class InjectMetadata[**P, T](Caller[P, T], EventListener):

def __init__(self, wrapped: Callable[P, T], /, threadsafe: bool) -> None:
self.__dependencies = Dependencies.empty()
self.__lock = threading.RLock() if threadsafe else nullcontext()
self.__lock = get_lock(threadsafe)
self.__owner = None
self.__tasks = deque()
self.__wrapped = wrapped
Expand Down
Loading
Loading