Skip to content

Commit

Permalink
Rewrite (again) executemany() to batch args for performance.
Browse files Browse the repository at this point in the history
Now `Bind` and `Execute` pairs are batched into 4 x 32KB buffers to take
advantage of `writelines()`. A single `Sync` is sent at last, so that
all args live in the same transaction.

Closes: #289
  • Loading branch information
fantix committed Nov 28, 2018
1 parent 43a7b21 commit 46ecc07
Show file tree
Hide file tree
Showing 7 changed files with 321 additions and 112 deletions.
10 changes: 10 additions & 0 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,13 @@ async def executemany(self, command: str, args, *, timeout: float=None):
.. versionchanged:: 0.11.0
`timeout` became a keyword-only parameter.
.. versionchanged:: 0.19.0
The execution was changed to be in an implicit transaction if there
was no explicit transaction, so that it will no longer end up with
partial success. If you still need the previous behavior to
progressively execute many args, please use a loop with prepared
statement instead.
"""
self._check_open()
return await self._executemany(command, args, timeout)
Expand Down Expand Up @@ -821,6 +828,9 @@ async def _copy_in(self, copy_stmt, source, timeout):
f = source
elif isinstance(source, collections.abc.AsyncIterable):
# assuming calling output returns an awaitable.
# copy_in() is designed to handle very large amounts of data, and
# the source async iterable is allowed to return an arbitrary
# amount of data on every iteration.
reader = source
else:
# assuming source is an instance supporting the buffer protocol.
Expand Down
24 changes: 21 additions & 3 deletions asyncpg/prepared_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,24 @@ async def fetchrow(self, *args, timeout=None):
return None
return data[0]

async def __bind_execute(self, args, limit, timeout):
@connresource.guarded
async def executemany(self, args, *, timeout: float=None):
"""Execute the statement for each sequence of arguments in *args*.
:param args: An iterable containing sequences of arguments.
:param float timeout: Optional timeout value in seconds.
:return None: This method discards the results of the operations.
.. versionadded:: 0.19.0
"""
return await self.__do_execute(
lambda protocol: protocol.bind_execute_many(
self._state, args, '', timeout))

async def __do_execute(self, executor):
protocol = self._connection._protocol
try:
data, status, _ = await protocol.bind_execute(
self._state, args, '', limit, True, timeout)
return await executor(protocol)
except exceptions.OutdatedSchemaCacheError:
await self._connection.reload_schema_state()
# We can not find all manually created prepared statements, so just
Expand All @@ -209,6 +222,11 @@ async def __bind_execute(self, args, limit, timeout):
# invalidate themselves (unfortunately, clearing caches again).
self._state.mark_closed()
raise

async def __bind_execute(self, args, limit, timeout):
data, status, _ = await self.__do_execute(
lambda protocol: protocol.bind_execute(
self._state, args, '', limit, True, timeout))
self._last_status = status
return data

Expand Down
2 changes: 2 additions & 0 deletions asyncpg/protocol/consts.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@
DEF _MAXINT32 = 2**31 - 1
DEF _COPY_BUFFER_SIZE = 524288
DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0"
DEF _EXECUTE_MANY_BUF_NUM = 4
DEF _EXECUTE_MANY_BUF_SIZE = 32768
18 changes: 11 additions & 7 deletions asyncpg/protocol/coreproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,6 @@ cdef class CoreProtocol:
bint _skip_discard
bint _discard_data

# executemany support data
object _execute_iter
str _execute_portal_name
str _execute_stmt_name

ConnectionStatus con_status
ProtocolState state
TransactionStatus xact_status
Expand All @@ -105,6 +100,7 @@ cdef class CoreProtocol:
# True - completed, False - suspended
bint result_execute_completed

cpdef is_in_transaction(self)
cdef _process__auth(self, char mtype)
cdef _process__prepare(self, char mtype)
cdef _process__bind_execute(self, char mtype)
Expand Down Expand Up @@ -135,6 +131,7 @@ cdef class CoreProtocol:
cdef _auth_password_message_md5(self, bytes salt)

cdef _write(self, buf)
cdef _writelines(self, list buffers)

cdef _read_server_messages(self)

Expand All @@ -144,9 +141,13 @@ cdef class CoreProtocol:

cdef _ensure_connected(self)

cdef WriteBuffer _build_parse_message(self, str stmt_name, str query)
cdef WriteBuffer _build_bind_message(self, str portal_name,
str stmt_name,
WriteBuffer bind_data)
cdef WriteBuffer _build_empty_bind_data(self)
cdef WriteBuffer _build_execute_message(self, str portal_name,
int32_t limit)


cdef _connect(self)
Expand All @@ -155,8 +156,11 @@ cdef class CoreProtocol:
WriteBuffer bind_data, int32_t limit)
cdef _bind_execute(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data)
cdef _execute_many_init(self)
cdef _execute_many_writelines(self, str portal_name, str stmt_name,
object bind_data)
cdef _execute_many_done(self, bint data_sent)
cdef _execute_many_fail(self, object error)
cdef _bind(self, str portal_name, str stmt_name,
WriteBuffer bind_data)
cdef _execute(self, str portal_name, int32_t limit)
Expand Down
154 changes: 103 additions & 51 deletions asyncpg/protocol/coreproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ cdef class CoreProtocol:
self.xact_status = PQTRANS_IDLE
self.encoding = 'utf-8'

# executemany support data
self._execute_iter = None
self._execute_portal_name = None
self._execute_stmt_name = None

self._reset_result()

cpdef is_in_transaction(self):
# PQTRANS_INTRANS = idle, within transaction block
# PQTRANS_INERROR = idle, within failed transaction
return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)

cdef _read_server_messages(self):
cdef:
char mtype
Expand Down Expand Up @@ -253,22 +253,7 @@ cdef class CoreProtocol:
elif mtype == b'Z':
# ReadyForQuery
self._parse_msg_ready_for_query()
if self.result_type == RESULT_FAILED:
self._push_result()
else:
try:
buf = <WriteBuffer>next(self._execute_iter)
except StopIteration:
self._push_result()
except Exception as e:
self.result_type = RESULT_FAILED
self.result = e
self._push_result()
else:
# Next iteration over the executemany() arg sequence
self._send_bind_message(
self._execute_portal_name, self._execute_stmt_name,
buf, 0)
self._push_result()

elif mtype == b'I':
# EmptyQueryResponse
Expand Down Expand Up @@ -687,6 +672,17 @@ cdef class CoreProtocol:
if self.con_status != CONNECTION_OK:
raise apg_exc.InternalClientError('not connected')

cdef WriteBuffer _build_parse_message(self, str stmt_name, str query):
cdef WriteBuffer buf

buf = WriteBuffer.new_message(b'P')
buf.write_str(stmt_name, self.encoding)
buf.write_str(query, self.encoding)
buf.write_int16(0)

buf.end_message()
return buf

cdef WriteBuffer _build_bind_message(self, str portal_name,
str stmt_name,
WriteBuffer bind_data):
Expand All @@ -702,6 +698,25 @@ cdef class CoreProtocol:
buf.end_message()
return buf

cdef WriteBuffer _build_empty_bind_data(self):
cdef WriteBuffer buf
buf = WriteBuffer.new()
buf.write_int16(0) # The number of parameter format codes
buf.write_int16(0) # The number of parameter values
buf.write_int16(0) # The number of result-column format codes
return buf

cdef WriteBuffer _build_execute_message(self, str portal_name,
int32_t limit):
cdef WriteBuffer buf

buf = WriteBuffer.new_message(b'E')
buf.write_str(portal_name, self.encoding) # name of the portal
buf.write_int32(limit) # number of rows to return; 0 - all

buf.end_message()
return buf

# API for subclasses

cdef _connect(self):
Expand Down Expand Up @@ -752,12 +767,7 @@ cdef class CoreProtocol:
self._ensure_connected()
self._set_state(PROTOCOL_PREPARE)

buf = WriteBuffer.new_message(b'P')
buf.write_str(stmt_name, self.encoding)
buf.write_str(query, self.encoding)
buf.write_int16(0)
buf.end_message()
packet = buf
packet = self._build_parse_message(stmt_name, query)

buf = WriteBuffer.new_message(b'D')
buf.write_byte(b'S')
Expand All @@ -779,10 +789,7 @@ cdef class CoreProtocol:
buf = self._build_bind_message(portal_name, stmt_name, bind_data)
packet = buf

buf = WriteBuffer.new_message(b'E')
buf.write_str(portal_name, self.encoding) # name of the portal
buf.write_int32(limit) # number of rows to return; 0 - all
buf.end_message()
buf = self._build_execute_message(portal_name, limit)
packet.write_buffer(buf)

packet.write_bytes(SYNC_MESSAGE)
Expand All @@ -801,30 +808,75 @@ cdef class CoreProtocol:

self._send_bind_message(portal_name, stmt_name, bind_data, limit)

cdef _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data):

cdef WriteBuffer buf

cdef _execute_many_init(self):
self._ensure_connected()
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)

self.result = None
self._discard_data = True
self._execute_iter = bind_data
self._execute_portal_name = portal_name
self._execute_stmt_name = stmt_name

try:
buf = <WriteBuffer>next(bind_data)
except StopIteration:
self._push_result()
except Exception as e:
self.result_type = RESULT_FAILED
self.result = e
cdef _execute_many_writelines(self, str portal_name, str stmt_name,
object bind_data):
cdef:
WriteBuffer packet
WriteBuffer buf
list buffers = []

if self.result_type == RESULT_FAILED:
raise StopIteration(True)

while len(buffers) < _EXECUTE_MANY_BUF_NUM:
packet = WriteBuffer.new()

while packet.len() < _EXECUTE_MANY_BUF_SIZE:
try:
buf = <WriteBuffer>next(bind_data)
except StopIteration:
if packet.len() > 0:
buffers.append(packet)
if len(buffers) > 0:
self._writelines(buffers)
raise StopIteration(True)
else:
raise StopIteration(False)
except Exception as ex:
raise StopIteration(ex)
packet.write_buffer(
self._build_bind_message(portal_name, stmt_name, buf))
packet.write_buffer(
self._build_execute_message(portal_name, 0))
buffers.append(packet)
self._writelines(buffers)

cdef _execute_many_done(self, bint data_sent):
if data_sent:
self._write(SYNC_MESSAGE)
else:
self._push_result()

cdef _execute_many_fail(self, object error):
cdef WriteBuffer buf

self.result_type = RESULT_FAILED
self.result = error

# We shall rollback in an implicit transaction to prevent partial
# commit, while do nothing in an explicit transaction and leaving the
# error to the user
if self.is_in_transaction():
self._execute_many_done(True)
else:
self._send_bind_message(portal_name, stmt_name, buf, 0)
# Here if the implicit transaction is in `ignore_till_sync` mode,
# the `ROLLBACK` will be ignored and `Sync` will restore the state;
# or else the implicit transaction will be rolled back with a
# warning saying that there was no transaction, but rollback is
# done anyway, so we could ignore this warning.
buf = self._build_parse_message('', 'ROLLBACK')
buf.write_buffer(self._build_bind_message(
'', '', self._build_empty_bind_data()))
buf.write_buffer(self._build_execute_message('', 0))
buf.write_bytes(SYNC_MESSAGE)
self._write(buf)

cdef _execute(self, str portal_name, int32_t limit):
cdef WriteBuffer buf
Expand All @@ -834,10 +886,7 @@ cdef class CoreProtocol:

self.result = []

buf = WriteBuffer.new_message(b'E')
buf.write_str(portal_name, self.encoding) # name of the portal
buf.write_int32(limit) # number of rows to return; 0 - all
buf.end_message()
buf = self._build_execute_message(portal_name, limit)

buf.write_bytes(SYNC_MESSAGE)

Expand Down Expand Up @@ -920,6 +969,9 @@ cdef class CoreProtocol:
cdef _write(self, buf):
raise NotImplementedError

cdef _writelines(self, list buffers):
raise NotImplementedError

cdef _decode_row(self, const char* buf, ssize_t buf_len):
pass

Expand Down
Loading

0 comments on commit 46ecc07

Please sign in to comment.