diff --git a/asyncpg/connection.py b/asyncpg/connection.py index a7f249ba..2d689512 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -47,6 +47,7 @@ class Connection(metaclass=ConnectionMeta): __slots__ = ('_protocol', '_transport', '_loop', '_top_xact', '_aborted', '_pool_release_ctr', '_stmt_cache', '_stmts_to_close', + '_stmt_cache_enabled', '_listeners', '_server_version', '_server_caps', '_intro_query', '_reset_query', '_proxy', '_stmt_exclusive_section', '_config', '_params', '_addr', @@ -79,6 +80,7 @@ def __init__(self, protocol, transport, loop, max_lifetime=config.max_cached_statement_lifetime) self._stmts_to_close = set() + self._stmt_cache_enabled = config.statement_cache_size > 0 self._listeners = {} self._log_listeners = set() @@ -381,11 +383,13 @@ async def _get_statement( # Only use the cache when: # * `statement_cache_size` is greater than 0; # * query size is less than `max_cacheable_statement_size`. - use_cache = self._stmt_cache.get_max_size() > 0 - if (use_cache and - self._config.max_cacheable_statement_size and - len(query) > self._config.max_cacheable_statement_size): - use_cache = False + use_cache = ( + self._stmt_cache_enabled + and ( + not self._config.max_cacheable_statement_size + or len(query) <= self._config.max_cacheable_statement_size + ) + ) if isinstance(named, str): stmt_name = named @@ -434,14 +438,16 @@ async def _get_statement( # for the statement. statement._init_codecs() - if need_reprepare: - await self._protocol.prepare( - stmt_name, - query, - timeout, - state=statement, - record_class=record_class, - ) + if ( + need_reprepare + or (not statement.name and not self._stmt_cache_enabled) + ): + # Mark this anonymous prepared statement as "unprepared", + # causing it to get re-Parsed in next bind_execute. + # We always do this when stmt_cache_size is set to 0 assuming + # people are running PgBouncer which is mishandling implicit + # transactions. + statement.mark_unprepared() if use_cache: self._stmt_cache.put( @@ -1679,7 +1685,13 @@ async def __execute( record_class=None ): executor = lambda stmt, timeout: self._protocol.bind_execute( - stmt, args, '', limit, return_status, timeout) + state=stmt, + args=args, + portal_name='', + limit=limit, + return_extra=return_status, + timeout=timeout, + ) timeout = self._protocol._get_timeout(timeout) return await self._do_execute( query, @@ -1691,7 +1703,11 @@ async def __execute( async def _executemany(self, query, args, timeout): executor = lambda stmt, timeout: self._protocol.bind_execute_many( - stmt, args, '', timeout) + state=stmt, + args=args, + portal_name='', + timeout=timeout, + ) timeout = self._protocol._get_timeout(timeout) with self._stmt_exclusive_section: result, _ = await self._do_execute(query, executor, timeout) diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd index f21559b4..7ce4f574 100644 --- a/asyncpg/protocol/coreproto.pxd +++ b/asyncpg/protocol/coreproto.pxd @@ -167,7 +167,8 @@ cdef class CoreProtocol: cdef _connect(self) - cdef _prepare(self, str stmt_name, str query) + cdef _prepare_and_describe(self, str stmt_name, str query) + cdef _send_parse_message(self, str stmt_name, str query) cdef _send_bind_message(self, str portal_name, str stmt_name, WriteBuffer bind_data, int32_t limit) cdef _bind_execute(self, str portal_name, str stmt_name, diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index 6bf1adc4..92754484 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -237,6 +237,10 @@ cdef class CoreProtocol: # ErrorResponse self._parse_msg_error_response(True) + elif mtype == b'1': + # ParseComplete, in case `_bind_execute()` is reparsing + self.buffer.discard_message() + elif mtype == b'2': # BindComplete self.buffer.discard_message() @@ -269,6 +273,10 @@ cdef class CoreProtocol: # ErrorResponse self._parse_msg_error_response(True) + elif mtype == b'1': + # ParseComplete, in case `_bind_execute_many()` is reparsing + self.buffer.discard_message() + elif mtype == b'2': # BindComplete self.buffer.discard_message() @@ -874,7 +882,15 @@ cdef class CoreProtocol: outbuf.write_buffer(buf) self._write(outbuf) - cdef _prepare(self, str stmt_name, str query): + cdef _send_parse_message(self, str stmt_name, str query): + cdef: + WriteBuffer msg + + self._ensure_connected() + msg = self._build_parse_message(stmt_name, query) + self._write(msg) + + cdef _prepare_and_describe(self, str stmt_name, str query): cdef: WriteBuffer packet WriteBuffer buf diff --git a/asyncpg/protocol/prepared_stmt.pxd b/asyncpg/protocol/prepared_stmt.pxd index 3906af25..369db733 100644 --- a/asyncpg/protocol/prepared_stmt.pxd +++ b/asyncpg/protocol/prepared_stmt.pxd @@ -10,6 +10,7 @@ cdef class PreparedStatementState: readonly str name readonly str query readonly bint closed + readonly bint prepared readonly int refs readonly type record_class readonly bint ignore_custom_codec diff --git a/asyncpg/protocol/prepared_stmt.pyx b/asyncpg/protocol/prepared_stmt.pyx index b1f2a66d..7335825c 100644 --- a/asyncpg/protocol/prepared_stmt.pyx +++ b/asyncpg/protocol/prepared_stmt.pyx @@ -27,6 +27,7 @@ cdef class PreparedStatementState: self.args_num = self.cols_num = 0 self.cols_desc = None self.closed = False + self.prepared = True self.refs = 0 self.record_class = record_class self.ignore_custom_codec = ignore_custom_codec @@ -101,6 +102,12 @@ cdef class PreparedStatementState: def mark_closed(self): self.closed = True + def mark_unprepared(self): + if self.name: + raise exceptions.InternalClientError( + "named prepared statements cannot be marked unprepared") + self.prepared = False + cdef _encode_bind_msg(self, args, int seqno = -1): cdef: int idx diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index 3f512a81..f504d9d0 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -155,7 +155,7 @@ cdef class BaseProtocol(CoreProtocol): waiter = self._new_waiter(timeout) try: - self._prepare(stmt_name, query) # network op + self._prepare_and_describe(stmt_name, query) # network op self.last_query = query if state is None: state = PreparedStatementState( @@ -168,10 +168,15 @@ cdef class BaseProtocol(CoreProtocol): return await waiter @cython.iterable_coroutine - async def bind_execute(self, PreparedStatementState state, args, - str portal_name, int limit, return_extra, - timeout): - + async def bind_execute( + self, + state: PreparedStatementState, + args, + portal_name: str, + limit: int, + return_extra: bool, + timeout, + ): if self.cancel_waiter is not None: await self.cancel_waiter if self.cancel_sent_waiter is not None: @@ -184,6 +189,9 @@ cdef class BaseProtocol(CoreProtocol): waiter = self._new_waiter(timeout) try: + if not state.prepared: + self._send_parse_message(state.name, state.query) + self._bind_execute( portal_name, state.name, @@ -201,9 +209,13 @@ cdef class BaseProtocol(CoreProtocol): return await waiter @cython.iterable_coroutine - async def bind_execute_many(self, PreparedStatementState state, args, - str portal_name, timeout): - + async def bind_execute_many( + self, + state: PreparedStatementState, + args, + portal_name: str, + timeout, + ): if self.cancel_waiter is not None: await self.cancel_waiter if self.cancel_sent_waiter is not None: @@ -222,6 +234,9 @@ cdef class BaseProtocol(CoreProtocol): waiter = self._new_waiter(timeout) try: + if not state.prepared: + self._send_parse_message(state.name, state.query) + more = self._bind_execute_many( portal_name, state.name,