From 50f65fbb62ab2df9fd8eddde54217b89c66f6cba Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Tue, 1 Dec 2020 17:38:46 -0800 Subject: [PATCH] Untangle custom codec confusion (#662) Asyncpg currently erroneously prefers binary I/O for underlying type of arrays effectively ignoring a possible custom text codec that might have been configured on a type. Fix this by removing the explicit preference for binary I/O, so that the codec selection preference is now in the following order: - custom binary codec - custom text codec - builtin binary codec - builtin text codec Fixes: #590 Reported-by: @neumond --- asyncpg/connection.py | 9 ++ asyncpg/introspection.py | 31 +++---- asyncpg/protocol/codecs/base.pxd | 3 +- asyncpg/protocol/codecs/base.pyx | 137 +++++++++++++++---------------- asyncpg/protocol/settings.pyx | 11 +-- tests/test_codecs.py | 37 +++++++++ 6 files changed, 123 insertions(+), 105 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index e2355aa8..d33db090 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -1156,6 +1156,15 @@ async def set_type_codec(self, typename, *, .. versionchanged:: 0.13.0 The ``binary`` keyword argument was removed in favor of ``format``. + + .. note:: + + It is recommended to use the ``'binary'`` or ``'tuple'`` *format* + whenever possible and if the underlying type supports it. Asyncpg + currently does not support text I/O for composite and range types, + and some other functionality, such as + :meth:`Connection.copy_to_table`, does not support types with text + codecs. """ self._check_open() typeinfo = await self._introspect_type(typename, schema) diff --git a/asyncpg/introspection.py b/asyncpg/introspection.py index cca07cef..64508692 100644 --- a/asyncpg/introspection.py +++ b/asyncpg/introspection.py @@ -37,23 +37,9 @@ ELSE NULL END) AS basetype, - t.typreceive::oid != 0 AND t.typsend::oid != 0 - AS has_bin_io, t.typelem AS elemtype, elem_t.typdelim AS elemdelim, range_t.rngsubtype AS range_subtype, - (CASE WHEN t.typtype = 'r' THEN - (SELECT - range_elem_t.typreceive::oid != 0 AND - range_elem_t.typsend::oid != 0 - FROM - pg_catalog.pg_type AS range_elem_t - WHERE - range_elem_t.oid = range_t.rngsubtype) - ELSE - elem_t.typreceive::oid != 0 AND - elem_t.typsend::oid != 0 - END) AS elem_has_bin_io, (CASE WHEN t.typtype = 'c' THEN (SELECT array_agg(ia.atttypid ORDER BY ia.attnum) @@ -98,12 +84,12 @@ INTRO_LOOKUP_TYPES = '''\ WITH RECURSIVE typeinfo_tree( - oid, ns, name, kind, basetype, has_bin_io, elemtype, elemdelim, - range_subtype, elem_has_bin_io, attrtypoids, attrnames, depth) + oid, ns, name, kind, basetype, elemtype, elemdelim, + range_subtype, attrtypoids, attrnames, depth) AS ( SELECT - ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, ti.has_bin_io, - ti.elemtype, ti.elemdelim, ti.range_subtype, ti.elem_has_bin_io, + ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, + ti.elemtype, ti.elemdelim, ti.range_subtype, ti.attrtypoids, ti.attrnames, 0 FROM {typeinfo} AS ti @@ -113,8 +99,8 @@ UNION ALL SELECT - ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, ti.has_bin_io, - ti.elemtype, ti.elemdelim, ti.range_subtype, ti.elem_has_bin_io, + ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, + ti.elemtype, ti.elemdelim, ti.range_subtype, ti.attrtypoids, ti.attrnames, tt.depth + 1 FROM {typeinfo} ti, @@ -126,7 +112,10 @@ ) SELECT DISTINCT - * + *, + basetype::regtype::text AS basetype_name, + elemtype::regtype::text AS elemtype_name, + range_subtype::regtype::text AS range_subtype_name FROM typeinfo_tree ORDER BY diff --git a/asyncpg/protocol/codecs/base.pxd b/asyncpg/protocol/codecs/base.pxd index e8136f7b..79d7a695 100644 --- a/asyncpg/protocol/codecs/base.pxd +++ b/asyncpg/protocol/codecs/base.pxd @@ -168,4 +168,5 @@ cdef class DataCodecConfig: cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format, bint ignore_custom_codec=*) - cdef inline Codec get_any_local_codec(self, uint32_t oid) + cdef inline Codec get_custom_codec(self, uint32_t oid, + ServerDataFormat format) diff --git a/asyncpg/protocol/codecs/base.pyx b/asyncpg/protocol/codecs/base.pyx index d24cb66d..e4a767a9 100644 --- a/asyncpg/protocol/codecs/base.pyx +++ b/asyncpg/protocol/codecs/base.pyx @@ -440,14 +440,7 @@ cdef class DataCodecConfig: for ti in types: oid = ti['oid'] - if not ti['has_bin_io']: - format = PG_FORMAT_TEXT - else: - format = PG_FORMAT_BINARY - - has_text_elements = False - - if self.get_codec(oid, format) is not None: + if self.get_codec(oid, PG_FORMAT_ANY) is not None: continue name = ti['name'] @@ -468,54 +461,50 @@ cdef class DataCodecConfig: name = name[1:] name = '{}[]'.format(name) - if ti['elem_has_bin_io']: - elem_format = PG_FORMAT_BINARY - else: - elem_format = PG_FORMAT_TEXT - - elem_codec = self.get_codec(array_element_oid, elem_format) + elem_codec = self.get_codec(array_element_oid, PG_FORMAT_ANY) if elem_codec is None: - elem_format = PG_FORMAT_TEXT elem_codec = self.declare_fallback_codec( - array_element_oid, name, schema) + array_element_oid, ti['elemtype_name'], schema) elem_delim = ti['elemdelim'][0] - self._derived_type_codecs[oid, elem_format] = \ + self._derived_type_codecs[oid, elem_codec.format] = \ Codec.new_array_codec( oid, name, schema, elem_codec, elem_delim) elif ti['kind'] == b'c': + # Composite type + if not comp_type_attrs: raise exceptions.InternalClientError( - 'type record missing field types for ' - 'composite {}'.format(oid)) - - # Composite type + f'type record missing field types for composite {oid}') comp_elem_codecs = [] + has_text_elements = False for typoid in comp_type_attrs: - elem_codec = self.get_codec(typoid, PG_FORMAT_BINARY) - if elem_codec is None: - elem_codec = self.get_codec(typoid, PG_FORMAT_TEXT) - has_text_elements = True + elem_codec = self.get_codec(typoid, PG_FORMAT_ANY) if elem_codec is None: raise exceptions.InternalClientError( - 'no codec for composite attribute type {}'.format( - typoid)) + f'no codec for composite attribute type {typoid}') + if elem_codec.format is PG_FORMAT_TEXT: + has_text_elements = True comp_elem_codecs.append(elem_codec) element_names = collections.OrderedDict() for i, attrname in enumerate(ti['attrnames']): element_names[attrname] = i + # If at least one element is text-encoded, we must + # encode the whole composite as text. if has_text_elements: - format = PG_FORMAT_TEXT + elem_format = PG_FORMAT_TEXT + else: + elem_format = PG_FORMAT_BINARY - self._derived_type_codecs[oid, format] = \ + self._derived_type_codecs[oid, elem_format] = \ Codec.new_composite_codec( - oid, name, schema, format, comp_elem_codecs, + oid, name, schema, elem_format, comp_elem_codecs, comp_type_attrs, element_names) elif ti['kind'] == b'd': @@ -523,37 +512,28 @@ cdef class DataCodecConfig: if not base_type: raise exceptions.InternalClientError( - 'type record missing base type for domain {}'.format( - oid)) + f'type record missing base type for domain {oid}') - elem_codec = self.get_codec(base_type, format) + elem_codec = self.get_codec(base_type, PG_FORMAT_ANY) if elem_codec is None: - format = PG_FORMAT_TEXT elem_codec = self.declare_fallback_codec( - base_type, name, schema) + base_type, ti['basetype_name'], schema) - self._derived_type_codecs[oid, format] = elem_codec + self._derived_type_codecs[oid, elem_codec.format] = elem_codec elif ti['kind'] == b'r': # Range type if not range_subtype_oid: raise exceptions.InternalClientError( - 'type record missing base type for range {}'.format( - oid)) + f'type record missing base type for range {oid}') - if ti['elem_has_bin_io']: - elem_format = PG_FORMAT_BINARY - else: - elem_format = PG_FORMAT_TEXT - - elem_codec = self.get_codec(range_subtype_oid, elem_format) + elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY) if elem_codec is None: - elem_format = PG_FORMAT_TEXT elem_codec = self.declare_fallback_codec( - range_subtype_oid, name, schema) + range_subtype_oid, ti['range_subtype_name'], schema) - self._derived_type_codecs[oid, elem_format] = \ + self._derived_type_codecs[oid, elem_codec.format] = \ Codec.new_range_codec(oid, name, schema, elem_codec) elif ti['kind'] == b'e': @@ -665,10 +645,6 @@ cdef class DataCodecConfig: def declare_fallback_codec(self, uint32_t oid, str name, str schema): cdef Codec codec - codec = self.get_codec(oid, PG_FORMAT_TEXT) - if codec is not None: - return codec - if oid <= MAXBUILTINOID: # This is a BKI type, for which asyncpg has no # defined codec. This should only happen for newly @@ -695,34 +671,49 @@ cdef class DataCodecConfig: bint ignore_custom_codec=False): cdef Codec codec - if not ignore_custom_codec: - codec = self.get_any_local_codec(oid) - if codec is not None: - if codec.format != format: - # The codec for this OID has been overridden by - # set_{builtin}_type_codec with a different format. - # We must respect that and not return a core codec. - return None - else: - return codec - - codec = get_core_codec(oid, format) - if codec is not None: + if format == PG_FORMAT_ANY: + codec = self.get_codec( + oid, PG_FORMAT_BINARY, ignore_custom_codec) + if codec is None: + codec = self.get_codec( + oid, PG_FORMAT_TEXT, ignore_custom_codec) return codec else: - try: - return self._derived_type_codecs[oid, format] - except KeyError: - return None + if not ignore_custom_codec: + codec = self.get_custom_codec(oid, PG_FORMAT_ANY) + if codec is not None: + if codec.format != format: + # The codec for this OID has been overridden by + # set_{builtin}_type_codec with a different format. + # We must respect that and not return a core codec. + return None + else: + return codec + + codec = get_core_codec(oid, format) + if codec is not None: + return codec + else: + try: + return self._derived_type_codecs[oid, format] + except KeyError: + return None - cdef inline Codec get_any_local_codec(self, uint32_t oid): + cdef inline Codec get_custom_codec( + self, + uint32_t oid, + ServerDataFormat format + ): cdef Codec codec - codec = self._custom_type_codecs.get((oid, PG_FORMAT_BINARY)) - if codec is None: - return self._custom_type_codecs.get((oid, PG_FORMAT_TEXT)) + if format == PG_FORMAT_ANY: + codec = self.get_custom_codec(oid, PG_FORMAT_BINARY) + if codec is None: + codec = self.get_custom_codec(oid, PG_FORMAT_TEXT) else: - return codec + codec = self._custom_type_codecs.get((oid, format)) + + return codec cdef inline Codec get_core_codec( diff --git a/asyncpg/protocol/settings.pyx b/asyncpg/protocol/settings.pyx index 9ab32f39..b4cfa399 100644 --- a/asyncpg/protocol/settings.pyx +++ b/asyncpg/protocol/settings.pyx @@ -89,16 +89,7 @@ cdef class ConnectionSettings(pgproto.CodecContext): cpdef inline Codec get_data_codec(self, uint32_t oid, ServerDataFormat format=PG_FORMAT_ANY, bint ignore_custom_codec=False): - if format == PG_FORMAT_ANY: - codec = self._data_codecs.get_codec( - oid, PG_FORMAT_BINARY, ignore_custom_codec) - if codec is None: - codec = self._data_codecs.get_codec( - oid, PG_FORMAT_TEXT, ignore_custom_codec) - return codec - else: - return self._data_codecs.get_codec( - oid, format, ignore_custom_codec) + return self._data_codecs.get_codec(oid, format, ignore_custom_codec) def __getattr__(self, name): if not name.startswith('_'): diff --git a/tests/test_codecs.py b/tests/test_codecs.py index ae713dc5..b4ed7057 100644 --- a/tests/test_codecs.py +++ b/tests/test_codecs.py @@ -1329,6 +1329,34 @@ async def test_custom_codec_on_enum(self): finally: await self.con.execute('DROP TYPE custom_codec_t') + async def test_custom_codec_on_enum_array(self): + """Test encoding/decoding using a custom codec on an enum array. + + Bug: https://github.com/MagicStack/asyncpg/issues/590 + """ + await self.con.execute(''' + CREATE TYPE custom_codec_t AS ENUM ('foo', 'bar', 'baz') + ''') + + try: + await self.con.set_type_codec( + 'custom_codec_t', + encoder=lambda v: str(v).lstrip('enum :'), + decoder=lambda v: 'enum: ' + str(v)) + + v = await self.con.fetchval( + "SELECT ARRAY['foo', 'bar']::custom_codec_t[]") + self.assertEqual(v, ['enum: foo', 'enum: bar']) + + v = await self.con.fetchval( + 'SELECT ARRAY[$1]::custom_codec_t[]', 'foo') + self.assertEqual(v, ['enum: foo']) + + v = await self.con.fetchval("SELECT 'foo'::custom_codec_t") + self.assertEqual(v, 'enum: foo') + finally: + await self.con.execute('DROP TYPE custom_codec_t') + async def test_custom_codec_override_binary(self): """Test overriding core codecs.""" import json @@ -1374,6 +1402,14 @@ def _decoder(value): res = await conn.fetchval('SELECT $1::json', data) self.assertEqual(data, res) + res = await conn.fetchval('SELECT $1::json[]', [data]) + self.assertEqual([data], res) + + await conn.execute('CREATE DOMAIN my_json AS json') + + res = await conn.fetchval('SELECT $1::my_json', data) + self.assertEqual(data, res) + def _encoder(value): return value @@ -1389,6 +1425,7 @@ def _decoder(value): res = await conn.fetchval('SELECT $1::uuid', data) self.assertEqual(res, data) finally: + await conn.execute('DROP DOMAIN IF EXISTS my_json') await conn.close() async def test_custom_codec_override_tuple(self):