Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for tuple-format custom codecs on composite types #1061

Merged
merged 1 commit into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,9 @@ async def set_type_codec(self, typename, *,
| ``time with | (``microseconds``, |
| time zone`` | ``time zone offset in seconds``) |
+-----------------+---------------------------------------------+
| any composite | Composite value elements |
| type | |
+-----------------+---------------------------------------------+

:param encoder:
Callable accepting a Python object as a single argument and
Expand Down Expand Up @@ -1208,6 +1211,10 @@ async def set_type_codec(self, typename, *,
The ``binary`` keyword argument was removed in favor of
``format``.

.. versionchanged:: 0.29.0
Custom codecs for composite types are now supported with
``format='tuple'``.

.. note::

It is recommended to use the ``'binary'`` or ``'tuple'`` *format*
Expand All @@ -1218,11 +1225,28 @@ async def set_type_codec(self, typename, *,
codecs.
"""
self._check_open()
settings = self._protocol.get_settings()
typeinfo = await self._introspect_type(typename, schema)
if not introspection.is_scalar_type(typeinfo):
full_typeinfos = []
if introspection.is_scalar_type(typeinfo):
kind = 'scalar'
elif introspection.is_composite_type(typeinfo):
if format != 'tuple':
raise exceptions.UnsupportedClientFeatureError(
'only tuple-format codecs can be used on composite types',
hint="Use `set_type_codec(..., format='tuple')` and "
"pass/interpret data as a Python tuple. See an "
"example at https://magicstack.github.io/asyncpg/"
"current/usage.html#example-decoding-complex-types",
)
kind = 'composite'
full_typeinfos, _ = await self._introspect_types(
(typeinfo['oid'],), 10)
else:
raise exceptions.InterfaceError(
'cannot use custom codec on non-scalar type {}.{}'.format(
schema, typename))
f'cannot use custom codec on type {schema}.{typename}: '
f'it is neither a scalar type nor a composite type'
)
if introspection.is_domain_type(typeinfo):
raise exceptions.UnsupportedClientFeatureError(
'custom codecs on domain types are not supported',
Expand All @@ -1234,8 +1258,8 @@ async def set_type_codec(self, typename, *,
)

oid = typeinfo['oid']
self._protocol.get_settings().add_python_codec(
oid, typename, schema, 'scalar',
settings.add_python_codec(
oid, typename, schema, full_typeinfos, kind,
encoder, decoder, format)

# Statement cache is no longer valid due to codec changes.
Expand Down
4 changes: 4 additions & 0 deletions asyncpg/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,7 @@ def is_scalar_type(typeinfo) -> bool:

def is_domain_type(typeinfo) -> bool:
return typeinfo['kind'] == b'd'


def is_composite_type(typeinfo) -> bool:
return typeinfo['kind'] == b'c'
3 changes: 3 additions & 0 deletions asyncpg/protocol/codecs/base.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ cdef class Codec:

encode_func c_encoder
decode_func c_decoder
Codec base_codec

object py_encoder
object py_decoder
Expand All @@ -79,6 +80,7 @@ cdef class Codec:
CodecType type, ServerDataFormat format,
ClientExchangeFormat xformat,
encode_func c_encoder, decode_func c_decoder,
Codec base_codec,
object py_encoder, object py_decoder,
Codec element_codec, tuple element_type_oids,
object element_names, list element_codecs,
Expand Down Expand Up @@ -169,6 +171,7 @@ cdef class Codec:
object decoder,
encode_func c_encoder,
decode_func c_decoder,
Codec base_codec,
ServerDataFormat format,
ClientExchangeFormat xformat)

Expand Down
90 changes: 62 additions & 28 deletions asyncpg/protocol/codecs/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,25 @@ cdef class Codec:
self.oid = oid
self.type = CODEC_UNDEFINED

cdef init(self, str name, str schema, str kind,
CodecType type, ServerDataFormat format,
ClientExchangeFormat xformat,
encode_func c_encoder, decode_func c_decoder,
object py_encoder, object py_decoder,
Codec element_codec, tuple element_type_oids,
object element_names, list element_codecs,
Py_UCS4 element_delimiter):
cdef init(
self,
str name,
str schema,
str kind,
CodecType type,
ServerDataFormat format,
ClientExchangeFormat xformat,
encode_func c_encoder,
decode_func c_decoder,
Codec base_codec,
object py_encoder,
object py_decoder,
Codec element_codec,
tuple element_type_oids,
object element_names,
list element_codecs,
Py_UCS4 element_delimiter,
):

self.name = name
self.schema = schema
Expand All @@ -40,6 +51,7 @@ cdef class Codec:
self.xformat = xformat
self.c_encoder = c_encoder
self.c_decoder = c_decoder
self.base_codec = base_codec
self.py_encoder = py_encoder
self.py_decoder = py_decoder
self.element_codec = element_codec
Expand All @@ -48,6 +60,12 @@ cdef class Codec:
self.element_delimiter = element_delimiter
self.element_names = element_names

if base_codec is not None:
if c_encoder != NULL or c_decoder != NULL:
raise exceptions.InternalClientError(
'base_codec is mutually exclusive with c_encoder/c_decoder'
)

if element_names is not None:
self.record_desc = record.ApgRecordDesc_New(
element_names, tuple(element_names))
Expand Down Expand Up @@ -98,7 +116,7 @@ cdef class Codec:
codec = Codec(self.oid)
codec.init(self.name, self.schema, self.kind,
self.type, self.format, self.xformat,
self.c_encoder, self.c_decoder,
self.c_encoder, self.c_decoder, self.base_codec,
self.py_encoder, self.py_decoder,
self.element_codec,
self.element_type_oids, self.element_names,
Expand Down Expand Up @@ -196,7 +214,10 @@ cdef class Codec:
raise exceptions.InternalClientError(
'unexpected data format: {}'.format(self.format))
elif self.xformat == PG_XFORMAT_TUPLE:
self.c_encoder(settings, buf, data)
if self.base_codec is not None:
self.base_codec.encode(settings, buf, data)
else:
self.c_encoder(settings, buf, data)
else:
raise exceptions.InternalClientError(
'unexpected exchange format: {}'.format(self.xformat))
Expand Down Expand Up @@ -295,7 +316,10 @@ cdef class Codec:
raise exceptions.InternalClientError(
'unexpected data format: {}'.format(self.format))
elif self.xformat == PG_XFORMAT_TUPLE:
data = self.c_decoder(settings, buf)
if self.base_codec is not None:
data = self.base_codec.decode(settings, buf)
else:
data = self.c_decoder(settings, buf)
else:
raise exceptions.InternalClientError(
'unexpected exchange format: {}'.format(self.xformat))
Expand Down Expand Up @@ -367,8 +391,8 @@ cdef class Codec:
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'array', CODEC_ARRAY, element_codec.format,
PG_XFORMAT_OBJECT, NULL, NULL, None, None, element_codec,
None, None, None, element_delimiter)
PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
element_codec, None, None, None, element_delimiter)
return codec

@staticmethod
Expand All @@ -379,8 +403,8 @@ cdef class Codec:
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'range', CODEC_RANGE, element_codec.format,
PG_XFORMAT_OBJECT, NULL, NULL, None, None, element_codec,
None, None, None, 0)
PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
element_codec, None, None, None, 0)
return codec

@staticmethod
Expand All @@ -391,7 +415,7 @@ cdef class Codec:
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'multirange', CODEC_MULTIRANGE,
element_codec.format, PG_XFORMAT_OBJECT, NULL, NULL,
element_codec.format, PG_XFORMAT_OBJECT, NULL, NULL, None,
None, None, element_codec, None, None, None, 0)
return codec

Expand All @@ -407,7 +431,7 @@ cdef class Codec:
codec = Codec(oid)
codec.init(name, schema, 'composite', CODEC_COMPOSITE,
format, PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
element_type_oids, element_names, element_codecs, 0)
None, element_type_oids, element_names, element_codecs, 0)
return codec

@staticmethod
Expand All @@ -419,12 +443,13 @@ cdef class Codec:
object decoder,
encode_func c_encoder,
decode_func c_decoder,
Codec base_codec,
ServerDataFormat format,
ClientExchangeFormat xformat):
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, kind, CODEC_PY, format, xformat,
c_encoder, c_decoder, encoder, decoder,
c_encoder, c_decoder, base_codec, encoder, decoder,
None, None, None, None, 0)
return codec

Expand Down Expand Up @@ -596,34 +621,43 @@ cdef class DataCodecConfig:
self.declare_fallback_codec(oid, name, schema)

def add_python_codec(self, typeoid, typename, typeschema, typekind,
encoder, decoder, format, xformat):
typeinfos, encoder, decoder, format, xformat):
cdef:
Codec core_codec
Codec core_codec = None
encode_func c_encoder = NULL
decode_func c_decoder = NULL
Codec base_codec = None
uint32_t oid = pylong_as_oid(typeoid)
bint codec_set = False

# Clear all previous overrides (this also clears type cache).
self.remove_python_codec(typeoid, typename, typeschema)

if typeinfos:
self.add_types(typeinfos)

if format == PG_FORMAT_ANY:
formats = (PG_FORMAT_TEXT, PG_FORMAT_BINARY)
else:
formats = (format,)

for fmt in formats:
if xformat == PG_XFORMAT_TUPLE:
core_codec = get_core_codec(oid, fmt, xformat)
if core_codec is None:
continue
c_encoder = core_codec.c_encoder
c_decoder = core_codec.c_decoder
if typekind == "scalar":
core_codec = get_core_codec(oid, fmt, xformat)
if core_codec is None:
continue
c_encoder = core_codec.c_encoder
c_decoder = core_codec.c_decoder
elif typekind == "composite":
base_codec = self.get_codec(oid, fmt)
if base_codec is None:
continue

self._custom_type_codecs[typeoid, fmt] = \
Codec.new_python_codec(oid, typename, typeschema, typekind,
encoder, decoder, c_encoder, c_decoder,
fmt, xformat)
base_codec, fmt, xformat)
codec_set = True

if not codec_set:
Expand Down Expand Up @@ -829,7 +863,7 @@ cdef register_core_codec(uint32_t oid,

codec = Codec(oid)
codec.init(name, 'pg_catalog', kind, CODEC_C, format, xformat,
encode, decode, None, None, None, None, None, None, 0)
encode, decode, None, None, None, None, None, None, None, 0)
cpython.Py_INCREF(codec) # immortalize

if format == PG_FORMAT_BINARY:
Expand All @@ -853,7 +887,7 @@ cdef register_extra_codec(str name,

codec = Codec(INVALIDOID)
codec.init(name, None, kind, CODEC_C, format, PG_XFORMAT_OBJECT,
encode, decode, None, None, None, None, None, None, 0)
encode, decode, None, None, None, None, None, None, None, 0)
EXTRA_CODECS[name, format] = codec


Expand Down
2 changes: 1 addition & 1 deletion asyncpg/protocol/settings.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ cdef class ConnectionSettings(pgproto.CodecContext):
cpdef get_text_codec(self)
cpdef inline register_data_types(self, types)
cpdef inline add_python_codec(
self, typeoid, typename, typeschema, typekind, encoder,
self, typeoid, typename, typeschema, typeinfos, typekind, encoder,
decoder, format)
cpdef inline remove_python_codec(
self, typeoid, typename, typeschema)
Expand Down
6 changes: 4 additions & 2 deletions asyncpg/protocol/settings.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ cdef class ConnectionSettings(pgproto.CodecContext):
self._data_codecs.add_types(types)

cpdef inline add_python_codec(self, typeoid, typename, typeschema,
typekind, encoder, decoder, format):
typeinfos, typekind, encoder, decoder,
format):
cdef:
ServerDataFormat _format
ClientExchangeFormat xformat
Expand All @@ -57,7 +58,8 @@ cdef class ConnectionSettings(pgproto.CodecContext):
))

self._data_codecs.add_python_codec(typeoid, typename, typeschema,
typekind, encoder, decoder,
typekind, typeinfos,
encoder, decoder,
_format, xformat)

cpdef inline remove_python_codec(self, typeoid, typename, typeschema):
Expand Down
43 changes: 41 additions & 2 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,46 @@ JSON values using the :mod:`json <python:json>` module.
finally:
await conn.close()

asyncio.get_event_loop().run_until_complete(main())
asyncio.run(main())


Example: complex types
~~~~~~~~~~~~~~~~~~~~~~

The example below shows how to configure asyncpg to encode and decode
Python :class:`complex <python:complex>` values to a custom composite
type in PostgreSQL.

.. code-block:: python

import asyncio
import asyncpg


async def main():
conn = await asyncpg.connect()

try:
await conn.execute(
'''
CREATE TYPE mycomplex AS (
r float,
i float
);'''
)
await conn.set_type_codec(
'complex',
encoder=lambda x: (x.real, x.imag),
decoder=lambda t: complex(t[0], t[1]),
format='tuple',
)

res = await conn.fetchval('SELECT $1::mycomplex', (1+2j))

finally:
await conn.close()

asyncio.run(main())


Example: automatic conversion of PostGIS types
Expand Down Expand Up @@ -274,7 +313,7 @@ will work.
finally:
await conn.close()

asyncio.get_event_loop().run_until_complete(main())
asyncio.run(main())


Example: decoding numeric columns as floats
Expand Down
Loading