From faad8ded522c7911a10a285141c4bb6a5b725ab9 Mon Sep 17 00:00:00 2001 From: Nick Humrich Date: Fri, 4 Oct 2019 20:57:44 -0600 Subject: [PATCH] Add middleware support --- asyncpg/_testbase/__init__.py | 3 ++- asyncpg/connect_utils.py | 7 ++--- asyncpg/connection.py | 22 +++++++++++---- asyncpg/pool.py | 50 +++++++++++++++++++++++++++++++++-- docs/installation.rst | 1 + tests/test_pool.py | 42 +++++++++++++++++++++++++++++ 6 files changed, 114 insertions(+), 11 deletions(-) diff --git a/asyncpg/_testbase/__init__.py b/asyncpg/_testbase/__init__.py index baf55c1b..d0a94bf0 100644 --- a/asyncpg/_testbase/__init__.py +++ b/asyncpg/_testbase/__init__.py @@ -264,6 +264,7 @@ def create_pool(dsn=None, *, setup=None, init=None, loop=None, + middlewares=None, pool_class=pg_pool.Pool, connection_class=pg_connection.Connection, **connect_kwargs): @@ -272,7 +273,7 @@ def create_pool(dsn=None, *, min_size=min_size, max_size=max_size, max_queries=max_queries, loop=loop, setup=setup, init=init, max_inactive_connection_lifetime=max_inactive_connection_lifetime, - connection_class=connection_class, + connection_class=connection_class, middlewares=middlewares, **connect_kwargs) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index ec3d1090..7abba0bd 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -594,7 +594,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *, async def _connect_addr(*, addr, loop, timeout, params, config, - connection_class): + middlewares, connection_class): assert loop is not None if timeout <= 0: @@ -633,12 +633,12 @@ async def _connect_addr(*, addr, loop, timeout, params, config, tr.close() raise - con = connection_class(pr, tr, loop, addr, config, params) + con = connection_class(pr, tr, loop, addr, config, params, middlewares) pr.set_connection(con) return con -async def _connect(*, loop, timeout, connection_class, **kwargs): +async def _connect(*, loop, timeout, middlewares, connection_class, **kwargs): if loop is None: loop = asyncio.get_event_loop() @@ -652,6 +652,7 @@ async def _connect(*, loop, timeout, connection_class, **kwargs): con = await _connect_addr( addr=addr, loop=loop, timeout=timeout, params=params, config=config, + middlewares=middlewares, connection_class=connection_class) except (OSError, asyncio.TimeoutError, ConnectionError) as ex: last_error = ex diff --git a/asyncpg/connection.py b/asyncpg/connection.py index ef1b595d..384ed46b 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -41,7 +41,7 @@ class Connection(metaclass=ConnectionMeta): """ __slots__ = ('_protocol', '_transport', '_loop', - '_top_xact', '_aborted', + '_top_xact', '_aborted', '_middlewares', '_pool_release_ctr', '_stmt_cache', '_stmts_to_close', '_listeners', '_server_version', '_server_caps', '_intro_query', '_reset_query', '_proxy', @@ -52,7 +52,8 @@ class Connection(metaclass=ConnectionMeta): def __init__(self, protocol, transport, loop, addr: (str, int) or str, config: connect_utils._ClientConfiguration, - params: connect_utils._ConnectionParameters): + params: connect_utils._ConnectionParameters, + _middlewares=None): self._protocol = protocol self._transport = transport self._loop = loop @@ -91,7 +92,7 @@ def __init__(self, protocol, transport, loop, self._reset_query = None self._proxy = None - + self._middlewares = _middlewares # Used to serialize operations that might involve anonymous # statements. Specifically, we want to make the following # operation atomic: @@ -1399,8 +1400,13 @@ async def reload_schema_state(self): async def _execute(self, query, args, limit, timeout, return_status=False): with self._stmt_exclusive_section: - result, _ = await self.__execute( - query, args, limit, timeout, return_status=return_status) + wrapped = self.__execute + if self._middlewares: + for m in reversed(self._middlewares): + wrapped = await m(connection=self, handler=wrapped) + + result, _ = await wrapped(query, args, limit, + timeout, return_status=return_status) return result async def __execute(self, query, args, limit, timeout, @@ -1491,6 +1497,7 @@ async def connect(dsn=None, *, max_cacheable_statement_size=1024 * 15, command_timeout=None, ssl=None, + middlewares=None, connection_class=Connection, server_settings=None): r"""A coroutine to establish a connection to a PostgreSQL server. @@ -1607,6 +1614,10 @@ async def connect(dsn=None, *, PostgreSQL documentation for a `list of supported options `_. + :param middlewares: + An optional list of middleware functions. Refer to documentation + on create_pool. + :param Connection connection_class: Class of the returned connection object. Must be a subclass of :class:`~asyncpg.connection.Connection`. @@ -1672,6 +1683,7 @@ async def connect(dsn=None, *, ssl=ssl, database=database, server_settings=server_settings, command_timeout=command_timeout, + middlewares=middlewares, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, max_cacheable_statement_size=max_cacheable_statement_size) diff --git a/asyncpg/pool.py b/asyncpg/pool.py index 20a3234e..207aca44 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -305,7 +305,7 @@ class Pool: """ __slots__ = ( - '_queue', '_loop', '_minsize', '_maxsize', + '_queue', '_loop', '_minsize', '_maxsize', '_middlewares', '_init', '_connect_args', '_connect_kwargs', '_working_addr', '_working_config', '_working_params', '_holders', '_initialized', '_initializing', '_closing', @@ -320,6 +320,7 @@ def __init__(self, *connect_args, max_inactive_connection_lifetime, setup, init, + middlewares, loop, connection_class, **connect_kwargs): @@ -377,6 +378,7 @@ def __init__(self, *connect_args, self._closed = False self._generation = 0 self._init = init + self._middlewares = middlewares self._connect_args = connect_args self._connect_kwargs = connect_kwargs @@ -469,6 +471,7 @@ async def _get_new_connection(self): *self._connect_args, loop=self._loop, connection_class=self._connection_class, + middlewares=self._middlewares, **self._connect_kwargs) self._working_addr = con._addr @@ -483,6 +486,7 @@ async def _get_new_connection(self): addr=self._working_addr, timeout=self._working_params.connect_timeout, config=self._working_config, + middlewares=self._middlewares, params=self._working_params, connection_class=self._connection_class) @@ -784,6 +788,29 @@ def __await__(self): return self.pool._acquire(self.timeout).__await__() +def middleware(f): + """Decorator for adding a middleware + + Can be used like such + + .. code-block:: python + + @pool.middleware + async def my_middleware(query, args, limit, + timeout, return_status, *, handler, conn): + print('do something before') + result, stmt = await handler(query, args, limit, + timeout, return_status) + print('do something after') + return result, stmt + + my_pool = await pool.create_pool(middlewares=[my_middleware]) + """ + async def middleware_factory(connection, handler): + return functools.partial(f, connection=connection, handler=handler) + return middleware_factory + + def create_pool(dsn=None, *, min_size=10, max_size=10, @@ -791,6 +818,7 @@ def create_pool(dsn=None, *, max_inactive_connection_lifetime=300.0, setup=None, init=None, + middlewares=None, loop=None, connection_class=connection.Connection, **connect_kwargs): @@ -866,6 +894,23 @@ def create_pool(dsn=None, *, or :meth:`Connection.set_type_codec() <\ asyncpg.connection.Connection.set_type_codec>`. + :param middlewares: + A list of middleware functions to be middleware just + before a connection excecutes a statement. + Syntax of a middleware is as follows: + + .. code-block:: python + + async def middleware_factory(connection, handler): + async def middleware(query, args, limit, + timeout, return_status): + print('do something before') + result, stmt = await handler(query, args, limit, + timeout, return_status) + print('do something after') + return result, stmt + return middleware + :param loop: An asyncio event loop instance. If ``None``, the default event loop will be used. @@ -893,6 +938,7 @@ def create_pool(dsn=None, *, dsn, connection_class=connection_class, min_size=min_size, max_size=max_size, - max_queries=max_queries, loop=loop, setup=setup, init=init, + max_queries=max_queries, loop=loop, setup=setup, + middlewares=middlewares, init=init, max_inactive_connection_lifetime=max_inactive_connection_lifetime, **connect_kwargs) diff --git a/docs/installation.rst b/docs/installation.rst index 6d9ec2ef..e9b9c344 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -30,6 +30,7 @@ If you want to build **asyncpg** from a Git checkout you will need: * CPython header files. These can usually be obtained by installing the relevant Python development package: **python3-dev** on Debian/Ubuntu, **python3-devel** on RHEL/Fedora. + * Clone the repo with submodules (`git clone --recursive`, or `git submodules init; git submodules update`) Once the above requirements are satisfied, run the following command in the root of the source checkout: diff --git a/tests/test_pool.py b/tests/test_pool.py index e51923e4..36bf312b 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -76,6 +76,48 @@ async def worker(): tasks = [worker() for _ in range(n)] await asyncio.gather(*tasks) + async def test_pool_with_middleware(self): + called = False + + async def my_middleware_factory(connection, handler): + async def middleware(query, args, limit, timeout, return_status): + nonlocal called + called = True + return await handler(query, args, limit, + timeout, return_status) + return middleware + + pool = await self.create_pool(database='postgres', + min_size=1, max_size=1, + middlewares=[my_middleware_factory]) + + con = await pool.acquire(timeout=5) + await con.fetchval('SELECT 1') + assert called + + pool.terminate() + del con + + async def test_pool_with_middleware_decorator(self): + called = False + + @pg_pool.middleware + async def my_middleware(query, args, limit, timeout, return_status, + *, connection, handler): + nonlocal called + called = True + return await handler(query, args, limit, + timeout, return_status) + + pool = await self.create_pool(database='postgres', min_size=1, + max_size=1, middlewares=[my_middleware]) + con = await pool.acquire(timeout=5) + await con.fetchval('SELECT 1') + assert called + + pool.terminate() + del con + async def test_pool_03(self): pool = await self.create_pool(database='postgres', min_size=1, max_size=1)