-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* It allows use of Go-like defer() function. * It also offers the async-aware adefer() function.
- Loading branch information
Showing
2 changed files
with
197 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import functools | ||
import inspect | ||
from typing import ( | ||
Union, | ||
Awaitable, | ||
Callable, | ||
) | ||
|
||
|
||
def defer(func): | ||
assert not inspect.iscoroutinefunction(func), \ | ||
'the decorated function must not be async' | ||
|
||
@functools.wraps(func) | ||
def _wrapped(*args, **kwargs): | ||
deferreds = [] | ||
|
||
def defer(f: Callable) -> None: | ||
assert not inspect.iscoroutinefunction(f), \ | ||
'the deferred function must not be async' | ||
assert not inspect.iscoroutine(f), \ | ||
'the deferred object must not be a coroutine' | ||
deferreds.append(f) | ||
|
||
try: | ||
return func(defer, *args, **kwargs) | ||
finally: | ||
for f in reversed(deferreds): | ||
f() | ||
|
||
return _wrapped | ||
|
||
|
||
def adefer(func): | ||
assert inspect.iscoroutinefunction(func), \ | ||
'the decorated function must be async' | ||
|
||
@functools.wraps(func) | ||
async def _wrapped(*args, **kwargs): | ||
deferreds = [] | ||
|
||
def defer(f: Union[Callable, Awaitable]) -> None: | ||
deferreds.append(f) | ||
|
||
try: | ||
return await func(defer, *args, **kwargs) | ||
finally: | ||
for f in reversed(deferreds): | ||
if inspect.iscoroutinefunction(f): | ||
await f() | ||
elif inspect.iscoroutine(f): | ||
await f | ||
else: | ||
f() | ||
|
||
return _wrapped |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import pytest | ||
|
||
import asyncio | ||
from aiotools.defer import defer, adefer | ||
from aiotools.func import apartial | ||
|
||
|
||
def test_defer(): | ||
|
||
x = [] | ||
|
||
@defer | ||
def myfunc(defer): | ||
x.append(1) | ||
defer(lambda: x.append(1)) | ||
x.append(2) | ||
defer(lambda: x.append(2)) | ||
x.append(3) | ||
defer(lambda: x.append(3)) | ||
|
||
myfunc() | ||
assert x == [1, 2, 3, 3, 2, 1] | ||
|
||
|
||
def test_defer_inner_exception(): | ||
|
||
x = [] | ||
|
||
@defer | ||
def myfunc(defer): | ||
x.append(1) | ||
defer(lambda: x.append(1)) | ||
x.append(2) | ||
defer(lambda: x.append(2)) | ||
raise ZeroDivisionError | ||
x.append(3) | ||
defer(lambda: x.append(3)) | ||
|
||
with pytest.raises(ZeroDivisionError): | ||
myfunc() | ||
assert x == [1, 2, 2, 1] | ||
|
||
|
||
def test_defer_wrong_func(): | ||
with pytest.raises(AssertionError): | ||
@defer | ||
async def myfunc(defer): | ||
pass | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_adefer(): | ||
|
||
x = [] | ||
|
||
@adefer | ||
async def myfunc(defer): | ||
x.append(1) | ||
defer(lambda: x.append(1)) | ||
x.append(2) | ||
defer(lambda: x.append(2)) | ||
x.append(3) | ||
defer(lambda: x.append(3)) | ||
|
||
await myfunc() | ||
assert x == [1, 2, 3, 3, 2, 1] | ||
|
||
|
||
def test_adefer_wrong_func(): | ||
with pytest.raises(AssertionError): | ||
@adefer | ||
def myfunc(defer): | ||
pass | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_adefer_coro(): | ||
|
||
x = [] | ||
|
||
async def async_append(target, item): | ||
target.append(item) | ||
await asyncio.sleep(0) | ||
|
||
@adefer | ||
async def myfunc(defer): | ||
x.append(1) | ||
defer(async_append(x, 1)) | ||
x.append(2) | ||
defer(async_append(x, 2)) | ||
x.append(3) | ||
defer(async_append(x, 3)) | ||
|
||
await myfunc() | ||
assert x == [1, 2, 3, 3, 2, 1] | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_adefer_corofunc(): | ||
|
||
x = [] | ||
|
||
async def async_append(target, item): | ||
target.append(item) | ||
await asyncio.sleep(0) | ||
|
||
@adefer | ||
async def myfunc(defer): | ||
x.append(1) | ||
defer(apartial(async_append, x, 1)) | ||
x.append(2) | ||
defer(apartial(async_append, x, 2)) | ||
x.append(3) | ||
defer(apartial(async_append, x, 3)) | ||
|
||
await myfunc() | ||
assert x == [1, 2, 3, 3, 2, 1] | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_adefer_inner_exception(): | ||
|
||
x = [] | ||
|
||
async def async_append(target, item): | ||
target.append(item) | ||
await asyncio.sleep(0) | ||
|
||
@adefer | ||
async def myfunc(defer): | ||
x.append(1) | ||
defer(apartial(async_append, x, 1)) | ||
x.append(2) | ||
defer(apartial(async_append, x, 2)) | ||
raise ZeroDivisionError | ||
x.append(3) | ||
defer(apartial(async_append, x, 3)) | ||
|
||
with pytest.raises(ZeroDivisionError): | ||
await myfunc() | ||
assert x == [1, 2, 2, 1] |