Skip to content

Commit

Permalink
Added support for optional dependencies in DI
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm committed May 9, 2022
1 parent e4470b5 commit 2d1c08a
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 9 deletions.
12 changes: 10 additions & 2 deletions docs/userguide/contexts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,19 @@ To specify a non-default name for the dependency, you can pass that name as an a
async def some_function(some_arg, some_resource: MyResourceType = resource('alternate')):
...

Resources can be declared to be optional too, by using either :data:`~typing.Optional`
or ``| None`` (Python 3.10+ only)::

@inject
async def some_function(some_arg, some_resource: Optional[MyResourceType] = resource('alternate')):
... # some_resource will be None if it's not found

Restrictions:

* The resource arguments must not be positional-only arguments
* The resources (or their relevant factories) must already be present in the context stack when
the decorated function is called, or otherwise :exc:`~.context.ResourceNotFound` is raised
* The resources (or their relevant factories) must already be present in the context
stack (unless declared optional) when the decorated function is called, or otherwise
:exc:`~.context.ResourceNotFound` is raised

.. _dependency injection: https://en.wikipedia.org/wiki/Dependency_injection

Expand Down
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
- Allowed resource retrieval and generation in teardown callbacks until the context has
been completely closed (this would previously raise
``RuntimeError("this context has already been closed")``)
- Allowed specifying optional dependencies with dependency injection, using either
``Optional[SomeType]`` (all Python versions) or ``SomeType | None`` (Python 3.10+)

**4.8.0** (2022-04-28)

Expand Down
39 changes: 33 additions & 6 deletions src/asphalt/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import logging
import re
import sys
import types
import warnings
from asyncio import (
AbstractEventLoop,
Expand Down Expand Up @@ -79,9 +80,9 @@
from typing_extensions import ParamSpec

if sys.version_info >= (3, 8):
from typing import get_origin
from typing import get_args, get_origin
else:
from typing_extensions import get_origin
from typing_extensions import get_args, get_origin

logger = logging.getLogger(__name__)
factory_callback_type = Callable[["Context"], Any]
Expand Down Expand Up @@ -928,6 +929,7 @@ def require_resource(type: Type[T_Resource], name: str = "default") -> T_Resourc
class _Dependency:
name: str = "default"
cls: type = field(init=False)
optional: bool = field(init=False, default=False)


def resource(name: str = "default") -> Any:
Expand Down Expand Up @@ -965,8 +967,8 @@ def inject(func: Callable[P, Any]) -> Callable[P, Any]:
"""
Wrap the given coroutine function for use with dependency injection.
Parameters with dependencies need to be annotated and have a :class:`Dependency` instance as
the default value.
Parameters with dependencies need to be annotated and have :func:`resource` as the
default value.
"""
forward_refs_resolved = False
Expand All @@ -976,6 +978,23 @@ def resolve_forward_refs() -> None:
type_hints = get_type_hints(func)
for key, dependency in injected_resources.items():
dependency.cls = type_hints[key]
origin = get_origin(type_hints[key])
if origin is Union or (
sys.version_info >= (3, 10) and origin is types.UnionType # noqa: E721
):
args = [
arg
for arg in get_args(dependency.cls)
if arg is not type(None) # noqa: E721
]
if len(args) == 1:
dependency.optional = True
dependency.cls = args[0]
else:
raise TypeError(
"Unions are only valid with dependency injection when there "
"are exactly two items and other item is None"
)

forward_refs_resolved = True

Expand All @@ -987,7 +1006,11 @@ def sync_wrapper(*args, **kwargs) -> T_Retval:
ctx = current_context()
resources: dict[str, Any] = {}
for argname, dependency in injected_resources.items():
resource: Any = ctx.require_resource(dependency.cls, dependency.name)
if dependency.optional:
resource = ctx.get_resource(dependency.cls, dependency.name)
else:
resource = ctx.require_resource(dependency.cls, dependency.name)

resources[argname] = resource

return func(*args, **kwargs, **resources)
Expand All @@ -1000,7 +1023,11 @@ async def async_wrapper(*args, **kwargs) -> T_Retval:
ctx = current_context()
resources: dict[str, Any] = {}
for argname, dependency in injected_resources.items():
resource: Any = ctx.require_resource(dependency.cls, dependency.name)
if dependency.optional:
resource = ctx.get_resource(dependency.cls, dependency.name)
else:
resource = ctx.require_resource(dependency.cls, dependency.name)

if isawaitable(resource):
resource = await resource

Expand Down
52 changes: 51 additions & 1 deletion tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from inspect import isawaitable
from itertools import count
from threading import current_thread
from typing import AsyncGenerator, AsyncIterator, Dict
from typing import AsyncGenerator, AsyncIterator, Dict, Optional
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -823,6 +823,56 @@ async def injected(foo: int, bar: str = resource()):

exc.match("no matching resource was found for type=str name='default'")

@pytest.mark.parametrize(
"annotation",
[
pytest.param(Optional[str], id="optional"),
# pytest.param(Union[str, int, None], id="union"),
pytest.param(
"str | None",
id="uniontype.10",
marks=[
pytest.mark.skipif(
sys.version_info < (3, 10), reason="Requires Python 3.10+"
)
],
),
],
)
@pytest.mark.parametrize(
"sync",
[
pytest.param(True, id="sync"),
pytest.param(False, id="async"),
],
)
@pytest.mark.asyncio
async def test_inject_optional_resource_async(
self, annotation: type, sync: bool
) -> None:
if sync:

@inject
def injected(
res: annotation = resource(), # type: ignore[valid-type]
) -> annotation: # type: ignore[valid-type]
return res

else:

@inject
async def injected(
res: annotation = resource(), # type: ignore[valid-type]
) -> annotation: # type: ignore[valid-type]
return res

async with Context() as ctx:
retval = injected() if sync else (await injected())
assert retval is None
ctx.add_resource("hello")
retval = injected() if sync else (await injected())
assert retval == "hello"


def test_dependency_deprecated():
with pytest.deprecated_call():
Expand Down

0 comments on commit 2d1c08a

Please sign in to comment.