Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix set_type_codec() to accept standard SQL type names #619

Merged
merged 2 commits into from
Sep 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 30 additions & 21 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,32 @@ async def _introspect_types(self, typeoids, timeout):
return await self.__execute(
self._intro_query, (list(typeoids),), 0, timeout)

async def _introspect_type(self, typename, schema):
if (
schema == 'pg_catalog'
and typename.lower() in protocol.BUILTIN_TYPE_NAME_MAP
):
typeoid = protocol.BUILTIN_TYPE_NAME_MAP[typename.lower()]
rows = await self._execute(
introspection.TYPE_BY_OID,
[typeoid],
limit=0,
timeout=None,
)
if rows:
typeinfo = rows[0]
else:
typeinfo = None
else:
typeinfo = await self.fetchrow(
introspection.TYPE_BY_NAME, typename, schema)

if not typeinfo:
raise ValueError(
'unknown type: {}.{}'.format(schema, typename))

return typeinfo

def cursor(
self,
query,
Expand Down Expand Up @@ -1108,12 +1134,7 @@ async def set_type_codec(self, typename, *,
``format``.
"""
self._check_open()

typeinfo = await self.fetchrow(
introspection.TYPE_BY_NAME, typename, schema)
if not typeinfo:
raise ValueError('unknown type: {}.{}'.format(schema, typename))

typeinfo = await self._introspect_type(typename, schema)
if not introspection.is_scalar_type(typeinfo):
raise ValueError(
'cannot use custom codec on non-scalar type {}.{}'.format(
Expand All @@ -1140,15 +1161,9 @@ async def reset_type_codec(self, typename, *, schema='public'):
.. versionadded:: 0.12.0
"""

typeinfo = await self.fetchrow(
introspection.TYPE_BY_NAME, typename, schema)
if not typeinfo:
raise ValueError('unknown type: {}.{}'.format(schema, typename))

oid = typeinfo['oid']

typeinfo = await self._introspect_type(typename, schema)
self._protocol.get_settings().remove_python_codec(
oid, typename, schema)
typeinfo['oid'], typename, schema)

# Statement cache is no longer valid due to codec changes.
self._drop_local_statement_cache()
Expand Down Expand Up @@ -1189,13 +1204,7 @@ async def set_builtin_type_codec(self, typename, *,
core data type. Added the *format* keyword argument.
"""
self._check_open()

typeinfo = await self.fetchrow(
introspection.TYPE_BY_NAME, typename, schema)
if not typeinfo:
raise exceptions.InterfaceError(
'unknown type: {}.{}'.format(schema, typename))

typeinfo = await self._introspect_type(typename, schema)
if not introspection.is_scalar_type(typeinfo):
raise exceptions.InterfaceError(
'cannot alias non-scalar type {}.{}'.format(
Expand Down
12 changes: 12 additions & 0 deletions asyncpg/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,18 @@
'''


TYPE_BY_OID = '''\
SELECT
t.oid,
t.typelem AS elemtype,
t.typtype AS kind
FROM
pg_catalog.pg_type AS t
WHERE
t.oid = $1
'''


# 'b' for a base type, 'd' for a domain, 'e' for enum.
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')

Expand Down
3 changes: 2 additions & 1 deletion asyncpg/protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

# flake8: NOQA

from .protocol import Protocol, Record, NO_TIMEOUT # NOQA
from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP
18 changes: 18 additions & 0 deletions asyncpg/protocol/pgtypes.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -216,5 +216,23 @@ BUILTIN_TYPE_NAME_MAP['double precision'] = \
BUILTIN_TYPE_NAME_MAP['timestamp with timezone'] = \
BUILTIN_TYPE_NAME_MAP['timestamptz']

BUILTIN_TYPE_NAME_MAP['timestamp without timezone'] = \
BUILTIN_TYPE_NAME_MAP['timestamp']

BUILTIN_TYPE_NAME_MAP['time with timezone'] = \
BUILTIN_TYPE_NAME_MAP['timetz']

BUILTIN_TYPE_NAME_MAP['time without timezone'] = \
BUILTIN_TYPE_NAME_MAP['time']

BUILTIN_TYPE_NAME_MAP['char'] = \
BUILTIN_TYPE_NAME_MAP['bpchar']

BUILTIN_TYPE_NAME_MAP['character'] = \
BUILTIN_TYPE_NAME_MAP['bpchar']

BUILTIN_TYPE_NAME_MAP['character varying'] = \
BUILTIN_TYPE_NAME_MAP['varchar']

BUILTIN_TYPE_NAME_MAP['bit varying'] = \
BUILTIN_TYPE_NAME_MAP['varbit']
33 changes: 33 additions & 0 deletions tests/test_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,39 @@ async def test_custom_codec_on_domain(self):
finally:
await self.con.execute('DROP DOMAIN custom_codec_t')

async def test_custom_codec_on_stdsql_types(self):
types = [
'smallint',
'int',
'integer',
'bigint',
'decimal',
'real',
'double precision',
'timestamp with timezone',
'time with timezone',
'timestamp without timezone',
'time without timezone',
'char',
'character',
'character varying',
'bit varying',
'CHARACTER VARYING'
]

for t in types:
with self.subTest(type=t):
try:
await self.con.set_type_codec(
t,
schema='pg_catalog',
encoder=str,
decoder=str,
format='text'
)
finally:
await self.con.reset_type_codec(t, schema='pg_catalog')

async def test_custom_codec_on_enum(self):
"""Test encoding/decoding using a custom codec on an enum."""
await self.con.execute('''
Expand Down