Skip to content

Commit

Permalink
Merge branch 'autoasync-fix'
Browse files Browse the repository at this point in the history
  • Loading branch information
Lucretiel committed Sep 22, 2016
2 parents 304728b + c6d4f0e commit dba3bc4
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/autocommand/autoasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,14 @@ def autoasync_wrapper(*args, **kwargs):
# installed after the autoasync decorator, it is respected at call time
local_loop = get_event_loop() if loop is None else loop

# Inject the 'loop' argument. We have to use this signature binding to
# ensure it's injected in the correct place (positional, keyword, etc)
if pass_loop:
kwargs['loop'] = local_loop
bound_args = old_sig.bind_partial()
bound_args.arguments.update(
loop=local_loop,
**new_sig.bind(*args, **kwargs).arguments)
args, kwargs = bound_args.args, bound_args.kwargs

if forever:
# Explicitly don't create a reference to the created task. This
Expand All @@ -91,8 +97,10 @@ def autoasync_wrapper(*args, **kwargs):
# Attach an updated signature, with the "loop" parameter filted out. This
# allows 'pass_loop' to be used with autoparse
if pass_loop:
sig = signature(coro)
autoasync_wrapper.__signature__ = sig.replace(parameters=(
param for name, param in sig.parameters.items() if name != "loop"))
old_sig = signature(coro)
new_sig = old_sig.replace(parameters=(
param for name, param in old_sig.parameters.items()
if name != "loop"))
autoasync_wrapper.__signature__ = new_sig

return autoasync_wrapper
28 changes: 28 additions & 0 deletions test/test_autoasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,34 @@ def async_main(loop):
assert async_main() is asyncio.get_event_loop()


def test_pass_loop_prior_argument(context_loop):
'''
Test that, if loop is the first positional argument, other arguments are
still passed correctly
'''
@autoasync(pass_loop=True)
@asyncio.coroutine
def async_main(loop, argument):
yield
return loop, argument

loop, value = async_main(10)
assert loop is asyncio.get_event_loop()
assert value == 10


def test_pass_loop_kwarg_only(context_loop):
@autoasync(pass_loop=True)
@asyncio.coroutine
def async_main(*, loop, argument):
yield
return loop, argument

loop, value = async_main(argument=10)
assert loop is asyncio.get_event_loop()
assert value == 10


def test_run_forever(context_loop):
@asyncio.coroutine
def stop_loop_after(t):
Expand Down

0 comments on commit dba3bc4

Please sign in to comment.