Skip to content

Commit

Permalink
Disable custom data codec for internal introspection
Browse files Browse the repository at this point in the history
Fixes: #617
  • Loading branch information
fantix committed Sep 18, 2020
1 parent 98dcf96 commit 821f279
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 37 deletions.
35 changes: 25 additions & 10 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ async def _get_statement(
*,
named: bool=False,
use_cache: bool=True,
disable_custom_codec=False,
record_class=None
):
if record_class is None:
Expand Down Expand Up @@ -401,7 +402,7 @@ async def _get_statement(

# Now that types have been resolved, populate the codec pipeline
# for the statement.
statement._init_codecs()
statement._init_codecs(disable_custom_codec)

if need_reprepare:
await self._protocol.prepare(
Expand All @@ -424,7 +425,12 @@ async def _get_statement(

async def _introspect_types(self, typeoids, timeout):
return await self.__execute(
self._intro_query, (list(typeoids),), 0, timeout)
self._intro_query,
(list(typeoids),),
0,
timeout,
disable_custom_codec=True,
)

async def _introspect_type(self, typename, schema):
if (
Expand All @@ -437,20 +443,22 @@ async def _introspect_type(self, typename, schema):
[typeoid],
limit=0,
timeout=None,
disable_custom_codec=True,
)
if rows:
typeinfo = rows[0]
else:
typeinfo = None
else:
typeinfo = await self.fetchrow(
introspection.TYPE_BY_NAME, typename, schema)
rows = await self._execute(
introspection.TYPE_BY_NAME,
[typename, schema],
limit=1,
timeout=None,
disable_custom_codec=True,
)

if not typeinfo:
if not rows:
raise ValueError(
'unknown type: {}.{}'.format(schema, typename))

return typeinfo
return rows[0]

def cursor(
self,
Expand Down Expand Up @@ -1587,6 +1595,7 @@ async def _execute(
timeout,
*,
return_status=False,
disable_custom_codec=False,
record_class=None
):
with self._stmt_exclusive_section:
Expand All @@ -1597,6 +1606,7 @@ async def _execute(
timeout,
return_status=return_status,
record_class=record_class,
disable_custom_codec=disable_custom_codec,
)
return result

Expand All @@ -1608,6 +1618,7 @@ async def __execute(
timeout,
*,
return_status=False,
disable_custom_codec=False,
record_class=None
):
executor = lambda stmt, timeout: self._protocol.bind_execute(
Expand All @@ -1618,6 +1629,7 @@ async def __execute(
executor,
timeout,
record_class=record_class,
disable_custom_codec=disable_custom_codec,
)

async def _executemany(self, query, args, timeout):
Expand All @@ -1635,20 +1647,23 @@ async def _do_execute(
timeout,
retry=True,
*,
disable_custom_codec=False,
record_class=None
):
if timeout is None:
stmt = await self._get_statement(
query,
None,
record_class=record_class,
disable_custom_codec=disable_custom_codec,
)
else:
before = time.monotonic()
stmt = await self._get_statement(
query,
timeout,
record_class=record_class,
disable_custom_codec=disable_custom_codec,
)
after = time.monotonic()
timeout -= after - before
Expand Down
3 changes: 2 additions & 1 deletion asyncpg/protocol/codecs/base.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,6 @@ cdef class DataCodecConfig:
dict _derived_type_codecs
dict _custom_type_codecs

cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format)
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
bint disable_custom_codec=*)
cdef inline Codec get_any_local_codec(self, uint32_t oid)
22 changes: 12 additions & 10 deletions asyncpg/protocol/codecs/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -692,18 +692,20 @@ cdef class DataCodecConfig:

return codec

cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format):
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
bint disable_custom_codec=False):
cdef Codec codec

codec = self.get_any_local_codec(oid)
if codec is not None:
if codec.format != format:
# The codec for this OID has been overridden by
# set_{builtin}_type_codec with a different format.
# We must respect that and not return a core codec.
return None
else:
return codec
if not disable_custom_codec:
codec = self.get_any_local_codec(oid)
if codec is not None:
if codec.format != format:
# The codec for this OID has been overridden by
# set_{builtin}_type_codec with a different format.
# We must respect that and not return a core codec.
return None
else:
return codec

codec = get_core_codec(oid, format)
if codec is not None:
Expand Down
6 changes: 3 additions & 3 deletions asyncpg/protocol/prepared_stmt.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ cdef class PreparedStatementState:
tuple rows_codecs

cdef _encode_bind_msg(self, args)
cpdef _init_codecs(self)
cdef _ensure_rows_decoder(self)
cdef _ensure_args_encoder(self)
cpdef _init_codecs(self, bint disable_custom_codec)
cdef _ensure_rows_decoder(self, bint disable_custom_codec)
cdef _ensure_args_encoder(self, bint disable_custom_codec)
cdef _set_row_desc(self, object desc)
cdef _set_args_desc(self, object desc)
cdef _decode_row(self, const char* cbuf, ssize_t buf_len)
17 changes: 10 additions & 7 deletions asyncpg/protocol/prepared_stmt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ cdef class PreparedStatementState:

return missing

cpdef _init_codecs(self):
self._ensure_args_encoder()
self._ensure_rows_decoder()
cpdef _init_codecs(self, bint disable_custom_codec):

self._ensure_args_encoder(disable_custom_codec)
self._ensure_rows_decoder(disable_custom_codec)

def attach(self):
self.refs += 1
Expand Down Expand Up @@ -180,7 +181,7 @@ cdef class PreparedStatementState:

return writer

cdef _ensure_rows_decoder(self):
cdef _ensure_rows_decoder(self, bint disable_custom_codec):
cdef:
list cols_names
object cols_mapping
Expand All @@ -205,7 +206,8 @@ cdef class PreparedStatementState:
cols_mapping[col_name] = i
cols_names.append(col_name)
oid = row[3]
codec = self.settings.get_data_codec(oid)
codec = self.settings.get_data_codec(
oid, disable_custom_codec=disable_custom_codec)
if codec is None or not codec.has_decoder():
raise exceptions.InternalClientError(
'no decoder for OID {}'.format(oid))
Expand All @@ -219,7 +221,7 @@ cdef class PreparedStatementState:

self.rows_codecs = tuple(codecs)

cdef _ensure_args_encoder(self):
cdef _ensure_args_encoder(self, bint disable_custom_codec):
cdef:
uint32_t p_oid
Codec codec
Expand All @@ -230,7 +232,8 @@ cdef class PreparedStatementState:

for i from 0 <= i < self.args_num:
p_oid = self.parameters_desc[i]
codec = self.settings.get_data_codec(p_oid)
codec = self.settings.get_data_codec(
p_oid, disable_custom_codec=disable_custom_codec)
if codec is None or not codec.has_encoder():
raise exceptions.InternalClientError(
'no encoder for OID {}'.format(p_oid))
Expand Down
2 changes: 1 addition & 1 deletion asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ cdef class BaseProtocol(CoreProtocol):
# No header extension
wbuf.write_int32(0)

record_stmt._ensure_rows_decoder()
record_stmt._ensure_rows_decoder(False)
codecs = record_stmt.rows_codecs
num_cols = len(codecs)
settings = self.settings
Expand Down
3 changes: 2 additions & 1 deletion asyncpg/protocol/settings.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ cdef class ConnectionSettings(pgproto.CodecContext):
cpdef inline set_builtin_type_codec(
self, typeoid, typename, typeschema, typekind, alias_to, format)
cpdef inline Codec get_data_codec(
self, uint32_t oid, ServerDataFormat format=*)
self, uint32_t oid, ServerDataFormat format=*,
bint disable_custom_codec=*)
12 changes: 8 additions & 4 deletions asyncpg/protocol/settings.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,18 @@ cdef class ConnectionSettings(pgproto.CodecContext):
typekind, alias_to, _format)

cpdef inline Codec get_data_codec(self, uint32_t oid,
ServerDataFormat format=PG_FORMAT_ANY):
ServerDataFormat format=PG_FORMAT_ANY,
bint disable_custom_codec=False):
if format == PG_FORMAT_ANY:
codec = self._data_codecs.get_codec(oid, PG_FORMAT_BINARY)
codec = self._data_codecs.get_codec(
oid, PG_FORMAT_BINARY, disable_custom_codec)
if codec is None:
codec = self._data_codecs.get_codec(oid, PG_FORMAT_TEXT)
codec = self._data_codecs.get_codec(
oid, PG_FORMAT_TEXT, disable_custom_codec)
return codec
else:
return self._data_codecs.get_codec(oid, format)
return self._data_codecs.get_codec(
oid, format, disable_custom_codec)

def __getattr__(self, name):
if not name.startswith('_'):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@ def tearDownClass(cls):

super().tearDownClass()

def setUp(self):
super().setUp()
self.loop.run_until_complete(self._add_custom_codec(self.con))

async def _add_custom_codec(self, conn):
# mess up with the codec - builtin introspection shouldn't be affected
await conn.set_type_codec(
"oid",
schema="pg_catalog",
encoder=lambda value: None,
decoder=lambda value: None,
format="text",
)

@tb.with_connection_options(database='asyncpg_intro_test')
async def test_introspection_on_large_db(self):
await self.con.execute(
Expand Down Expand Up @@ -142,6 +156,7 @@ async def test_introspection_retries_after_cache_bust(self):
# query would cause introspection to retry.
slow_intro_conn = await self.connect(
connection_class=SlowIntrospectionConnection)
await self._add_custom_codec(slow_intro_conn)
try:
await self.con.execute('''
CREATE DOMAIN intro_1_t AS int;
Expand Down

0 comments on commit 821f279

Please sign in to comment.