diff --git a/asyncpg/connection.py b/asyncpg/connection.py index e21c12c4..23141f4d 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -291,12 +291,20 @@ async def _get_statement(self, query, timeout, *, named: bool=False, types, intro_stmt = await self.__execute( self._intro_query, (list(ready),), 0, timeout) self._protocol.get_settings().register_data_types(types) - if not intro_stmt.name and not statement.name: - # The introspection query has used an anonymous statement, - # which has blown away the anonymous statement we've prepared - # for the query, so we need to re-prepare it. - statement = await self._protocol.prepare( - stmt_name, query, timeout) + # The introspection query has used an anonymous statement, + # which has blown away the anonymous statement we've prepared + # for the query, so we need to re-prepare it. + need_reprepare = not intro_stmt.name and not statement.name + else: + need_reprepare = False + + # Now that types have been resolved, populate the codec pipeline + # for the statement. + statement._init_codecs() + + if need_reprepare: + await self._protocol.prepare( + stmt_name, query, timeout, state=statement) if use_cache: self._stmt_cache.put(query, statement) diff --git a/asyncpg/protocol/prepared_stmt.pxd b/asyncpg/protocol/prepared_stmt.pxd index 8dab35b1..9749113c 100644 --- a/asyncpg/protocol/prepared_stmt.pxd +++ b/asyncpg/protocol/prepared_stmt.pxd @@ -30,6 +30,7 @@ cdef class PreparedStatementState: tuple rows_codecs cdef _encode_bind_msg(self, args) + cpdef _init_codecs(self) cdef _ensure_rows_decoder(self) cdef _ensure_args_encoder(self) cdef _set_row_desc(self, object desc) diff --git a/asyncpg/protocol/prepared_stmt.pyx b/asyncpg/protocol/prepared_stmt.pyx index 3edb56f0..e69369d3 100644 --- a/asyncpg/protocol/prepared_stmt.pyx +++ b/asyncpg/protocol/prepared_stmt.pyx @@ -82,6 +82,10 @@ cdef class PreparedStatementState: else: return True + cpdef _init_codecs(self): + self._ensure_args_encoder() + self._ensure_rows_decoder() + def attach(self): self.refs += 1 @@ -101,9 +105,6 @@ cdef class PreparedStatementState: raise exceptions.InterfaceError( 'the number of query arguments cannot exceed 32767') - self._ensure_args_encoder() - self._ensure_rows_decoder() - writer = WriteBuffer.new() num_args_passed = len(args) diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index 09fc8c11..983c0ea1 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -146,7 +146,8 @@ cdef class BaseProtocol(CoreProtocol): self.is_reading = False self.transport.pause_reading() - async def prepare(self, stmt_name, query, timeout): + async def prepare(self, stmt_name, query, timeout, + PreparedStatementState state=None): if self.cancel_waiter is not None: await self.cancel_waiter if self.cancel_sent_waiter is not None: @@ -160,7 +161,9 @@ cdef class BaseProtocol(CoreProtocol): try: self._prepare(stmt_name, query) # network op self.last_query = query - self.statement = PreparedStatementState(stmt_name, query, self) + if state is None: + state = PreparedStatementState(stmt_name, query, self) + self.statement = state except Exception as ex: waiter.set_exception(ex) self._coreproto_error() diff --git a/tests/test_introspection.py b/tests/test_introspection.py index d46095f8..fcf5885d 100644 --- a/tests/test_introspection.py +++ b/tests/test_introspection.py @@ -5,6 +5,8 @@ # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +import json + from asyncpg import _testbase as tb from asyncpg import connection as apg_con @@ -98,3 +100,28 @@ async def test_introspection_no_stmt_cache_03(self): "SELECT $1::int[], '{foo}'".format(foo='a' * 10000), [1, 2]) self.assertEqual(apg_con._uid, old_uid + 1) + + async def test_introspection_sticks_for_ps(self): + # Test that the introspected codec pipeline for a prepared + # statement is not affected by a subsequent codec cache bust. + + ps = await self.con._prepare('SELECT $1::json[]', use_cache=True) + + try: + # Setting a custom codec blows the codec cache for derived types. + await self.con.set_type_codec( + 'json', encoder=lambda v: v, decoder=json.loads, + schema='pg_catalog', format='text' + ) + + # The originally prepared statement should still be OK and + # use the previously selected codec. + self.assertEqual(await ps.fetchval(['{"foo": 1}']), ['{"foo": 1}']) + + # The new query uses the custom codec. + v = await self.con.fetchval('SELECT $1::json[]', ['{"foo": 1}']) + self.assertEqual(v, [{'foo': 1}]) + + finally: + await self.con.reset_type_codec( + 'json', schema='pg_catalog')