Skip to content

Commit

Permalink
Keep track of timeout in executemany properly
Browse files Browse the repository at this point in the history
  • Loading branch information
elprans authored and fantix committed Nov 24, 2020
1 parent c4c5b5e commit b137184
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 17 deletions.
24 changes: 20 additions & 4 deletions asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ cdef class BaseProtocol(CoreProtocol):

self._check_state()
timeout = self._get_timeout_impl(timeout)
timer = Timer(timeout)

# Make sure the argument sequence is encoded lazily with
# this generator expression to keep the memory pressure under
Expand All @@ -230,12 +231,20 @@ cdef class BaseProtocol(CoreProtocol):
self.queries_count += 1

while more:
await self.writing_allowed.wait()
# On Windows the above event somehow won't allow context
# switch, so forcing one with sleep(0) here
await asyncio.sleep(0)
with timer:
await asyncio.wait_for(
self.writing_allowed.wait(),
timeout=timer.get_remaining_budget())
# On Windows the above event somehow won't allow context
# switch, so forcing one with sleep(0) here
await asyncio.sleep(0)
if not timer.has_budget_greater_than(0):
raise asyncio.TimeoutError
more = self._bind_execute_many_more() # network op

except asyncio.TimeoutError as e:
self._bind_execute_many_fail(e) # network op

except Exception as ex:
waiter.set_exception(ex)
self._coreproto_error()
Expand Down Expand Up @@ -951,6 +960,13 @@ class Timer:
def get_remaining_budget(self):
return self._budget

def has_budget_greater_than(self, amount):
if self._budget is None:
# Unlimited budget.
return True
else:
return self._budget > amount


class Protocol(BaseProtocol, asyncio.Protocol):
pass
Expand Down
57 changes: 44 additions & 13 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ class TestExecuteMany(tb.ConnectedTestCase):
def setUp(self):
super().setUp()
self.loop.run_until_complete(self.con.execute(
'CREATE TEMP TABLE exmany (a text, b int PRIMARY KEY)'))
'CREATE TABLE exmany (a text, b int PRIMARY KEY)'))

def tearDown(self):
self.loop.run_until_complete(self.con.execute('DROP TABLE exmany'))
super().tearDown()

async def test_basic(self):
async def test_executemany_basic(self):
result = await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2)
''', [
Expand Down Expand Up @@ -139,7 +139,7 @@ async def test_basic(self):
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])

async def test_bad_input(self):
async def test_executemany_bad_input(self):
bad_data = ([1 / 0] for v in range(10))

with self.assertRaises(ZeroDivisionError):
Expand All @@ -154,7 +154,7 @@ async def test_bad_input(self):
INSERT INTO exmany (b)VALUES($1)
''', good_data)

async def test_server_failure(self):
async def test_executemany_server_failure(self):
with self.assertRaises(UniqueViolationError):
await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2)
Expand All @@ -164,7 +164,7 @@ async def test_server_failure(self):
result = await self.con.fetch('SELECT * FROM exmany')
self.assertEqual(result, [])

async def test_server_failure_after_writes(self):
async def test_executemany_server_failure_after_writes(self):
with self.assertRaises(UniqueViolationError):
await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2)
Expand All @@ -174,7 +174,7 @@ async def test_server_failure_after_writes(self):
result = await self.con.fetch('SELECT b FROM exmany')
self.assertEqual(result, [])

async def test_server_failure_during_writes(self):
async def test_executemany_server_failure_during_writes(self):
# failure at the beginning, server error detected in the middle
pos = 0

Expand All @@ -195,23 +195,54 @@ def gen():
self.assertEqual(result, [])
self.assertLess(pos, 128, 'should stop early')

async def test_client_failure_after_writes(self):
async def test_executemany_client_failure_after_writes(self):
with self.assertRaises(ZeroDivisionError):
await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2)
''', (('a' * 32768, y + y / y) for y in range(10, -1, -1)))
result = await self.con.fetch('SELECT b FROM exmany')
self.assertEqual(result, [])

async def test_timeout(self):
async def test_executemany_timeout(self):
with self.assertRaises(asyncio.TimeoutError):
await self.con.executemany('''
INSERT INTO exmany VALUES(pg_sleep(0.1), $1)
''', [[x] for x in range(128)], timeout=0.5)
INSERT INTO exmany VALUES(pg_sleep(0.1) || $1, $2)
''', [('a' * 32768, x) for x in range(128)], timeout=0.5)
result = await self.con.fetch('SELECT * FROM exmany')
self.assertEqual(result, [])

async def test_client_failure_in_transaction(self):
async def test_executemany_timeout_flow_control(self):
event = asyncio.Event()

async def locker():
test_func = getattr(self, self._testMethodName).__func__
opts = getattr(test_func, '__connect_options__', {})
conn = await self.connect(**opts)
try:
tx = conn.transaction()
await tx.start()
await conn.execute("UPDATE exmany SET a = '1' WHERE b = 10")
event.set()
await asyncio.sleep(1)
await tx.rollback()
finally:
event.set()
await conn.close()

await self.con.executemany('''
INSERT INTO exmany VALUES(NULL, $1)
''', [(x,) for x in range(128)])
fut = asyncio.ensure_future(locker())
await event.wait()
with self.assertRaises(asyncio.TimeoutError):
await self.con.executemany('''
UPDATE exmany SET a = $1 WHERE b = $2
''', [('a' * 32768, x) for x in range(128)], timeout=0.5)
await fut
result = await self.con.fetch('SELECT * FROM exmany WHERE a IS NOT NULL')
self.assertEqual(result, [])

async def test_executemany_client_failure_in_transaction(self):
tx = self.con.transaction()
await tx.start()
with self.assertRaises(ZeroDivisionError):
Expand All @@ -226,7 +257,7 @@ async def test_client_failure_in_transaction(self):
result = await self.con.fetch('SELECT b FROM exmany')
self.assertEqual(result, [])

async def test_client_server_failure_conflict(self):
async def test_executemany_client_server_failure_conflict(self):
self.con._transport.set_write_buffer_limits(65536 * 64, 16384 * 64)
with self.assertRaises(UniqueViolationError):
await self.con.executemany('''
Expand All @@ -235,7 +266,7 @@ async def test_client_server_failure_conflict(self):
result = await self.con.fetch('SELECT b FROM exmany')
self.assertEqual(result, [])

async def test_prepare(self):
async def test_executemany_prepare(self):
stmt = await self.con.prepare('''
INSERT INTO exmany VALUES($1, $2)
''')
Expand Down

0 comments on commit b137184

Please sign in to comment.