From 0ddfa4666f967b1f0a10f321ee9f87ce41a9cb7f Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Fri, 1 Jun 2018 11:08:20 -0400 Subject: [PATCH] Raise a consistent exception on input encoding errors Currently, when invalid input is passed a query argument, asyncpg will raise whatever exception was triggered by the codec function, which can be `TypeError`, `ValueError`, `decimal.InvalidOperation` etc. Additionally, these exceptions lack sufficient context as to which argument actually triggered an error. Fix this by consistently raising the new `asyncpg.DataError` exception, which is a subclass of `asyncpg.InterfaceError` and `ValueError`, and include the position of the offending argument as well as the passed value, e.g: asyncpg.exceptions.DataError: invalid input for query argument $1: 'aaa' (a bytes-like object is required, not 'str') Fixes: #260 --- asyncpg/exceptions/_base.py | 4 ++ asyncpg/protocol/codecs/float.pyx | 2 +- asyncpg/protocol/codecs/int.pyx | 12 ++--- asyncpg/protocol/codecs/tid.pyx | 6 +-- asyncpg/protocol/prepared_stmt.pyx | 26 +++++++++-- tests/test_codecs.py | 75 +++++++++++++++++------------- 6 files changed, 75 insertions(+), 50 deletions(-) diff --git a/asyncpg/exceptions/_base.py b/asyncpg/exceptions/_base.py index 8f90e3b0..d59cf671 100644 --- a/asyncpg/exceptions/_base.py +++ b/asyncpg/exceptions/_base.py @@ -209,6 +209,10 @@ def __init__(self, msg, *, detail=None, hint=None): Exception.__init__(self, msg) +class DataError(InterfaceError, ValueError): + """An error caused by invalid query input.""" + + class InterfaceWarning(InterfaceMessage, UserWarning): """A warning caused by an improper use of asyncpg API.""" diff --git a/asyncpg/protocol/codecs/float.pyx b/asyncpg/protocol/codecs/float.pyx index b1f50ff6..27b2cf3a 100644 --- a/asyncpg/protocol/codecs/float.pyx +++ b/asyncpg/protocol/codecs/float.pyx @@ -12,7 +12,7 @@ cdef float4_encode(ConnectionSettings settings, WriteBuffer buf, obj): cdef double dval = cpython.PyFloat_AsDouble(obj) cdef float fval = dval if math.isinf(fval) and not math.isinf(dval): - raise ValueError('float value too large to be encoded as FLOAT4') + raise ValueError('value out of float32 range') buf.write_int32(4) buf.write_float(fval) diff --git a/asyncpg/protocol/codecs/int.pyx b/asyncpg/protocol/codecs/int.pyx index eb65d85c..514dafff 100644 --- a/asyncpg/protocol/codecs/int.pyx +++ b/asyncpg/protocol/codecs/int.pyx @@ -28,8 +28,7 @@ cdef int2_encode(ConnectionSettings settings, WriteBuffer buf, obj): overflow = 1 if overflow or val < INT16_MIN or val > INT16_MAX: - raise OverflowError( - 'int16 value out of range: {!r}'.format(obj)) + raise OverflowError('value out of int16 range') buf.write_int32(2) buf.write_int16(val) @@ -50,8 +49,7 @@ cdef int4_encode(ConnectionSettings settings, WriteBuffer buf, obj): # "long" and "long long" have the same size for x86_64, need an extra check if overflow or (sizeof(val) > 4 and (val < INT32_MIN or val > INT32_MAX)): - raise OverflowError( - 'int32 value out of range: {!r}'.format(obj)) + raise OverflowError('value out of int32 range') buf.write_int32(4) buf.write_int32(val) @@ -72,8 +70,7 @@ cdef uint4_encode(ConnectionSettings settings, WriteBuffer buf, obj): # "long" and "long long" have the same size for x86_64, need an extra check if overflow or (sizeof(val) > 4 and val > UINT32_MAX): - raise OverflowError( - 'uint32 value out of range: {!r}'.format(obj)) + raise OverflowError('value out of uint32 range') buf.write_int32(4) buf.write_int32(val) @@ -95,8 +92,7 @@ cdef int8_encode(ConnectionSettings settings, WriteBuffer buf, obj): # Just in case for systems with "long long" bigger than 8 bytes if overflow or (sizeof(val) > 8 and (val < INT64_MIN or val > INT64_MAX)): - raise OverflowError( - 'int64 value out of range: {!r}'.format(obj)) + raise OverflowError('value out of int64 range') buf.write_int32(8) buf.write_int64(val) diff --git a/asyncpg/protocol/codecs/tid.pyx b/asyncpg/protocol/codecs/tid.pyx index a64e1901..8db6338e 100644 --- a/asyncpg/protocol/codecs/tid.pyx +++ b/asyncpg/protocol/codecs/tid.pyx @@ -24,8 +24,7 @@ cdef tid_encode(ConnectionSettings settings, WriteBuffer buf, obj): # "long" and "long long" have the same size for x86_64, need an extra check if overflow or (sizeof(block) > 4 and block > UINT32_MAX): - raise OverflowError( - 'tuple id block value out of range: {!r}'.format(obj[0])) + raise OverflowError('tuple id block value out of uint32 range') try: offset = cpython.PyLong_AsUnsignedLong(obj[1]) @@ -34,8 +33,7 @@ cdef tid_encode(ConnectionSettings settings, WriteBuffer buf, obj): overflow = 1 if overflow or offset > 65535: - raise OverflowError( - 'tuple id offset value out of range: {!r}'.format(obj[1])) + raise OverflowError('tuple id offset value out of uint16 range') buf.write_int32(6) buf.write_int32(block) diff --git a/asyncpg/protocol/prepared_stmt.pyx b/asyncpg/protocol/prepared_stmt.pyx index 7e0d6e31..fb2dc55d 100644 --- a/asyncpg/protocol/prepared_stmt.pyx +++ b/asyncpg/protocol/prepared_stmt.pyx @@ -127,7 +127,7 @@ cdef class PreparedStatementState: if self.have_text_args: writer.write_int16(self.args_num) - for idx from 0 <= idx < self.args_num: + for idx in range(self.args_num): codec = (self.args_codecs[idx]) writer.write_int16(codec.format) else: @@ -136,17 +136,35 @@ cdef class PreparedStatementState: writer.write_int16(self.args_num) - for idx from 0 <= idx < self.args_num: + for idx in range(self.args_num): arg = args[idx] if arg is None: writer.write_int32(-1) else: codec = (self.args_codecs[idx]) - codec.encode(self.settings, writer, arg) + try: + codec.encode(self.settings, writer, arg) + except (AssertionError, exceptions.InternalClientError): + # These are internal errors and should raise as-is. + raise + except exceptions.InterfaceError: + # This is already a descriptive error. + raise + except Exception as e: + # Everything else is assumed to be an encoding error + # due to invalid input. + value_repr = repr(arg) + if len(value_repr) > 40: + value_repr = value_repr[:40] + '...' + + raise exceptions.DataError( + 'invalid input for query argument' + ' ${n}: {v} ({msg})'.format( + n=idx + 1, v=value_repr, msg=e)) from e if self.have_text_cols: writer.write_int16(self.cols_num) - for idx from 0 <= idx < self.cols_num: + for idx in range(self.cols_num): codec = (self.rows_codecs[idx]) writer.write_int16(codec.format) else: diff --git a/tests/test_codecs.py b/tests/test_codecs.py index 152e1e48..593083ce 100644 --- a/tests/test_codecs.py +++ b/tests/test_codecs.py @@ -542,17 +542,19 @@ async def test_numeric(self): "SELECT $1::numeric", decimal.Decimal('sNaN')) self.assertTrue(res.is_nan()) - with self.assertRaisesRegex(ValueError, 'numeric type does not ' - 'support infinite values'): + with self.assertRaisesRegex(asyncpg.DataError, + 'numeric type does not ' + 'support infinite values'): await self.con.fetchval( "SELECT $1::numeric", decimal.Decimal('-Inf')) - with self.assertRaisesRegex(ValueError, 'numeric type does not ' - 'support infinite values'): + with self.assertRaisesRegex(asyncpg.DataError, + 'numeric type does not ' + 'support infinite values'): await self.con.fetchval( "SELECT $1::numeric", decimal.Decimal('+Inf')) - with self.assertRaises(decimal.InvalidOperation): + with self.assertRaisesRegex(asyncpg.DataError, 'invalid'): await self.con.fetchval( "SELECT $1::numeric", 'invalid') @@ -578,18 +580,18 @@ async def test_unhandled_type_fallback(self): async def test_invalid_input(self): cases = [ - ('bytea', TypeError, 'a bytes-like object is required', [ + ('bytea', 'a bytes-like object is required', [ 1, 'aaa' ]), - ('bool', TypeError, 'a boolean is required', [ + ('bool', 'a boolean is required', [ 1, ]), - ('int2', TypeError, 'an integer is required', [ + ('int2', 'an integer is required', [ '2', 'aa', ]), - ('smallint', OverflowError, 'int16 value out of range', [ + ('smallint', 'value out of int16 range', [ 2**256, # check for the same exception for any big numbers decimal.Decimal("2000000000000000000000000000000"), 0xffff, @@ -597,72 +599,76 @@ async def test_invalid_input(self): 32768, -32769 ]), - ('float4', ValueError, 'float value too large', [ + ('float4', 'value out of float32 range', [ 4.1 * 10 ** 40, -4.1 * 10 ** 40, ]), - ('int4', TypeError, 'an integer is required', [ + ('int4', 'an integer is required', [ '2', 'aa', ]), - ('int', OverflowError, 'int32 value out of range', [ + ('int', 'value out of int32 range', [ 2**256, # check for the same exception for any big numbers decimal.Decimal("2000000000000000000000000000000"), 0xffffffff, 2**31, -2**31 - 1, ]), - ('int8', TypeError, 'an integer is required', [ + ('int8', 'an integer is required', [ '2', 'aa', ]), - ('bigint', OverflowError, 'int64 value out of range', [ + ('bigint', 'value out of int64 range', [ 2**256, # check for the same exception for any big numbers decimal.Decimal("2000000000000000000000000000000"), 0xffffffffffffffff, 2**63, -2**63 - 1, ]), - ('text', TypeError, 'expected str, got bytes', [ + ('text', 'expected str, got bytes', [ b'foo' ]), - ('text', TypeError, 'expected str, got list', [ + ('text', 'expected str, got list', [ [1] ]), - ('tid', TypeError, 'list or tuple expected', [ + ('tid', 'list or tuple expected', [ b'foo' ]), - ('tid', ValueError, 'invalid number of elements in tid tuple', [ + ('tid', 'invalid number of elements in tid tuple', [ [], (), [1, 2, 3], (4,), ]), - ('tid', OverflowError, 'tuple id block value out of range', [ + ('tid', 'tuple id block value out of uint32 range', [ (-1, 0), (2**256, 0), (0xffffffff + 1, 0), (2**32, 0), ]), - ('tid', OverflowError, 'tuple id offset value out of range', [ + ('tid', 'tuple id offset value out of uint16 range', [ (0, -1), (0, 2**256), (0, 0xffff + 1), (0, 0xffffffff), (0, 65536), ]), - ('oid', OverflowError, 'uint32 value out of range', [ + ('oid', 'value out of uint32 range', [ 2 ** 32, -1, ]), ] - for typname, errcls, errmsg, data in cases: + for typname, errmsg, data in cases: stmt = await self.con.prepare("SELECT $1::" + typname) for sample in data: with self.subTest(sample=sample, typname=typname): - with self.assertRaisesRegex(errcls, errmsg): + full_errmsg = ( + r'invalid input for query argument \$1:.*' + errmsg) + + with self.assertRaisesRegex( + asyncpg.DataError, full_errmsg): await stmt.fetchval(sample) async def test_arrays(self): @@ -733,37 +739,39 @@ class SomeContainer: def __contains__(self, item): return False - with self.assertRaisesRegex(TypeError, + with self.assertRaisesRegex(asyncpg.DataError, 'sized iterable container expected'): result = await self.con.fetchval("SELECT $1::int[]", SomeContainer()) - with self.assertRaisesRegex(ValueError, 'dimensions'): + with self.assertRaisesRegex(asyncpg.DataError, 'dimensions'): await self.con.fetchval( "SELECT $1::int[]", [[[[[[[1]]]]]]]) - with self.assertRaisesRegex(ValueError, 'non-homogeneous'): + with self.assertRaisesRegex(asyncpg.DataError, 'non-homogeneous'): await self.con.fetchval( "SELECT $1::int[]", [1, [1]]) - with self.assertRaisesRegex(ValueError, 'non-homogeneous'): + with self.assertRaisesRegex(asyncpg.DataError, 'non-homogeneous'): await self.con.fetchval( "SELECT $1::int[]", [[1], 1, [2]]) - with self.assertRaisesRegex(ValueError, 'invalid array element'): + with self.assertRaisesRegex(asyncpg.DataError, + 'invalid array element'): await self.con.fetchval( "SELECT $1::int[]", [1, 't', 2]) - with self.assertRaisesRegex(ValueError, 'invalid array element'): + with self.assertRaisesRegex(asyncpg.DataError, + 'invalid array element'): await self.con.fetchval( "SELECT $1::int[]", [[1], ['t'], [2]]) - with self.assertRaisesRegex(TypeError, + with self.assertRaisesRegex(asyncpg.DataError, 'sized iterable container expected'): await self.con.fetchval( "SELECT $1::int[]", @@ -887,11 +895,11 @@ async def test_range_types(self): self.assertEqual(result, expected) with self.assertRaisesRegex( - TypeError, 'list, tuple or Range object expected'): + asyncpg.DataError, 'list, tuple or Range object expected'): await self.con.fetch("SELECT $1::int4range", 'aa') with self.assertRaisesRegex( - ValueError, 'expected 0, 1 or 2 elements'): + asyncpg.DataError, 'expected 0, 1 or 2 elements'): await self.con.fetch("SELECT $1::int4range", (0, 2, 3)) cases = [(asyncpg.Range(0, 1), asyncpg.Range(0, 1), 1), @@ -933,7 +941,8 @@ async def test_extra_codec_alias(self): self.assertEqual(res, {'foo': '2', 'bar': '3'}) - with self.assertRaisesRegex(ValueError, 'null value not allowed'): + with self.assertRaisesRegex(asyncpg.DataError, + 'null value not allowed'): await self.con.fetchval(''' SELECT $1::hstore AS result ''', {None: '1'})