diff --git a/asyncpg/connection.py b/asyncpg/connection.py index a7f249ba..adf84e00 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -47,6 +47,7 @@ class Connection(metaclass=ConnectionMeta): __slots__ = ('_protocol', '_transport', '_loop', '_top_xact', '_aborted', '_pool_release_ctr', '_stmt_cache', '_stmts_to_close', + '_stmt_cache_enabled', '_listeners', '_server_version', '_server_caps', '_intro_query', '_reset_query', '_proxy', '_stmt_exclusive_section', '_config', '_params', '_addr', @@ -79,6 +80,7 @@ def __init__(self, protocol, transport, loop, max_lifetime=config.max_cached_statement_lifetime) self._stmts_to_close = set() + self._stmt_cache_enabled = config.statement_cache_size > 0 self._listeners = {} self._log_listeners = set() @@ -381,11 +383,13 @@ async def _get_statement( # Only use the cache when: # * `statement_cache_size` is greater than 0; # * query size is less than `max_cacheable_statement_size`. - use_cache = self._stmt_cache.get_max_size() > 0 - if (use_cache and - self._config.max_cacheable_statement_size and - len(query) > self._config.max_cacheable_statement_size): - use_cache = False + use_cache = ( + self._stmt_cache_enabled + and ( + not self._config.max_cacheable_statement_size + or len(query) <= self._config.max_cacheable_statement_size + ) + ) if isinstance(named, str): stmt_name = named @@ -434,7 +438,7 @@ async def _get_statement( # for the statement. statement._init_codecs() - if need_reprepare: + if need_reprepare and self._stmt_cache_enabled: await self._protocol.prepare( stmt_name, query, @@ -1679,7 +1683,14 @@ async def __execute( record_class=None ): executor = lambda stmt, timeout: self._protocol.bind_execute( - stmt, args, '', limit, return_status, timeout) + state=stmt, + args=args, + portal_name='', + limit=limit, + return_extra=return_status, + timeout=timeout, + reparse=not self._stmt_cache_enabled, + ) timeout = self._protocol._get_timeout(timeout) return await self._do_execute( query, @@ -1691,7 +1702,12 @@ async def __execute( async def _executemany(self, query, args, timeout): executor = lambda stmt, timeout: self._protocol.bind_execute_many( - stmt, args, '', timeout) + state=stmt, + args=args, + portal_name='', + timeout=timeout, + reparse=not self._stmt_cache_enabled, + ) timeout = self._protocol._get_timeout(timeout) with self._stmt_exclusive_section: result, _ = await self._do_execute(query, executor, timeout) diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd index f21559b4..7ce4f574 100644 --- a/asyncpg/protocol/coreproto.pxd +++ b/asyncpg/protocol/coreproto.pxd @@ -167,7 +167,8 @@ cdef class CoreProtocol: cdef _connect(self) - cdef _prepare(self, str stmt_name, str query) + cdef _prepare_and_describe(self, str stmt_name, str query) + cdef _send_parse_message(self, str stmt_name, str query) cdef _send_bind_message(self, str portal_name, str stmt_name, WriteBuffer bind_data, int32_t limit) cdef _bind_execute(self, str portal_name, str stmt_name, diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index 6bf1adc4..92754484 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -237,6 +237,10 @@ cdef class CoreProtocol: # ErrorResponse self._parse_msg_error_response(True) + elif mtype == b'1': + # ParseComplete, in case `_bind_execute()` is reparsing + self.buffer.discard_message() + elif mtype == b'2': # BindComplete self.buffer.discard_message() @@ -269,6 +273,10 @@ cdef class CoreProtocol: # ErrorResponse self._parse_msg_error_response(True) + elif mtype == b'1': + # ParseComplete, in case `_bind_execute_many()` is reparsing + self.buffer.discard_message() + elif mtype == b'2': # BindComplete self.buffer.discard_message() @@ -874,7 +882,15 @@ cdef class CoreProtocol: outbuf.write_buffer(buf) self._write(outbuf) - cdef _prepare(self, str stmt_name, str query): + cdef _send_parse_message(self, str stmt_name, str query): + cdef: + WriteBuffer msg + + self._ensure_connected() + msg = self._build_parse_message(stmt_name, query) + self._write(msg) + + cdef _prepare_and_describe(self, str stmt_name, str query): cdef: WriteBuffer packet WriteBuffer buf diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index 3f512a81..447fd717 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -155,7 +155,7 @@ cdef class BaseProtocol(CoreProtocol): waiter = self._new_waiter(timeout) try: - self._prepare(stmt_name, query) # network op + self._prepare_and_describe(stmt_name, query) # network op self.last_query = query if state is None: state = PreparedStatementState( @@ -168,10 +168,16 @@ cdef class BaseProtocol(CoreProtocol): return await waiter @cython.iterable_coroutine - async def bind_execute(self, PreparedStatementState state, args, - str portal_name, int limit, return_extra, - timeout): - + async def bind_execute( + self, + state: PreparedStatementState, + args, + portal_name: str, + limit: int, + return_extra: bool, + timeout, + reparse: bool = False, + ): if self.cancel_waiter is not None: await self.cancel_waiter if self.cancel_sent_waiter is not None: @@ -184,6 +190,9 @@ cdef class BaseProtocol(CoreProtocol): waiter = self._new_waiter(timeout) try: + if reparse: + self._send_parse_message(state.name, state.query) + self._bind_execute( portal_name, state.name, @@ -201,9 +210,14 @@ cdef class BaseProtocol(CoreProtocol): return await waiter @cython.iterable_coroutine - async def bind_execute_many(self, PreparedStatementState state, args, - str portal_name, timeout): - + async def bind_execute_many( + self, + state: PreparedStatementState, + args, + portal_name: str, + timeout, + reparse: bool = False, + ): if self.cancel_waiter is not None: await self.cancel_waiter if self.cancel_sent_waiter is not None: @@ -222,6 +236,9 @@ cdef class BaseProtocol(CoreProtocol): waiter = self._new_waiter(timeout) try: + if reparse: + self._send_parse_message(state.name, state.query) + more = self._bind_execute_many( portal_name, state.name,