Skip to content

Commit

Permalink
Initialize statement codecs immediately after Prepare
Browse files Browse the repository at this point in the history
Currently the statement codecs are populated just before the first Bind
is issued.  This is problematic as in the time since Prepare, the codec
cache for derived types (arrays, composites etc.) may have been purged
by an installation of a custom codec, or general schema state
invalidation.

Fix this by populating the codecs immediately after the statement data
types have been resolved.

Fixes: #241.
  • Loading branch information
elprans committed Jan 21, 2018
1 parent a19ce50 commit 803c115
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 11 deletions.
20 changes: 14 additions & 6 deletions asyncpg/connection.py
Expand Up @@ -291,12 +291,20 @@ async def _get_statement(self, query, timeout, *, named: bool=False,
types, intro_stmt = await self.__execute(
self._intro_query, (list(ready),), 0, timeout)
self._protocol.get_settings().register_data_types(types)
if not intro_stmt.name and not statement.name:
# The introspection query has used an anonymous statement,
# which has blown away the anonymous statement we've prepared
# for the query, so we need to re-prepare it.
statement = await self._protocol.prepare(
stmt_name, query, timeout)
# The introspection query has used an anonymous statement,
# which has blown away the anonymous statement we've prepared
# for the query, so we need to re-prepare it.
need_reprepare = not intro_stmt.name and not statement.name
else:
need_reprepare = False

# Now that types have been resolved, populate the codec pipeline
# for the statement.
statement._init_codecs()

if need_reprepare:
await self._protocol.prepare(
stmt_name, query, timeout, state=statement)

if use_cache:
self._stmt_cache.put(query, statement)
Expand Down
1 change: 1 addition & 0 deletions asyncpg/protocol/prepared_stmt.pxd
Expand Up @@ -30,6 +30,7 @@ cdef class PreparedStatementState:
tuple rows_codecs

cdef _encode_bind_msg(self, args)
cpdef _init_codecs(self)
cdef _ensure_rows_decoder(self)
cdef _ensure_args_encoder(self)
cdef _set_row_desc(self, object desc)
Expand Down
7 changes: 4 additions & 3 deletions asyncpg/protocol/prepared_stmt.pyx
Expand Up @@ -82,6 +82,10 @@ cdef class PreparedStatementState:
else:
return True

cpdef _init_codecs(self):
self._ensure_args_encoder()
self._ensure_rows_decoder()

def attach(self):
self.refs += 1

Expand All @@ -101,9 +105,6 @@ cdef class PreparedStatementState:
raise exceptions.InterfaceError(
'the number of query arguments cannot exceed 32767')

self._ensure_args_encoder()
self._ensure_rows_decoder()

writer = WriteBuffer.new()

num_args_passed = len(args)
Expand Down
7 changes: 5 additions & 2 deletions asyncpg/protocol/protocol.pyx
Expand Up @@ -146,7 +146,8 @@ cdef class BaseProtocol(CoreProtocol):
self.is_reading = False
self.transport.pause_reading()

async def prepare(self, stmt_name, query, timeout):
async def prepare(self, stmt_name, query, timeout,
PreparedStatementState state=None):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
Expand All @@ -160,7 +161,9 @@ cdef class BaseProtocol(CoreProtocol):
try:
self._prepare(stmt_name, query) # network op
self.last_query = query
self.statement = PreparedStatementState(stmt_name, query, self)
if state is None:
state = PreparedStatementState(stmt_name, query, self)
self.statement = state
except Exception as ex:
waiter.set_exception(ex)
self._coreproto_error()
Expand Down
27 changes: 27 additions & 0 deletions tests/test_introspection.py
Expand Up @@ -5,6 +5,8 @@
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0


import json

from asyncpg import _testbase as tb
from asyncpg import connection as apg_con

Expand Down Expand Up @@ -98,3 +100,28 @@ async def test_introspection_no_stmt_cache_03(self):
"SELECT $1::int[], '{foo}'".format(foo='a' * 10000), [1, 2])

self.assertEqual(apg_con._uid, old_uid + 1)

async def test_introspection_sticks_for_ps(self):
# Test that the introspected codec pipeline for a prepared
# statement is not affected by a subsequent codec cache bust.

ps = await self.con._prepare('SELECT $1::json[]', use_cache=True)

try:
# Setting a custom codec blows the codec cache for derived types.
await self.con.set_type_codec(
'json', encoder=lambda v: v, decoder=json.loads,
schema='pg_catalog', format='text'
)

# The originally prepared statement should still be OK and
# use the previously selected codec.
self.assertEqual(await ps.fetchval(['{"foo": 1}']), ['{"foo": 1}'])

# The new query uses the custom codec.
v = await self.con.fetchval('SELECT $1::json[]', ['{"foo": 1}'])
self.assertEqual(v, [{'foo': 1}])

finally:
await self.con.reset_type_codec(
'json', schema='pg_catalog')

0 comments on commit 803c115

Please sign in to comment.