Skip to content

Commit

Permalink
Raise a consistent exception on input encoding errors
Browse files Browse the repository at this point in the history
Currently, when invalid input is passed a query argument, asyncpg will
raise whatever exception was triggered by the codec function, which
can be `TypeError`, `ValueError`, `decimal.InvalidOperation` etc.
Additionally, these exceptions lack sufficient context as to which
argument actually triggered an error.

Fix this by consistently raising the new `asyncpg.DataError` exception,
which is a subclass of `asyncpg.InterfaceError` and `ValueError`, and
include the position of the offending argument as well as the passed
value, e.g:

    asyncpg.exceptions.DataError: invalid input for query argument $1: 'aaa'
    (a bytes-like object is required, not 'str')

Fixes: #260
  • Loading branch information
elprans committed Jun 1, 2018
1 parent 482a186 commit 0ddfa46
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 50 deletions.
4 changes: 4 additions & 0 deletions asyncpg/exceptions/_base.py
Expand Up @@ -209,6 +209,10 @@ def __init__(self, msg, *, detail=None, hint=None):
Exception.__init__(self, msg)


class DataError(InterfaceError, ValueError):
"""An error caused by invalid query input."""


class InterfaceWarning(InterfaceMessage, UserWarning):
"""A warning caused by an improper use of asyncpg API."""

Expand Down
2 changes: 1 addition & 1 deletion asyncpg/protocol/codecs/float.pyx
Expand Up @@ -12,7 +12,7 @@ cdef float4_encode(ConnectionSettings settings, WriteBuffer buf, obj):
cdef double dval = cpython.PyFloat_AsDouble(obj)
cdef float fval = <float>dval
if math.isinf(fval) and not math.isinf(dval):
raise ValueError('float value too large to be encoded as FLOAT4')
raise ValueError('value out of float32 range')

buf.write_int32(4)
buf.write_float(fval)
Expand Down
12 changes: 4 additions & 8 deletions asyncpg/protocol/codecs/int.pyx
Expand Up @@ -28,8 +28,7 @@ cdef int2_encode(ConnectionSettings settings, WriteBuffer buf, obj):
overflow = 1

if overflow or val < INT16_MIN or val > INT16_MAX:
raise OverflowError(
'int16 value out of range: {!r}'.format(obj))
raise OverflowError('value out of int16 range')

buf.write_int32(2)
buf.write_int16(<int16_t>val)
Expand All @@ -50,8 +49,7 @@ cdef int4_encode(ConnectionSettings settings, WriteBuffer buf, obj):

# "long" and "long long" have the same size for x86_64, need an extra check
if overflow or (sizeof(val) > 4 and (val < INT32_MIN or val > INT32_MAX)):
raise OverflowError(
'int32 value out of range: {!r}'.format(obj))
raise OverflowError('value out of int32 range')

buf.write_int32(4)
buf.write_int32(<int32_t>val)
Expand All @@ -72,8 +70,7 @@ cdef uint4_encode(ConnectionSettings settings, WriteBuffer buf, obj):

# "long" and "long long" have the same size for x86_64, need an extra check
if overflow or (sizeof(val) > 4 and val > UINT32_MAX):
raise OverflowError(
'uint32 value out of range: {!r}'.format(obj))
raise OverflowError('value out of uint32 range')

buf.write_int32(4)
buf.write_int32(<int32_t>val)
Expand All @@ -95,8 +92,7 @@ cdef int8_encode(ConnectionSettings settings, WriteBuffer buf, obj):

# Just in case for systems with "long long" bigger than 8 bytes
if overflow or (sizeof(val) > 8 and (val < INT64_MIN or val > INT64_MAX)):
raise OverflowError(
'int64 value out of range: {!r}'.format(obj))
raise OverflowError('value out of int64 range')

buf.write_int32(8)
buf.write_int64(<int64_t>val)
Expand Down
6 changes: 2 additions & 4 deletions asyncpg/protocol/codecs/tid.pyx
Expand Up @@ -24,8 +24,7 @@ cdef tid_encode(ConnectionSettings settings, WriteBuffer buf, obj):

# "long" and "long long" have the same size for x86_64, need an extra check
if overflow or (sizeof(block) > 4 and block > UINT32_MAX):
raise OverflowError(
'tuple id block value out of range: {!r}'.format(obj[0]))
raise OverflowError('tuple id block value out of uint32 range')

try:
offset = cpython.PyLong_AsUnsignedLong(obj[1])
Expand All @@ -34,8 +33,7 @@ cdef tid_encode(ConnectionSettings settings, WriteBuffer buf, obj):
overflow = 1

if overflow or offset > 65535:
raise OverflowError(
'tuple id offset value out of range: {!r}'.format(obj[1]))
raise OverflowError('tuple id offset value out of uint16 range')

buf.write_int32(6)
buf.write_int32(<int32_t>block)
Expand Down
26 changes: 22 additions & 4 deletions asyncpg/protocol/prepared_stmt.pyx
Expand Up @@ -127,7 +127,7 @@ cdef class PreparedStatementState:

if self.have_text_args:
writer.write_int16(self.args_num)
for idx from 0 <= idx < self.args_num:
for idx in range(self.args_num):
codec = <Codec>(self.args_codecs[idx])
writer.write_int16(codec.format)
else:
Expand All @@ -136,17 +136,35 @@ cdef class PreparedStatementState:

writer.write_int16(self.args_num)

for idx from 0 <= idx < self.args_num:
for idx in range(self.args_num):
arg = args[idx]
if arg is None:
writer.write_int32(-1)
else:
codec = <Codec>(self.args_codecs[idx])
codec.encode(self.settings, writer, arg)
try:
codec.encode(self.settings, writer, arg)
except (AssertionError, exceptions.InternalClientError):
# These are internal errors and should raise as-is.
raise
except exceptions.InterfaceError:
# This is already a descriptive error.
raise
except Exception as e:
# Everything else is assumed to be an encoding error
# due to invalid input.
value_repr = repr(arg)
if len(value_repr) > 40:
value_repr = value_repr[:40] + '...'

raise exceptions.DataError(
'invalid input for query argument'
' ${n}: {v} ({msg})'.format(
n=idx + 1, v=value_repr, msg=e)) from e

if self.have_text_cols:
writer.write_int16(self.cols_num)
for idx from 0 <= idx < self.cols_num:
for idx in range(self.cols_num):
codec = <Codec>(self.rows_codecs[idx])
writer.write_int16(codec.format)
else:
Expand Down
75 changes: 42 additions & 33 deletions tests/test_codecs.py
Expand Up @@ -542,17 +542,19 @@ async def test_numeric(self):
"SELECT $1::numeric", decimal.Decimal('sNaN'))
self.assertTrue(res.is_nan())

with self.assertRaisesRegex(ValueError, 'numeric type does not '
'support infinite values'):
with self.assertRaisesRegex(asyncpg.DataError,
'numeric type does not '
'support infinite values'):
await self.con.fetchval(
"SELECT $1::numeric", decimal.Decimal('-Inf'))

with self.assertRaisesRegex(ValueError, 'numeric type does not '
'support infinite values'):
with self.assertRaisesRegex(asyncpg.DataError,
'numeric type does not '
'support infinite values'):
await self.con.fetchval(
"SELECT $1::numeric", decimal.Decimal('+Inf'))

with self.assertRaises(decimal.InvalidOperation):
with self.assertRaisesRegex(asyncpg.DataError, 'invalid'):
await self.con.fetchval(
"SELECT $1::numeric", 'invalid')

Expand All @@ -578,91 +580,95 @@ async def test_unhandled_type_fallback(self):

async def test_invalid_input(self):
cases = [
('bytea', TypeError, 'a bytes-like object is required', [
('bytea', 'a bytes-like object is required', [
1,
'aaa'
]),
('bool', TypeError, 'a boolean is required', [
('bool', 'a boolean is required', [
1,
]),
('int2', TypeError, 'an integer is required', [
('int2', 'an integer is required', [
'2',
'aa',
]),
('smallint', OverflowError, 'int16 value out of range', [
('smallint', 'value out of int16 range', [
2**256, # check for the same exception for any big numbers
decimal.Decimal("2000000000000000000000000000000"),
0xffff,
0xffffffff,
32768,
-32769
]),
('float4', ValueError, 'float value too large', [
('float4', 'value out of float32 range', [
4.1 * 10 ** 40,
-4.1 * 10 ** 40,
]),
('int4', TypeError, 'an integer is required', [
('int4', 'an integer is required', [
'2',
'aa',
]),
('int', OverflowError, 'int32 value out of range', [
('int', 'value out of int32 range', [
2**256, # check for the same exception for any big numbers
decimal.Decimal("2000000000000000000000000000000"),
0xffffffff,
2**31,
-2**31 - 1,
]),
('int8', TypeError, 'an integer is required', [
('int8', 'an integer is required', [
'2',
'aa',
]),
('bigint', OverflowError, 'int64 value out of range', [
('bigint', 'value out of int64 range', [
2**256, # check for the same exception for any big numbers
decimal.Decimal("2000000000000000000000000000000"),
0xffffffffffffffff,
2**63,
-2**63 - 1,
]),
('text', TypeError, 'expected str, got bytes', [
('text', 'expected str, got bytes', [
b'foo'
]),
('text', TypeError, 'expected str, got list', [
('text', 'expected str, got list', [
[1]
]),
('tid', TypeError, 'list or tuple expected', [
('tid', 'list or tuple expected', [
b'foo'
]),
('tid', ValueError, 'invalid number of elements in tid tuple', [
('tid', 'invalid number of elements in tid tuple', [
[],
(),
[1, 2, 3],
(4,),
]),
('tid', OverflowError, 'tuple id block value out of range', [
('tid', 'tuple id block value out of uint32 range', [
(-1, 0),
(2**256, 0),
(0xffffffff + 1, 0),
(2**32, 0),
]),
('tid', OverflowError, 'tuple id offset value out of range', [
('tid', 'tuple id offset value out of uint16 range', [
(0, -1),
(0, 2**256),
(0, 0xffff + 1),
(0, 0xffffffff),
(0, 65536),
]),
('oid', OverflowError, 'uint32 value out of range', [
('oid', 'value out of uint32 range', [
2 ** 32,
-1,
]),
]

for typname, errcls, errmsg, data in cases:
for typname, errmsg, data in cases:
stmt = await self.con.prepare("SELECT $1::" + typname)

for sample in data:
with self.subTest(sample=sample, typname=typname):
with self.assertRaisesRegex(errcls, errmsg):
full_errmsg = (
r'invalid input for query argument \$1:.*' + errmsg)

with self.assertRaisesRegex(
asyncpg.DataError, full_errmsg):
await stmt.fetchval(sample)

async def test_arrays(self):
Expand Down Expand Up @@ -733,37 +739,39 @@ class SomeContainer:
def __contains__(self, item):
return False

with self.assertRaisesRegex(TypeError,
with self.assertRaisesRegex(asyncpg.DataError,
'sized iterable container expected'):
result = await self.con.fetchval("SELECT $1::int[]",
SomeContainer())

with self.assertRaisesRegex(ValueError, 'dimensions'):
with self.assertRaisesRegex(asyncpg.DataError, 'dimensions'):
await self.con.fetchval(
"SELECT $1::int[]",
[[[[[[[1]]]]]]])

with self.assertRaisesRegex(ValueError, 'non-homogeneous'):
with self.assertRaisesRegex(asyncpg.DataError, 'non-homogeneous'):
await self.con.fetchval(
"SELECT $1::int[]",
[1, [1]])

with self.assertRaisesRegex(ValueError, 'non-homogeneous'):
with self.assertRaisesRegex(asyncpg.DataError, 'non-homogeneous'):
await self.con.fetchval(
"SELECT $1::int[]",
[[1], 1, [2]])

with self.assertRaisesRegex(ValueError, 'invalid array element'):
with self.assertRaisesRegex(asyncpg.DataError,
'invalid array element'):
await self.con.fetchval(
"SELECT $1::int[]",
[1, 't', 2])

with self.assertRaisesRegex(ValueError, 'invalid array element'):
with self.assertRaisesRegex(asyncpg.DataError,
'invalid array element'):
await self.con.fetchval(
"SELECT $1::int[]",
[[1], ['t'], [2]])

with self.assertRaisesRegex(TypeError,
with self.assertRaisesRegex(asyncpg.DataError,
'sized iterable container expected'):
await self.con.fetchval(
"SELECT $1::int[]",
Expand Down Expand Up @@ -887,11 +895,11 @@ async def test_range_types(self):
self.assertEqual(result, expected)

with self.assertRaisesRegex(
TypeError, 'list, tuple or Range object expected'):
asyncpg.DataError, 'list, tuple or Range object expected'):
await self.con.fetch("SELECT $1::int4range", 'aa')

with self.assertRaisesRegex(
ValueError, 'expected 0, 1 or 2 elements'):
asyncpg.DataError, 'expected 0, 1 or 2 elements'):
await self.con.fetch("SELECT $1::int4range", (0, 2, 3))

cases = [(asyncpg.Range(0, 1), asyncpg.Range(0, 1), 1),
Expand Down Expand Up @@ -933,7 +941,8 @@ async def test_extra_codec_alias(self):

self.assertEqual(res, {'foo': '2', 'bar': '3'})

with self.assertRaisesRegex(ValueError, 'null value not allowed'):
with self.assertRaisesRegex(asyncpg.DataError,
'null value not allowed'):
await self.con.fetchval('''
SELECT $1::hstore AS result
''', {None: '1'})
Expand Down

0 comments on commit 0ddfa46

Please sign in to comment.