diff --git a/Makefile b/Makefile index e3c79b84..ab7e57c6 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ doc: cd docs && make html - echo "open file://`pwd`/docs/_build/html/index.html" + @echo "open file://`pwd`/docs/_build/html/index.html" pep: pep8 aiopg examples tests @@ -18,6 +18,7 @@ vtest: pep flake cov cover coverage: pep flake py.test --cov=aiopg --cov=tests --cov-report=html --cov-report=term tests + @echo "open file://`pwd`/htmlcov/index.html" clean: find . -name __pycache__ |xargs rm -rf diff --git a/aiopg/connection.py b/aiopg/connection.py index ca4f9a59..3669656e 100644 --- a/aiopg/connection.py +++ b/aiopg/connection.py @@ -166,8 +166,23 @@ def _create_waiter(self, func_name): def _poll(self, waiter, timeout): assert waiter is self._waiter, (waiter, self._waiter) self._ready(self._weakref) + + @asyncio.coroutine + def cancel(): + if not self._isexecuting(): + return + self._waiter = asyncio.Future(loop=self._loop) + self._conn.cancel() + try: + yield from self._waiter + except psycopg2.extensions.QueryCanceledError: + pass + try: yield from asyncio.wait_for(self._waiter, timeout, loop=self._loop) + except (asyncio.CancelledError, asyncio.TimeoutError) as exc: + yield from asyncio.shield(cancel(), loop=self._loop) + raise exc finally: self._waiter = None @@ -282,14 +297,24 @@ def tpc_recover(self): @asyncio.coroutine def cancel(self, timeout=None): """Cancel the current database operation.""" - waiter = self._create_waiter('cancel') - self._conn.cancel() - if timeout is None: - timeout = self._timeout - try: - yield from self._poll(waiter, timeout) - except psycopg2.extensions.QueryCanceledError: - pass + if timeout is not None: + warnings.warn('timeout parameter is deprecated and never used', + DeprecationWarning) + if not self._isexecuting(): + return + if self._waiter is not None: + self._waiter.cancel() + + @asyncio.coroutine + def cancel(): + self._waiter = asyncio.Future(loop=self._loop) + self._conn.cancel() + try: + yield from self._waiter + except psycopg2.extensions.QueryCanceledError: + pass + + yield from asyncio.shield(cancel(), loop=self._loop) @asyncio.coroutine def reset(self): diff --git a/aiopg/cursor.py b/aiopg/cursor.py index 954ec04d..6d51a208 100644 --- a/aiopg/cursor.py +++ b/aiopg/cursor.py @@ -111,7 +111,6 @@ def execute(self, operation, parameters=None, *, timeout=None): try: yield from self._conn._poll(waiter, timeout) except asyncio.TimeoutError: - yield from self._conn.cancel() self._impl.close() raise diff --git a/tests/test_connection.py b/tests/test_connection.py index c8087dc1..130e9e0c 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -340,7 +340,7 @@ def go(): self.loop.run_until_complete(go()) - def test_cancel(self): + def test_cancel_noop(self): @asyncio.coroutine def go(): @@ -354,7 +354,23 @@ def test_cancel_with_timeout(self): @asyncio.coroutine def go(): conn = yield from self.connect() - yield from conn.cancel(10) + with self.assertWarns(DeprecationWarning): + yield from conn.cancel(10) + + self.loop.run_until_complete(go()) + + def test_cancel_pending_op(self): + + @asyncio.coroutine + def go(): + conn = yield from self.connect() + cur = yield from conn.cursor() + task = asyncio.async(cur.execute("SELECT pg_sleep(10)"), + loop=self.loop) + yield from asyncio.sleep(0.01, loop=self.loop) + yield from conn.cancel() + with self.assertRaises(asyncio.CancelledError): + yield from task self.loop.run_until_complete(go())