Skip to content
This repository has been archived by the owner on Feb 21, 2023. It is now read-only.

Python 3.5 (async/await syntax everywhere) #323

Merged
merged 8 commits into from
Nov 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions aioredis/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
"""
import abc
import asyncio
try:
from abc import ABC
except ImportError:
class ABC(metaclass=abc.ABCMeta):
pass

from abc import ABC


__all__ = [
Expand Down
70 changes: 31 additions & 39 deletions aioredis/commands/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import asyncio
# import warnings

from aioredis.connection import create_connection
from aioredis.pool import create_pool
from aioredis.util import _NOTSET
Expand Down Expand Up @@ -56,10 +53,9 @@ def close(self):
"""Close client connections."""
self._pool_or_conn.close()

@asyncio.coroutine
def wait_closed(self):
async def wait_closed(self):
"""Coroutine waiting until underlying connections are closed."""
yield from self._pool_or_conn.wait_closed()
await self._pool_or_conn.wait_closed()

@property
def db(self):
Expand Down Expand Up @@ -130,15 +126,14 @@ def select(self, db):

def __await__(self):
if isinstance(self._pool_or_conn, AbcPool):
conn = yield from self._pool_or_conn.acquire()
conn = yield from self._pool_or_conn.acquire().__await__()
release = self._pool_or_conn.release
else:
# TODO: probably a lock is needed here if _pool_or_conn
# is Connection instance.
conn = self._pool_or_conn
release = None
return ContextRedis(conn, release)
__iter__ = __await__


class ContextRedis(Redis):
Expand All @@ -159,48 +154,45 @@ def __exit__(self, *exc_info):
def __await__(self):
return ContextRedis(self._pool_or_conn)
yield
__iter__ = __await__


@asyncio.coroutine
def create_redis(address, *, db=None, password=None, ssl=None,
encoding=None, commands_factory=Redis,
parser=None, timeout=None,
connection_cls=None, loop=None):
async def create_redis(address, *, db=None, password=None, ssl=None,
encoding=None, commands_factory=Redis,
parser=None, timeout=None,
connection_cls=None, loop=None):
"""Creates high-level Redis interface.

This function is a coroutine.
"""
conn = yield from create_connection(address, db=db,
password=password,
ssl=ssl,
encoding=encoding,
parser=parser,
timeout=timeout,
connection_cls=connection_cls,
loop=loop)
conn = await create_connection(address, db=db,
password=password,
ssl=ssl,
encoding=encoding,
parser=parser,
timeout=timeout,
connection_cls=connection_cls,
loop=loop)
return commands_factory(conn)


@asyncio.coroutine
def create_redis_pool(address, *, db=None, password=None, ssl=None,
encoding=None, commands_factory=Redis,
minsize=1, maxsize=10, parser=None,
timeout=None, pool_cls=None,
connection_cls=None, loop=None):
async def create_redis_pool(address, *, db=None, password=None, ssl=None,
encoding=None, commands_factory=Redis,
minsize=1, maxsize=10, parser=None,
timeout=None, pool_cls=None,
connection_cls=None, loop=None):
"""Creates high-level Redis interface.

This function is a coroutine.
"""
pool = yield from create_pool(address, db=db,
password=password,
ssl=ssl,
encoding=encoding,
minsize=minsize,
maxsize=maxsize,
parser=parser,
create_connection_timeout=timeout,
pool_cls=pool_cls,
connection_cls=connection_cls,
loop=loop)
pool = await create_pool(address, db=db,
password=password,
ssl=ssl,
encoding=encoding,
minsize=minsize,
maxsize=maxsize,
parser=parser,
create_connection_timeout=timeout,
pool_cls=pool_cls,
connection_cls=connection_cls,
loop=loop)
return commands_factory(pool)
7 changes: 2 additions & 5 deletions aioredis/commands/pubsub.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import json

from aioredis.util import wait_make_dict
Expand Down Expand Up @@ -104,8 +103,6 @@ def in_pubsub(self):
return self._pool_or_conn.in_pubsub


@asyncio.coroutine
def wait_return_channels(fut, channels_dict):
res = yield from fut
async def wait_return_channels(fut, channels_dict):
return [channels_dict[name]
for cmd, name, count in res]
for cmd, name, count in await fut]
53 changes: 24 additions & 29 deletions aioredis/commands/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
)
from ..util import (
wait_ok,
async_task,
create_future,
_set_exception,
)

Expand Down Expand Up @@ -100,7 +98,7 @@ def __init__(self, pipeline, *, loop=None):
self._loop = loop

def execute(self, cmd, *args, **kw):
fut = create_future(loop=self._loop)
fut = self._loop.create_future()
self._pipeline.append((fut, cmd, args, kw))
return fut

Expand Down Expand Up @@ -144,17 +142,17 @@ def __getattr__(self, name):
@functools.wraps(attr)
def wrapper(*args, **kw):
try:
task = async_task(attr(*args, **kw), loop=self._loop)
task = asyncio.ensure_future(attr(*args, **kw),
loop=self._loop)
except Exception as exc:
task = create_future(loop=self._loop)
task = self._loop.create_future()
task.set_exception(exc)
self._results.append(task)
return task
return wrapper
return attr

@asyncio.coroutine
def execute(self, *, return_exceptions=False):
async def execute(self, *, return_exceptions=False):
"""Execute all buffered commands.

Any exception that is raised by any command is caught and
Expand All @@ -168,30 +166,28 @@ def execute(self, *, return_exceptions=False):

if self._pipeline:
if isinstance(self._pool_or_conn, AbcPool):
with (yield from self._pool_or_conn) as conn:
return (yield from self._do_execute(
conn, return_exceptions=return_exceptions))
async with self._pool_or_conn.get() as conn:
return await self._do_execute(
conn, return_exceptions=return_exceptions)
else:
return (yield from self._do_execute(
return await self._do_execute(
self._pool_or_conn,
return_exceptions=return_exceptions))
return_exceptions=return_exceptions)
else:
return (yield from self._gather_result(return_exceptions))
return await self._gather_result(return_exceptions)

@asyncio.coroutine
def _do_execute(self, conn, *, return_exceptions=False):
yield from asyncio.gather(*self._send_pipeline(conn),
loop=self._loop,
return_exceptions=True)
return (yield from self._gather_result(return_exceptions))
async def _do_execute(self, conn, *, return_exceptions=False):
await asyncio.gather(*self._send_pipeline(conn),
loop=self._loop,
return_exceptions=True)
return await self._gather_result(return_exceptions)

@asyncio.coroutine
def _gather_result(self, return_exceptions):
async def _gather_result(self, return_exceptions):
errors = []
results = []
for fut in self._results:
try:
res = yield from fut
res = await fut
results.append(res)
except Exception as exc:
errors.append(exc)
Expand Down Expand Up @@ -257,8 +253,7 @@ class MultiExec(Pipeline):
"""
error_class = MultiExecError

@asyncio.coroutine
def _do_execute(self, conn, *, return_exceptions=False):
async def _do_execute(self, conn, *, return_exceptions=False):
self._waiters = waiters = []
multi = conn.execute('MULTI')
coros = list(self._send_pipeline(conn))
Expand All @@ -267,9 +262,9 @@ def _do_execute(self, conn, *, return_exceptions=False):
return_exceptions=True)
last_error = None
try:
yield from asyncio.shield(gather, loop=self._loop)
await asyncio.shield(gather, loop=self._loop)
except asyncio.CancelledError:
yield from gather
await gather
except Exception as err:
last_error = err
raise
Expand All @@ -286,15 +281,15 @@ def _do_execute(self, conn, *, return_exceptions=False):
# fut.cancel()
else:
try:
results = yield from exec_
results = await exec_
except RedisError as err:
for fut in waiters:
fut.set_exception(err)
else:
assert len(results) == len(waiters), (
"Results does not match waiters", results, waiters)
self._resolve_waiters(results, return_exceptions)
return (yield from self._gather_result(return_exceptions))
return (await self._gather_result(return_exceptions))

def _resolve_waiters(self, results, return_exceptions):
errors = []
Expand All @@ -310,7 +305,7 @@ def _resolve_waiters(self, results, return_exceptions):
def _check_result(self, fut, waiter):
assert waiter not in self._waiters, (fut, waiter, self._waiters)
assert not waiter.done(), waiter
if fut.cancelled(): # yield from gather was cancelled
if fut.cancelled(): # await gather was cancelled
waiter.cancel()
elif fut.exception(): # server replied with error
waiter.set_exception(fut.exception())
Expand Down
38 changes: 17 additions & 21 deletions aioredis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
_set_exception,
coerced_keys_dict,
decode,
async_task,
create_future,
parse_url,
)
from .parser import Reader
Expand Down Expand Up @@ -45,10 +43,9 @@
)


@asyncio.coroutine
def create_connection(address, *, db=None, password=None, ssl=None,
encoding=None, parser=None, loop=None, timeout=None,
connection_cls=None):
async def create_connection(address, *, db=None, password=None, ssl=None,
encoding=None, parser=None, loop=None,
timeout=None, connection_cls=None):
"""Creates redis connection.

Opens connection to Redis server specified by address argument.
Expand Down Expand Up @@ -104,7 +101,7 @@ def create_connection(address, *, db=None, password=None, ssl=None,
if isinstance(address, (list, tuple)):
host, port = address
logger.debug("Creating tcp connection to %r", address)
reader, writer = yield from asyncio.wait_for(open_connection(
reader, writer = await asyncio.wait_for(open_connection(
host, port, limit=MAX_CHUNK_SIZE, ssl=ssl, loop=loop),
timeout, loop=loop)
sock = writer.transport.get_extra_info('socket')
Expand All @@ -114,7 +111,7 @@ def create_connection(address, *, db=None, password=None, ssl=None,
address = tuple(address[:2])
else:
logger.debug("Creating unix connection to %r", address)
reader, writer = yield from asyncio.wait_for(open_unix_connection(
reader, writer = await asyncio.wait_for(open_unix_connection(
address, ssl=ssl, limit=MAX_CHUNK_SIZE, loop=loop),
timeout, loop=loop)
sock = writer.transport.get_extra_info('socket')
Expand All @@ -127,12 +124,12 @@ def create_connection(address, *, db=None, password=None, ssl=None,

try:
if password is not None:
yield from conn.auth(password)
await conn.auth(password)
if db is not None:
yield from conn.select(db)
await conn.select(db)
except Exception:
conn.close()
yield from conn.wait_closed()
await conn.wait_closed()
raise
return conn

Expand All @@ -156,11 +153,12 @@ def __init__(self, reader, writer, *, address, encoding=None,
self._reader.set_parser(
parser(protocolError=ProtocolError, replyError=ReplyError)
)
self._reader_task = async_task(self._read_data(), loop=self._loop)
self._reader_task = asyncio.ensure_future(self._read_data(),
loop=self._loop)
self._db = 0
self._closing = False
self._closed = False
self._close_waiter = create_future(loop=self._loop)
self._close_waiter = loop.create_future()
self._reader_task.add_done_callback(self._close_waiter.set_result)
self._in_transaction = None
self._transaction_error = None # XXX: never used?
Expand All @@ -172,14 +170,13 @@ def __init__(self, reader, writer, *, address, encoding=None,
def __repr__(self):
return '<RedisConnection [db:{}]>'.format(self._db)

@asyncio.coroutine
def _read_data(self):
async def _read_data(self):
"""Response reader task."""
last_error = ConnectionClosedError(
"Connection has been closed by server")
while not self._reader.at_eof():
try:
obj = yield from self._reader.readobj()
obj = await self._reader.readobj()
except asyncio.CancelledError:
# NOTE: reader can get cancelled from `close()` method only.
last_error = RuntimeError('this is unexpected')
Expand Down Expand Up @@ -308,7 +305,7 @@ def execute(self, command, *args, encoding=_NOTSET):
cb = None
if encoding is _NOTSET:
encoding = self._encoding
fut = create_future(loop=self._loop)
fut = self._loop.create_future()
self._writer.write(encode_command(command, *args))
self._waiters.append((fut, encoding, cb))
return fut
Expand Down Expand Up @@ -338,7 +335,7 @@ def execute_pubsub(self, command, *channels):
cmd = encode_command(command, *(ch.name for ch in channels))
res = []
for ch in channels:
fut = create_future(loop=self._loop)
fut = self._loop.create_future()
res.append(fut)
cb = partial(self._update_pubsub, ch=ch)
self._waiters.append((fut, None, cb))
Expand Down Expand Up @@ -384,10 +381,9 @@ def closed(self):
self._loop.call_soon(self._do_close, None)
return closed

@asyncio.coroutine
def wait_closed(self):
async def wait_closed(self):
"""Coroutine waiting until connection is closed."""
yield from asyncio.shield(self._close_waiter, loop=self._loop)
await asyncio.shield(self._close_waiter, loop=self._loop)

@property
def db(self):
Expand Down
Loading