Skip to content

Commit

Permalink
Support forward references in provider method return types
Browse files Browse the repository at this point in the history
Closes GH-130.
  • Loading branch information
jstasiak committed Dec 14, 2019
1 parent bcae038 commit d31b83f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
29 changes: 25 additions & 4 deletions injector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,20 @@ def __call__(self, binder: Binder) -> None:
binding = None
if hasattr(function, '__binding__'):
binding = function.__binding__
if binding.interface == '__deferred__':
# We could not evaluate a forward reference at @provider-decoration time, we need to
# try again now.
try:
annotations = get_type_hints(function)
except NameError as e:
raise NameError(
'Cannot avaluate forward reference annotation(s) in method %r belonging to %r: %s'
% (function.__name__, type(self), e)
) from e
return_type = annotations['return']
binding = function.__func__.__binding__ = Binding(
interface=return_type, provider=binding.provider, scope=binding.scope
)
bind_method = binder.multibind if binding.is_multibinding() else binder.bind
bind_method( # type: ignore
binding.interface, to=types.MethodType(binding.provider, self), scope=binding.scope
Expand Down Expand Up @@ -1206,16 +1220,23 @@ def provide_strs_also(self) -> List[str]:

def _mark_provider_function(function: Callable, *, allow_multi: bool) -> None:
scope_ = getattr(function, '__scope__', None)
annotations = inspect.getfullargspec(function).annotations
return_type = annotations['return']
try:
annotations = get_type_hints(function)
except NameError:
return_type = '__deferred__'
else:
return_type = annotations['return']
_validate_provider_return_type(function, cast(type, return_type), allow_multi)
function.__binding__ = Binding(return_type, inject(function), scope_) # type: ignore


def _validate_provider_return_type(function: Callable, return_type: type, allow_multi: bool) -> None:
origin = _get_origin(_punch_through_alias(return_type))
if origin in {dict, list} and not allow_multi:
raise Error(
'Function %s needs to be decorated with multiprovider instead of provider if it is to '
'provide values to a multibinding of type %s' % (function.__name__, return_type)
)
binding = Binding(return_type, inject(function), scope_)
function.__binding__ = binding # type: ignore


ConstructorOrClassT = TypeVar('ConstructorOrClassT', bound=Union[Callable, Type])
Expand Down
9 changes: 6 additions & 3 deletions injector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,8 +1176,11 @@ def __init__(self, s: str, *args: int, **kwargs: str):

def test_forward_references_in_annotations_are_handled():
# See https://www.python.org/dev/peps/pep-0484/#forward-references for details
def configure(binder):
binder.bind(X, to=X('hello'))

class CustomModule(Module):
@provider
def provide_x(self) -> 'X':
return X('hello')

@inject
def fun(s: 'X') -> 'X':
Expand All @@ -1193,7 +1196,7 @@ def __init__(self, message: str) -> None:
self.message = message

try:
injector = Injector(configure)
injector = Injector(CustomModule)
assert injector.call_with_injection(fun).message == 'hello'
finally:
del X
Expand Down

0 comments on commit d31b83f

Please sign in to comment.