diff --git a/asyncpg/_testbase.py b/asyncpg/_testbase.py index 71ea3053..cab04631 100644 --- a/asyncpg/_testbase.py +++ b/asyncpg/_testbase.py @@ -12,6 +12,7 @@ import inspect import logging import os +import re import time import unittest @@ -95,6 +96,30 @@ def assertRunUnder(self, delta): raise AssertionError( 'running block took longer than {}'.format(delta)) + @contextlib.contextmanager + def assertLoopErrorHandlerCalled(self, msg_re: str): + contexts = [] + + def handler(loop, ctx): + contexts.append(ctx) + + old_handler = self.loop.get_exception_handler() + self.loop.set_exception_handler(handler) + try: + yield + + for ctx in contexts: + msg = ctx.get('message') + if msg and re.search(msg_re, msg): + return + + raise AssertionError( + 'no message matching {!r} was logged with ' + 'loop.call_exception_handler()'.format(msg_re)) + + finally: + self.loop.set_exception_handler(old_handler) + _default_cluster = None diff --git a/asyncpg/connection.py b/asyncpg/connection.py index f57c3a17..97ab8998 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -512,13 +512,20 @@ def _get_reset_query(self): caps = self._server_caps - _reset_query = '' + _reset_query = [] + if self._protocol.is_in_transaction() or self._top_xact is not None: + self._loop.call_exception_handler({ + 'message': 'Resetting connection with an ' + 'active transaction {!r}'.format(self) + }) + self._top_xact = None + _reset_query.append('ROLLBACK;') if caps.advisory_locks: - _reset_query += 'SELECT pg_advisory_unlock_all();\n' + _reset_query.append('SELECT pg_advisory_unlock_all();') if caps.cursors: - _reset_query += 'CLOSE ALL;\n' + _reset_query.append('CLOSE ALL;') if caps.notifications and caps.plpgsql: - _reset_query += ''' + _reset_query.append(''' DO $$ BEGIN PERFORM * FROM pg_listening_channels() LIMIT 1; @@ -527,10 +534,11 @@ def _get_reset_query(self): END IF; END; $$; - ''' + ''') if caps.sql_reset: - _reset_query += 'RESET ALL;\n' + _reset_query.append('RESET ALL;') + _reset_query = '\n'.join(_reset_query) self._reset_query = _reset_query return _reset_query diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index 98b8d698..59a3d387 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -120,6 +120,9 @@ cdef class BaseProtocol(CoreProtocol): def get_settings(self): return self.settings + def is_in_transaction(self): + return self.xact_status == PQTRANS_INTRANS + async def prepare(self, stmt_name, query, timeout): if self.cancel_waiter is not None: await self.cancel_waiter diff --git a/asyncpg/transaction.py b/asyncpg/transaction.py index edbdeec1..f06c3af7 100644 --- a/asyncpg/transaction.py +++ b/asyncpg/transaction.py @@ -84,6 +84,10 @@ async def start(self): con = self._connection if con._top_xact is None: + if con._protocol.is_in_transaction(): + raise apg_errors.InterfaceError( + 'cannot use Connection.transaction() in ' + 'a manually started transaction') con._top_xact = self else: # Nested transaction block diff --git a/tests/test_pool.py b/tests/test_pool.py index 8c4e0cbe..bb166445 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -349,6 +349,35 @@ async def sleep_and_release(): async with pool.acquire() as con: await con.fetchval('SELECT 1') + async def test_pool_release_in_xact(self): + """Test that Connection.reset() closes any open transaction.""" + async with self.create_pool(database='postgres', + min_size=1, max_size=1) as pool: + async def get_xact_id(con): + return await con.fetchval('select txid_current()') + + with self.assertLoopErrorHandlerCalled('an active transaction'): + async with pool.acquire() as con: + real_con = con._con # unwrap PoolConnectionProxy + + id1 = await get_xact_id(con) + + tr = con.transaction() + self.assertIsNone(con._con._top_xact) + await tr.start() + self.assertIs(real_con._top_xact, tr) + + id2 = await get_xact_id(con) + self.assertNotEqual(id1, id2) + + self.assertIsNone(real_con._top_xact) + + async with pool.acquire() as con: + self.assertIs(con._con, real_con) + self.assertIsNone(con._con._top_xact) + id3 = await get_xact_id(con) + self.assertNotEqual(id2, id3) + @unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing') class TestHostStandby(tb.ConnectedTestCase): diff --git a/tests/test_test.py b/tests/test_test.py index de19d30f..61820945 100644 --- a/tests/test_test.py +++ b/tests/test_test.py @@ -33,3 +33,14 @@ def test_tests_fail_1(self): suite.run(result) self.assertIn('ZeroDivisionError', result.errors[0][1]) + + +class TestHelpers(tb.TestCase): + + async def test_tests_assertLoopErrorHandlerCalled_01(self): + with self.assertRaisesRegex(AssertionError, r'no message.*was logged'): + with self.assertLoopErrorHandlerCalled('aa'): + self.loop.call_exception_handler({'message': 'bb a bb'}) + + with self.assertLoopErrorHandlerCalled('aa'): + self.loop.call_exception_handler({'message': 'bbaabb'}) diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 2f606379..2e70dad4 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -139,3 +139,21 @@ async def test_transaction_interface_errors(self): async with tr: async with tr: pass + + async def test_transaction_within_manual_transaction(self): + self.assertIsNone(self.con._top_xact) + + await self.con.execute('BEGIN') + + tr = self.con.transaction() + self.assertIsNone(self.con._top_xact) + + with self.assertRaisesRegex(asyncpg.InterfaceError, + 'cannot use Connection.transaction'): + await tr.start() + + with self.assertLoopErrorHandlerCalled( + 'Resetting connection with an active transaction'): + await self.con.reset() + + self.assertIsNone(self.con._top_xact)