Skip to content

Commit

Permalink
Rewrite executemany() to batch args for performance.
Browse files Browse the repository at this point in the history
Now `Bind` and `Execute` pairs are batched into 4 x 32KB buffers to take
advantage of `writelines()`. A single `Sync` is sent at last, so that
all args live in the same transaction.

Closes: #289
  • Loading branch information
fantix committed Nov 24, 2020
1 parent 92aa806 commit 557f515
Show file tree
Hide file tree
Showing 7 changed files with 330 additions and 97 deletions.
10 changes: 10 additions & 0 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,13 @@ async def executemany(self, command: str, args, *, timeout: float=None):
.. versionchanged:: 0.11.0
`timeout` became a keyword-only parameter.
.. versionchanged:: 0.22.0
The execution was changed to be in an implicit transaction if there
was no explicit transaction, so that it will no longer end up with
partial success. If you still need the previous behavior to
progressively execute many args, please use a loop with prepared
statement instead.
"""
self._check_open()
return await self._executemany(command, args, timeout)
Expand Down Expand Up @@ -1010,6 +1017,9 @@ async def _copy_in(self, copy_stmt, source, timeout):
f = source
elif isinstance(source, collections.abc.AsyncIterable):
# assuming calling output returns an awaitable.
# copy_in() is designed to handle very large amounts of data, and
# the source async iterable is allowed to return an arbitrary
# amount of data on every iteration.
reader = source
else:
# assuming source is an instance supporting the buffer protocol.
Expand Down
24 changes: 21 additions & 3 deletions asyncpg/prepared_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,24 @@ async def fetchrow(self, *args, timeout=None):
return None
return data[0]

async def __bind_execute(self, args, limit, timeout):
@connresource.guarded
async def executemany(self, args, *, timeout: float=None):
"""Execute the statement for each sequence of arguments in *args*.
:param args: An iterable containing sequences of arguments.
:param float timeout: Optional timeout value in seconds.
:return None: This method discards the results of the operations.
.. versionadded:: 0.22.0
"""
return await self.__do_execute(
lambda protocol: protocol.bind_execute_many(
self._state, args, '', timeout))

async def __do_execute(self, executor):
protocol = self._connection._protocol
try:
data, status, _ = await protocol.bind_execute(
self._state, args, '', limit, True, timeout)
return await executor(protocol)
except exceptions.OutdatedSchemaCacheError:
await self._connection.reload_schema_state()
# We can not find all manually created prepared statements, so just
Expand All @@ -215,6 +228,11 @@ async def __bind_execute(self, args, limit, timeout):
# invalidate themselves (unfortunately, clearing caches again).
self._state.mark_closed()
raise

async def __bind_execute(self, args, limit, timeout):
data, status, _ = await self.__do_execute(
lambda protocol: protocol.bind_execute(
self._state, args, '', limit, True, timeout))
self._last_status = status
return data

Expand Down
2 changes: 2 additions & 0 deletions asyncpg/protocol/consts.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@
DEF _MAXINT32 = 2**31 - 1
DEF _COPY_BUFFER_SIZE = 524288
DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0"
DEF _EXECUTE_MANY_BUF_NUM = 4
DEF _EXECUTE_MANY_BUF_SIZE = 32768
11 changes: 9 additions & 2 deletions asyncpg/protocol/coreproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ cdef class CoreProtocol:
# True - completed, False - suspended
bint result_execute_completed

cpdef is_in_transaction(self)
cdef _process__auth(self, char mtype)
cdef _process__prepare(self, char mtype)
cdef _process__bind_execute(self, char mtype)
Expand Down Expand Up @@ -146,6 +147,7 @@ cdef class CoreProtocol:
cdef _auth_password_message_sasl_continue(self, bytes server_response)

cdef _write(self, buf)
cdef _writelines(self, list buffers)

cdef _read_server_messages(self)

Expand All @@ -155,9 +157,13 @@ cdef class CoreProtocol:

cdef _ensure_connected(self)

cdef WriteBuffer _build_parse_message(self, str stmt_name, str query)
cdef WriteBuffer _build_bind_message(self, str portal_name,
str stmt_name,
WriteBuffer bind_data)
cdef WriteBuffer _build_empty_bind_data(self)
cdef WriteBuffer _build_execute_message(self, str portal_name,
int32_t limit)


cdef _connect(self)
Expand All @@ -166,8 +172,9 @@ cdef class CoreProtocol:
WriteBuffer bind_data, int32_t limit)
cdef _bind_execute(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data)
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data)
cdef bint _bind_execute_many_more(self, bint first=*)
cdef _bind(self, str portal_name, str stmt_name,
WriteBuffer bind_data)
cdef _execute(self, str portal_name, int32_t limit)
Expand Down
174 changes: 129 additions & 45 deletions asyncpg/protocol/coreproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ cdef class CoreProtocol:

self._reset_result()

cpdef is_in_transaction(self):
# PQTRANS_INTRANS = idle, within transaction block
# PQTRANS_INERROR = idle, within failed transaction
return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)

cdef _read_server_messages(self):
cdef:
char mtype
Expand Down Expand Up @@ -263,27 +268,16 @@ cdef class CoreProtocol:
elif mtype == b'Z':
# ReadyForQuery
self._parse_msg_ready_for_query()
if self.result_type == RESULT_FAILED:
self._push_result()
else:
try:
buf = <WriteBuffer>next(self._execute_iter)
except StopIteration:
self._push_result()
except Exception as e:
self.result_type = RESULT_FAILED
self.result = e
self._push_result()
else:
# Next iteration over the executemany() arg sequence
self._send_bind_message(
self._execute_portal_name, self._execute_stmt_name,
buf, 0)
self._push_result()

elif mtype == b'I':
# EmptyQueryResponse
self.buffer.discard_message()

elif mtype == b'1':
# ParseComplete
self.buffer.discard_message()

cdef _process__bind(self, char mtype):
if mtype == b'E':
# ErrorResponse
Expand Down Expand Up @@ -780,6 +774,17 @@ cdef class CoreProtocol:
if self.con_status != CONNECTION_OK:
raise apg_exc.InternalClientError('not connected')

cdef WriteBuffer _build_parse_message(self, str stmt_name, str query):
cdef WriteBuffer buf

buf = WriteBuffer.new_message(b'P')
buf.write_str(stmt_name, self.encoding)
buf.write_str(query, self.encoding)
buf.write_int16(0)

buf.end_message()
return buf

cdef WriteBuffer _build_bind_message(self, str portal_name,
str stmt_name,
WriteBuffer bind_data):
Expand All @@ -795,6 +800,25 @@ cdef class CoreProtocol:
buf.end_message()
return buf

cdef WriteBuffer _build_empty_bind_data(self):
cdef WriteBuffer buf
buf = WriteBuffer.new()
buf.write_int16(0) # The number of parameter format codes
buf.write_int16(0) # The number of parameter values
buf.write_int16(0) # The number of result-column format codes
return buf

cdef WriteBuffer _build_execute_message(self, str portal_name,
int32_t limit):
cdef WriteBuffer buf

buf = WriteBuffer.new_message(b'E')
buf.write_str(portal_name, self.encoding) # name of the portal
buf.write_int32(limit) # number of rows to return; 0 - all

buf.end_message()
return buf

# API for subclasses

cdef _connect(self):
Expand Down Expand Up @@ -845,12 +869,7 @@ cdef class CoreProtocol:
self._ensure_connected()
self._set_state(PROTOCOL_PREPARE)

buf = WriteBuffer.new_message(b'P')
buf.write_str(stmt_name, self.encoding)
buf.write_str(query, self.encoding)
buf.write_int16(0)
buf.end_message()
packet = buf
packet = self._build_parse_message(stmt_name, query)

buf = WriteBuffer.new_message(b'D')
buf.write_byte(b'S')
Expand All @@ -872,10 +891,7 @@ cdef class CoreProtocol:
buf = self._build_bind_message(portal_name, stmt_name, bind_data)
packet = buf

buf = WriteBuffer.new_message(b'E')
buf.write_str(portal_name, self.encoding) # name of the portal
buf.write_int32(limit) # number of rows to return; 0 - all
buf.end_message()
buf = self._build_execute_message(portal_name, limit)
packet.write_buffer(buf)

packet.write_bytes(SYNC_MESSAGE)
Expand All @@ -894,11 +910,8 @@ cdef class CoreProtocol:

self._send_bind_message(portal_name, stmt_name, bind_data, limit)

cdef _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data):

cdef WriteBuffer buf

cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data):
self._ensure_connected()
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)

Expand All @@ -907,17 +920,88 @@ cdef class CoreProtocol:
self._execute_iter = bind_data
self._execute_portal_name = portal_name
self._execute_stmt_name = stmt_name
return self._bind_execute_many_more(True)

try:
buf = <WriteBuffer>next(bind_data)
except StopIteration:
self._push_result()
except Exception as e:
self.result_type = RESULT_FAILED
self.result = e
self._push_result()
else:
self._send_bind_message(portal_name, stmt_name, buf, 0)
cdef bint _bind_execute_many_more(self, bint first=False):
cdef:
WriteBuffer packet
WriteBuffer buf
list buffers = []

# as we keep sending, the server may return an error early
if self.result_type == RESULT_FAILED:
self._write(SYNC_MESSAGE)
return False

# collect up to four 32KB buffers to send
# https://github.com/MagicStack/asyncpg/pull/289#issuecomment-391215051
while len(buffers) < _EXECUTE_MANY_BUF_NUM:
packet = WriteBuffer.new()

# fill one 32KB buffer
while packet.len() < _EXECUTE_MANY_BUF_SIZE:
try:
# grab one item from the input
buf = <WriteBuffer>next(self._execute_iter)

# reached the end of the input
except StopIteration:
if first:
# if we never send anything, simply set the result
self._push_result()
else:
# otherwise, append SYNC and send the buffers
packet.write_bytes(SYNC_MESSAGE)
buffers.append(packet)
self._writelines(buffers)
return False

# error in input, give up the buffers and cleanup
except Exception as ex:
self.result_type = RESULT_FAILED
self.result = ex
if first:
self._push_result()
elif self.is_in_transaction():
# we're in an explicit transaction, just SYNC
self._write(SYNC_MESSAGE)
else:
# In an implicit transaction, if `ignore_till_sync`,
# `ROLLBACK` will be ignored and `Sync` will restore
# the state; or the transaction will be rolled back
# with a warning saying that there was no transaction,
# but rollback is done anyway, so we could safely
# ignore this warning.
# GOTCHA: simple query message will be ignored if
# `ignore_till_sync` is set.
buf = self._build_parse_message('', 'ROLLBACK')
buf.write_buffer(self._build_bind_message(
'', '', self._build_empty_bind_data()))
buf.write_buffer(self._build_execute_message('', 0))
buf.write_bytes(SYNC_MESSAGE)
self._write(buf)
return False

# all good, write to the buffer
first = False
packet.write_buffer(
self._build_bind_message(
self._execute_portal_name,
self._execute_stmt_name,
buf,
)
)
packet.write_buffer(
self._build_execute_message(self._execute_portal_name, 0,
)
)

# collected one buffer
buffers.append(packet)

# write to the wire, and signal the caller for more to send
self._writelines(buffers)
return True

cdef _execute(self, str portal_name, int32_t limit):
cdef WriteBuffer buf
Expand All @@ -927,10 +1011,7 @@ cdef class CoreProtocol:

self.result = []

buf = WriteBuffer.new_message(b'E')
buf.write_str(portal_name, self.encoding) # name of the portal
buf.write_int32(limit) # number of rows to return; 0 - all
buf.end_message()
buf = self._build_execute_message(portal_name, limit)

buf.write_bytes(SYNC_MESSAGE)

Expand Down Expand Up @@ -1013,6 +1094,9 @@ cdef class CoreProtocol:
cdef _write(self, buf):
raise NotImplementedError

cdef _writelines(self, list buffers):
raise NotImplementedError

cdef _decode_row(self, const char* buf, ssize_t buf_len):
pass

Expand Down
Loading

0 comments on commit 557f515

Please sign in to comment.