Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions asyncpg/_testbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import inspect
import logging
import os
import re
import time
import unittest

Expand Down Expand Up @@ -95,6 +96,30 @@ def assertRunUnder(self, delta):
raise AssertionError(
'running block took longer than {}'.format(delta))

@contextlib.contextmanager
def assertLoopErrorHandlerCalled(self, msg_re: str):
contexts = []

def handler(loop, ctx):
contexts.append(ctx)

old_handler = self.loop.get_exception_handler()
self.loop.set_exception_handler(handler)
try:
yield

for ctx in contexts:
msg = ctx.get('message')
if msg and re.search(msg_re, msg):
return

raise AssertionError(
'no message matching {!r} was logged with '
'loop.call_exception_handler()'.format(msg_re))

finally:
self.loop.set_exception_handler(old_handler)


_default_cluster = None

Expand Down
20 changes: 14 additions & 6 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,13 +512,20 @@ def _get_reset_query(self):

caps = self._server_caps

_reset_query = ''
_reset_query = []
if self._protocol.is_in_transaction() or self._top_xact is not None:
self._loop.call_exception_handler({
'message': 'Resetting connection with an '
'active transaction {!r}'.format(self)
})
self._top_xact = None
_reset_query.append('ROLLBACK;')
if caps.advisory_locks:
_reset_query += 'SELECT pg_advisory_unlock_all();\n'
_reset_query.append('SELECT pg_advisory_unlock_all();')
if caps.cursors:
_reset_query += 'CLOSE ALL;\n'
_reset_query.append('CLOSE ALL;')
if caps.notifications and caps.plpgsql:
_reset_query += '''
_reset_query.append('''
DO $$
BEGIN
PERFORM * FROM pg_listening_channels() LIMIT 1;
Expand All @@ -527,10 +534,11 @@ def _get_reset_query(self):
END IF;
END;
$$;
'''
''')
if caps.sql_reset:
_reset_query += 'RESET ALL;\n'
_reset_query.append('RESET ALL;')

_reset_query = '\n'.join(_reset_query)
self._reset_query = _reset_query

return _reset_query
Expand Down
3 changes: 3 additions & 0 deletions asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ cdef class BaseProtocol(CoreProtocol):
def get_settings(self):
return self.settings

def is_in_transaction(self):
return self.xact_status == PQTRANS_INTRANS

async def prepare(self, stmt_name, query, timeout):
if self.cancel_waiter is not None:
await self.cancel_waiter
Expand Down
4 changes: 4 additions & 0 deletions asyncpg/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ async def start(self):
con = self._connection

if con._top_xact is None:
if con._protocol.is_in_transaction():
raise apg_errors.InterfaceError(
'cannot use Connection.transaction() in '
'a manually started transaction')
con._top_xact = self
else:
# Nested transaction block
Expand Down
29 changes: 29 additions & 0 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,35 @@ async def sleep_and_release():
async with pool.acquire() as con:
await con.fetchval('SELECT 1')

async def test_pool_release_in_xact(self):
"""Test that Connection.reset() closes any open transaction."""
async with self.create_pool(database='postgres',
min_size=1, max_size=1) as pool:
async def get_xact_id(con):
return await con.fetchval('select txid_current()')

with self.assertLoopErrorHandlerCalled('an active transaction'):
async with pool.acquire() as con:
real_con = con._con # unwrap PoolConnectionProxy

id1 = await get_xact_id(con)

tr = con.transaction()
self.assertIsNone(con._con._top_xact)
await tr.start()
self.assertIs(real_con._top_xact, tr)

id2 = await get_xact_id(con)
self.assertNotEqual(id1, id2)

self.assertIsNone(real_con._top_xact)

async with pool.acquire() as con:
self.assertIs(con._con, real_con)
self.assertIsNone(con._con._top_xact)
id3 = await get_xact_id(con)
self.assertNotEqual(id2, id3)


@unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing')
class TestHostStandby(tb.ConnectedTestCase):
Expand Down
11 changes: 11 additions & 0 deletions tests/test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,14 @@ def test_tests_fail_1(self):
suite.run(result)

self.assertIn('ZeroDivisionError', result.errors[0][1])


class TestHelpers(tb.TestCase):

async def test_tests_assertLoopErrorHandlerCalled_01(self):
with self.assertRaisesRegex(AssertionError, r'no message.*was logged'):
with self.assertLoopErrorHandlerCalled('aa'):
self.loop.call_exception_handler({'message': 'bb a bb'})

with self.assertLoopErrorHandlerCalled('aa'):
self.loop.call_exception_handler({'message': 'bbaabb'})
18 changes: 18 additions & 0 deletions tests/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,21 @@ async def test_transaction_interface_errors(self):
async with tr:
async with tr:
pass

async def test_transaction_within_manual_transaction(self):
self.assertIsNone(self.con._top_xact)

await self.con.execute('BEGIN')

tr = self.con.transaction()
self.assertIsNone(self.con._top_xact)

with self.assertRaisesRegex(asyncpg.InterfaceError,
'cannot use Connection.transaction'):
await tr.start()

with self.assertLoopErrorHandlerCalled(
'Resetting connection with an active transaction'):
await self.con.reset()

self.assertIsNone(self.con._top_xact)