Skip to content

Commit

Permalink
Support PEP 604 optional & union types (#203)
Browse files Browse the repository at this point in the history
Fixes #202.
  • Loading branch information
davidparsson committed Jul 11, 2022
1 parent 97aef80 commit f66285c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
6 changes: 5 additions & 1 deletion injector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,10 @@ def __getattribute__(self, name: str) -> Any:


def _infer_injected_bindings(callable: Callable, only_explicit_bindings: bool) -> Dict[str, type]:
def _is_new_union_type(instance: Any) -> bool:
new_union_type = getattr(types, 'UnionType', None)
return new_union_type is not None and isinstance(instance, new_union_type)

spec = inspect.getfullargspec(callable)

try:
Expand Down Expand Up @@ -1204,7 +1208,7 @@ def _infer_injected_bindings(callable: Callable, only_explicit_bindings: bool) -

if only_explicit_bindings and _inject_marker not in metadata or _noinject_marker in metadata:
del bindings[k]
elif _is_specialization(v, Union):
elif _is_specialization(v, Union) or _is_new_union_type(v):
# We don't treat Optional parameters in any special way at the moment.
union_members = v.__args__
new_members = tuple(set(union_members) - {type(None)})
Expand Down
18 changes: 17 additions & 1 deletion injector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"""Functional tests for the "Injector" dependency injection framework."""

from contextlib import contextmanager
from typing import Any, NewType, Optional
from typing import Any, NewType, Optional, Union
import abc
import sys
import threading
Expand Down Expand Up @@ -1516,3 +1516,19 @@ def function11(a: int) -> 'InvalidForwardReference':
pass

assert get_bindings(function11) == {'a': int}


# Tests https://github.com/alecthomas/injector/issues/202
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10+")
def test_get_bindings_for_pep_604():
@inject
def function1(a: int | None) -> None:
pass

assert get_bindings(function1) == {'a': int}

@inject
def function1(a: int | str) -> None:
pass

assert get_bindings(function1) == {'a': Union[int, str]}

0 comments on commit f66285c

Please sign in to comment.