Skip to content
This repository has been archived by the owner on Feb 21, 2023. It is now read-only.

Commit

Permalink
Merge pull request #323 from aio-libs/python_35
Browse files Browse the repository at this point in the history
Python 3.5 (async/await syntax everywhere)
  • Loading branch information
popravich committed Nov 13, 2017
2 parents 64fac63 + e864842 commit 7a26a7c
Show file tree
Hide file tree
Showing 56 changed files with 2,587 additions and 3,226 deletions.
7 changes: 2 additions & 5 deletions aioredis/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
"""
import abc
import asyncio
try:
from abc import ABC
except ImportError:
class ABC(metaclass=abc.ABCMeta):
pass

from abc import ABC


__all__ = [
Expand Down
70 changes: 31 additions & 39 deletions aioredis/commands/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import asyncio
# import warnings

from aioredis.connection import create_connection
from aioredis.pool import create_pool
from aioredis.util import _NOTSET
Expand Down Expand Up @@ -56,10 +53,9 @@ def close(self):
"""Close client connections."""
self._pool_or_conn.close()

@asyncio.coroutine
def wait_closed(self):
async def wait_closed(self):
"""Coroutine waiting until underlying connections are closed."""
yield from self._pool_or_conn.wait_closed()
await self._pool_or_conn.wait_closed()

@property
def db(self):
Expand Down Expand Up @@ -130,15 +126,14 @@ def select(self, db):

def __await__(self):
if isinstance(self._pool_or_conn, AbcPool):
conn = yield from self._pool_or_conn.acquire()
conn = yield from self._pool_or_conn.acquire().__await__()
release = self._pool_or_conn.release
else:
# TODO: probably a lock is needed here if _pool_or_conn
# is Connection instance.
conn = self._pool_or_conn
release = None
return ContextRedis(conn, release)
__iter__ = __await__


class ContextRedis(Redis):
Expand All @@ -159,48 +154,45 @@ def __exit__(self, *exc_info):
def __await__(self):
return ContextRedis(self._pool_or_conn)
yield
__iter__ = __await__


@asyncio.coroutine
def create_redis(address, *, db=None, password=None, ssl=None,
encoding=None, commands_factory=Redis,
parser=None, timeout=None,
connection_cls=None, loop=None):
async def create_redis(address, *, db=None, password=None, ssl=None,
encoding=None, commands_factory=Redis,
parser=None, timeout=None,
connection_cls=None, loop=None):
"""Creates high-level Redis interface.
This function is a coroutine.
"""
conn = yield from create_connection(address, db=db,
password=password,
ssl=ssl,
encoding=encoding,
parser=parser,
timeout=timeout,
connection_cls=connection_cls,
loop=loop)
conn = await create_connection(address, db=db,
password=password,
ssl=ssl,
encoding=encoding,
parser=parser,
timeout=timeout,
connection_cls=connection_cls,
loop=loop)
return commands_factory(conn)


@asyncio.coroutine
def create_redis_pool(address, *, db=None, password=None, ssl=None,
encoding=None, commands_factory=Redis,
minsize=1, maxsize=10, parser=None,
timeout=None, pool_cls=None,
connection_cls=None, loop=None):
async def create_redis_pool(address, *, db=None, password=None, ssl=None,
encoding=None, commands_factory=Redis,
minsize=1, maxsize=10, parser=None,
timeout=None, pool_cls=None,
connection_cls=None, loop=None):
"""Creates high-level Redis interface.
This function is a coroutine.
"""
pool = yield from create_pool(address, db=db,
password=password,
ssl=ssl,
encoding=encoding,
minsize=minsize,
maxsize=maxsize,
parser=parser,
create_connection_timeout=timeout,
pool_cls=pool_cls,
connection_cls=connection_cls,
loop=loop)
pool = await create_pool(address, db=db,
password=password,
ssl=ssl,
encoding=encoding,
minsize=minsize,
maxsize=maxsize,
parser=parser,
create_connection_timeout=timeout,
pool_cls=pool_cls,
connection_cls=connection_cls,
loop=loop)
return commands_factory(pool)
7 changes: 2 additions & 5 deletions aioredis/commands/pubsub.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import json

from aioredis.util import wait_make_dict
Expand Down Expand Up @@ -104,8 +103,6 @@ def in_pubsub(self):
return self._pool_or_conn.in_pubsub


@asyncio.coroutine
def wait_return_channels(fut, channels_dict):
res = yield from fut
async def wait_return_channels(fut, channels_dict):
return [channels_dict[name]
for cmd, name, count in res]
for cmd, name, count in await fut]
53 changes: 24 additions & 29 deletions aioredis/commands/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
)
from ..util import (
wait_ok,
async_task,
create_future,
_set_exception,
)

Expand Down Expand Up @@ -100,7 +98,7 @@ def __init__(self, pipeline, *, loop=None):
self._loop = loop

def execute(self, cmd, *args, **kw):
fut = create_future(loop=self._loop)
fut = self._loop.create_future()
self._pipeline.append((fut, cmd, args, kw))
return fut

Expand Down Expand Up @@ -144,17 +142,17 @@ def __getattr__(self, name):
@functools.wraps(attr)
def wrapper(*args, **kw):
try:
task = async_task(attr(*args, **kw), loop=self._loop)
task = asyncio.ensure_future(attr(*args, **kw),
loop=self._loop)
except Exception as exc:
task = create_future(loop=self._loop)
task = self._loop.create_future()
task.set_exception(exc)
self._results.append(task)
return task
return wrapper
return attr

@asyncio.coroutine
def execute(self, *, return_exceptions=False):
async def execute(self, *, return_exceptions=False):
"""Execute all buffered commands.
Any exception that is raised by any command is caught and
Expand All @@ -168,30 +166,28 @@ def execute(self, *, return_exceptions=False):

if self._pipeline:
if isinstance(self._pool_or_conn, AbcPool):
with (yield from self._pool_or_conn) as conn:
return (yield from self._do_execute(
conn, return_exceptions=return_exceptions))
async with self._pool_or_conn.get() as conn:
return await self._do_execute(
conn, return_exceptions=return_exceptions)
else:
return (yield from self._do_execute(
return await self._do_execute(
self._pool_or_conn,
return_exceptions=return_exceptions))
return_exceptions=return_exceptions)
else:
return (yield from self._gather_result(return_exceptions))
return await self._gather_result(return_exceptions)

@asyncio.coroutine
def _do_execute(self, conn, *, return_exceptions=False):
yield from asyncio.gather(*self._send_pipeline(conn),
loop=self._loop,
return_exceptions=True)
return (yield from self._gather_result(return_exceptions))
async def _do_execute(self, conn, *, return_exceptions=False):
await asyncio.gather(*self._send_pipeline(conn),
loop=self._loop,
return_exceptions=True)
return await self._gather_result(return_exceptions)

@asyncio.coroutine
def _gather_result(self, return_exceptions):
async def _gather_result(self, return_exceptions):
errors = []
results = []
for fut in self._results:
try:
res = yield from fut
res = await fut
results.append(res)
except Exception as exc:
errors.append(exc)
Expand Down Expand Up @@ -257,8 +253,7 @@ class MultiExec(Pipeline):
"""
error_class = MultiExecError

@asyncio.coroutine
def _do_execute(self, conn, *, return_exceptions=False):
async def _do_execute(self, conn, *, return_exceptions=False):
self._waiters = waiters = []
multi = conn.execute('MULTI')
coros = list(self._send_pipeline(conn))
Expand All @@ -267,9 +262,9 @@ def _do_execute(self, conn, *, return_exceptions=False):
return_exceptions=True)
last_error = None
try:
yield from asyncio.shield(gather, loop=self._loop)
await asyncio.shield(gather, loop=self._loop)
except asyncio.CancelledError:
yield from gather
await gather
except Exception as err:
last_error = err
raise
Expand All @@ -286,15 +281,15 @@ def _do_execute(self, conn, *, return_exceptions=False):
# fut.cancel()
else:
try:
results = yield from exec_
results = await exec_
except RedisError as err:
for fut in waiters:
fut.set_exception(err)
else:
assert len(results) == len(waiters), (
"Results does not match waiters", results, waiters)
self._resolve_waiters(results, return_exceptions)
return (yield from self._gather_result(return_exceptions))
return (await self._gather_result(return_exceptions))

def _resolve_waiters(self, results, return_exceptions):
errors = []
Expand All @@ -310,7 +305,7 @@ def _resolve_waiters(self, results, return_exceptions):
def _check_result(self, fut, waiter):
assert waiter not in self._waiters, (fut, waiter, self._waiters)
assert not waiter.done(), waiter
if fut.cancelled(): # yield from gather was cancelled
if fut.cancelled(): # await gather was cancelled
waiter.cancel()
elif fut.exception(): # server replied with error
waiter.set_exception(fut.exception())
Expand Down
38 changes: 17 additions & 21 deletions aioredis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
_set_exception,
coerced_keys_dict,
decode,
async_task,
create_future,
parse_url,
)
from .parser import Reader
Expand Down Expand Up @@ -45,10 +43,9 @@
)


@asyncio.coroutine
def create_connection(address, *, db=None, password=None, ssl=None,
encoding=None, parser=None, loop=None, timeout=None,
connection_cls=None):
async def create_connection(address, *, db=None, password=None, ssl=None,
encoding=None, parser=None, loop=None,
timeout=None, connection_cls=None):
"""Creates redis connection.
Opens connection to Redis server specified by address argument.
Expand Down Expand Up @@ -104,7 +101,7 @@ def create_connection(address, *, db=None, password=None, ssl=None,
if isinstance(address, (list, tuple)):
host, port = address
logger.debug("Creating tcp connection to %r", address)
reader, writer = yield from asyncio.wait_for(open_connection(
reader, writer = await asyncio.wait_for(open_connection(
host, port, limit=MAX_CHUNK_SIZE, ssl=ssl, loop=loop),
timeout, loop=loop)
sock = writer.transport.get_extra_info('socket')
Expand All @@ -114,7 +111,7 @@ def create_connection(address, *, db=None, password=None, ssl=None,
address = tuple(address[:2])
else:
logger.debug("Creating unix connection to %r", address)
reader, writer = yield from asyncio.wait_for(open_unix_connection(
reader, writer = await asyncio.wait_for(open_unix_connection(
address, ssl=ssl, limit=MAX_CHUNK_SIZE, loop=loop),
timeout, loop=loop)
sock = writer.transport.get_extra_info('socket')
Expand All @@ -127,12 +124,12 @@ def create_connection(address, *, db=None, password=None, ssl=None,

try:
if password is not None:
yield from conn.auth(password)
await conn.auth(password)
if db is not None:
yield from conn.select(db)
await conn.select(db)
except Exception:
conn.close()
yield from conn.wait_closed()
await conn.wait_closed()
raise
return conn

Expand All @@ -156,11 +153,12 @@ def __init__(self, reader, writer, *, address, encoding=None,
self._reader.set_parser(
parser(protocolError=ProtocolError, replyError=ReplyError)
)
self._reader_task = async_task(self._read_data(), loop=self._loop)
self._reader_task = asyncio.ensure_future(self._read_data(),
loop=self._loop)
self._db = 0
self._closing = False
self._closed = False
self._close_waiter = create_future(loop=self._loop)
self._close_waiter = loop.create_future()
self._reader_task.add_done_callback(self._close_waiter.set_result)
self._in_transaction = None
self._transaction_error = None # XXX: never used?
Expand All @@ -172,14 +170,13 @@ def __init__(self, reader, writer, *, address, encoding=None,
def __repr__(self):
return '<RedisConnection [db:{}]>'.format(self._db)

@asyncio.coroutine
def _read_data(self):
async def _read_data(self):
"""Response reader task."""
last_error = ConnectionClosedError(
"Connection has been closed by server")
while not self._reader.at_eof():
try:
obj = yield from self._reader.readobj()
obj = await self._reader.readobj()
except asyncio.CancelledError:
# NOTE: reader can get cancelled from `close()` method only.
last_error = RuntimeError('this is unexpected')
Expand Down Expand Up @@ -308,7 +305,7 @@ def execute(self, command, *args, encoding=_NOTSET):
cb = None
if encoding is _NOTSET:
encoding = self._encoding
fut = create_future(loop=self._loop)
fut = self._loop.create_future()
self._writer.write(encode_command(command, *args))
self._waiters.append((fut, encoding, cb))
return fut
Expand Down Expand Up @@ -338,7 +335,7 @@ def execute_pubsub(self, command, *channels):
cmd = encode_command(command, *(ch.name for ch in channels))
res = []
for ch in channels:
fut = create_future(loop=self._loop)
fut = self._loop.create_future()
res.append(fut)
cb = partial(self._update_pubsub, ch=ch)
self._waiters.append((fut, None, cb))
Expand Down Expand Up @@ -384,10 +381,9 @@ def closed(self):
self._loop.call_soon(self._do_close, None)
return closed

@asyncio.coroutine
def wait_closed(self):
async def wait_closed(self):
"""Coroutine waiting until connection is closed."""
yield from asyncio.shield(self._close_waiter, loop=self._loop)
await asyncio.shield(self._close_waiter, loop=self._loop)

@property
def db(self):
Expand Down
Loading

0 comments on commit 7a26a7c

Please sign in to comment.