Skip to content

Commit

Permalink
When prepared statements are disabled, avoid relying on them harder (#…
Browse files Browse the repository at this point in the history
…1065)

It appears that PgBouncer's `transaction` pooling mode does not consider
implicit transactions properly, and so in a [`Parse`, `Flush`, `Bind`,
`Execute`, `Sync`] sequence, `Flush` would be (incorrectly) considered by
PgBouncer as a transaction boundary and it will happily send the
following `Bind` / `Execute` messages to a different backend process.

This makes it so that when `statement_cache_size` is set to `0`, asyncpg
assumes a pessimistic stance on prepared statement persistence and does
not rely on them even in implicit transactions.  The above message
sequence thus becomes `Parse`, `Flush`, `Parse` (a second time), `Bind`,
`Execute`, `Sync`.

This obviously has negative performance impact due to the extraneous
`Parse`.

Fixes: #1058
Fixes: #1041
  • Loading branch information
elprans committed Aug 17, 2023
1 parent 87ab143 commit cbf64e1
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 25 deletions.
46 changes: 31 additions & 15 deletions asyncpg/connection.py
Expand Up @@ -47,6 +47,7 @@ class Connection(metaclass=ConnectionMeta):
__slots__ = ('_protocol', '_transport', '_loop',
'_top_xact', '_aborted',
'_pool_release_ctr', '_stmt_cache', '_stmts_to_close',
'_stmt_cache_enabled',
'_listeners', '_server_version', '_server_caps',
'_intro_query', '_reset_query', '_proxy',
'_stmt_exclusive_section', '_config', '_params', '_addr',
Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(self, protocol, transport, loop,
max_lifetime=config.max_cached_statement_lifetime)

self._stmts_to_close = set()
self._stmt_cache_enabled = config.statement_cache_size > 0

self._listeners = {}
self._log_listeners = set()
Expand Down Expand Up @@ -381,11 +383,13 @@ async def _get_statement(
# Only use the cache when:
# * `statement_cache_size` is greater than 0;
# * query size is less than `max_cacheable_statement_size`.
use_cache = self._stmt_cache.get_max_size() > 0
if (use_cache and
self._config.max_cacheable_statement_size and
len(query) > self._config.max_cacheable_statement_size):
use_cache = False
use_cache = (
self._stmt_cache_enabled
and (
not self._config.max_cacheable_statement_size
or len(query) <= self._config.max_cacheable_statement_size
)
)

if isinstance(named, str):
stmt_name = named
Expand Down Expand Up @@ -434,14 +438,16 @@ async def _get_statement(
# for the statement.
statement._init_codecs()

if need_reprepare:
await self._protocol.prepare(
stmt_name,
query,
timeout,
state=statement,
record_class=record_class,
)
if (
need_reprepare
or (not statement.name and not self._stmt_cache_enabled)
):
# Mark this anonymous prepared statement as "unprepared",
# causing it to get re-Parsed in next bind_execute.
# We always do this when stmt_cache_size is set to 0 assuming
# people are running PgBouncer which is mishandling implicit
# transactions.
statement.mark_unprepared()

if use_cache:
self._stmt_cache.put(
Expand Down Expand Up @@ -1679,7 +1685,13 @@ async def __execute(
record_class=None
):
executor = lambda stmt, timeout: self._protocol.bind_execute(
stmt, args, '', limit, return_status, timeout)
state=stmt,
args=args,
portal_name='',
limit=limit,
return_extra=return_status,
timeout=timeout,
)
timeout = self._protocol._get_timeout(timeout)
return await self._do_execute(
query,
Expand All @@ -1691,7 +1703,11 @@ async def __execute(

async def _executemany(self, query, args, timeout):
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
stmt, args, '', timeout)
state=stmt,
args=args,
portal_name='',
timeout=timeout,
)
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
result, _ = await self._do_execute(query, executor, timeout)
Expand Down
3 changes: 2 additions & 1 deletion asyncpg/protocol/coreproto.pxd
Expand Up @@ -167,7 +167,8 @@ cdef class CoreProtocol:


cdef _connect(self)
cdef _prepare(self, str stmt_name, str query)
cdef _prepare_and_describe(self, str stmt_name, str query)
cdef _send_parse_message(self, str stmt_name, str query)
cdef _send_bind_message(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef _bind_execute(self, str portal_name, str stmt_name,
Expand Down
18 changes: 17 additions & 1 deletion asyncpg/protocol/coreproto.pyx
Expand Up @@ -237,6 +237,10 @@ cdef class CoreProtocol:
# ErrorResponse
self._parse_msg_error_response(True)

elif mtype == b'1':
# ParseComplete, in case `_bind_execute()` is reparsing
self.buffer.discard_message()

elif mtype == b'2':
# BindComplete
self.buffer.discard_message()
Expand Down Expand Up @@ -269,6 +273,10 @@ cdef class CoreProtocol:
# ErrorResponse
self._parse_msg_error_response(True)

elif mtype == b'1':
# ParseComplete, in case `_bind_execute_many()` is reparsing
self.buffer.discard_message()

elif mtype == b'2':
# BindComplete
self.buffer.discard_message()
Expand Down Expand Up @@ -874,7 +882,15 @@ cdef class CoreProtocol:
outbuf.write_buffer(buf)
self._write(outbuf)

cdef _prepare(self, str stmt_name, str query):
cdef _send_parse_message(self, str stmt_name, str query):
cdef:
WriteBuffer msg

self._ensure_connected()
msg = self._build_parse_message(stmt_name, query)
self._write(msg)

cdef _prepare_and_describe(self, str stmt_name, str query):
cdef:
WriteBuffer packet
WriteBuffer buf
Expand Down
1 change: 1 addition & 0 deletions asyncpg/protocol/prepared_stmt.pxd
Expand Up @@ -10,6 +10,7 @@ cdef class PreparedStatementState:
readonly str name
readonly str query
readonly bint closed
readonly bint prepared
readonly int refs
readonly type record_class
readonly bint ignore_custom_codec
Expand Down
7 changes: 7 additions & 0 deletions asyncpg/protocol/prepared_stmt.pyx
Expand Up @@ -27,6 +27,7 @@ cdef class PreparedStatementState:
self.args_num = self.cols_num = 0
self.cols_desc = None
self.closed = False
self.prepared = True
self.refs = 0
self.record_class = record_class
self.ignore_custom_codec = ignore_custom_codec
Expand Down Expand Up @@ -101,6 +102,12 @@ cdef class PreparedStatementState:
def mark_closed(self):
self.closed = True

def mark_unprepared(self):
if self.name:
raise exceptions.InternalClientError(
"named prepared statements cannot be marked unprepared")
self.prepared = False

cdef _encode_bind_msg(self, args, int seqno = -1):
cdef:
int idx
Expand Down
31 changes: 23 additions & 8 deletions asyncpg/protocol/protocol.pyx
Expand Up @@ -155,7 +155,7 @@ cdef class BaseProtocol(CoreProtocol):

waiter = self._new_waiter(timeout)
try:
self._prepare(stmt_name, query) # network op
self._prepare_and_describe(stmt_name, query) # network op
self.last_query = query
if state is None:
state = PreparedStatementState(
Expand All @@ -168,10 +168,15 @@ cdef class BaseProtocol(CoreProtocol):
return await waiter

@cython.iterable_coroutine
async def bind_execute(self, PreparedStatementState state, args,
str portal_name, int limit, return_extra,
timeout):

async def bind_execute(
self,
state: PreparedStatementState,
args,
portal_name: str,
limit: int,
return_extra: bool,
timeout,
):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
Expand All @@ -184,6 +189,9 @@ cdef class BaseProtocol(CoreProtocol):

waiter = self._new_waiter(timeout)
try:
if not state.prepared:
self._send_parse_message(state.name, state.query)

self._bind_execute(
portal_name,
state.name,
Expand All @@ -201,9 +209,13 @@ cdef class BaseProtocol(CoreProtocol):
return await waiter

@cython.iterable_coroutine
async def bind_execute_many(self, PreparedStatementState state, args,
str portal_name, timeout):

async def bind_execute_many(
self,
state: PreparedStatementState,
args,
portal_name: str,
timeout,
):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
Expand All @@ -222,6 +234,9 @@ cdef class BaseProtocol(CoreProtocol):

waiter = self._new_waiter(timeout)
try:
if not state.prepared:
self._send_parse_message(state.name, state.query)

more = self._bind_execute_many(
portal_name,
state.name,
Expand Down

0 comments on commit cbf64e1

Please sign in to comment.