Skip to content

Commit f06960a

Browse files
committed
Expose structured error fields as exception attributes
Exception instances can now have the following attributes: * severity * sqlstate * message * detail * hint * position * query * internal_position * internal_query * context * schema_name * table_name * column_name * data_type_name * constraint_name * server_source_filename * server_source_line * server_source_function Refer to [1] for the meaning of these fields. [1] https://www.postgresql.org/docs/current/static/protocol-error-fields.html
1 parent ccfd5fe commit f06960a

File tree

7 files changed

+75
-15
lines changed

7 files changed

+75
-15
lines changed

asyncpg/exceptions.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,25 @@
33

44
class ErrorMeta(type):
55
_error_map = {}
6+
_field_map = {
7+
'S': 'severity',
8+
'C': 'sqlstate',
9+
'M': 'message',
10+
'D': 'detail',
11+
'H': 'hint',
12+
'P': 'position',
13+
'p': 'internal_position',
14+
'q': 'internal_query',
15+
'W': 'context',
16+
's': 'schema_name',
17+
't': 'table_name',
18+
'c': 'column_name',
19+
'd': 'data_type_name',
20+
'n': 'constraint_name',
21+
'F': 'server_source_filename',
22+
'L': 'server_source_line',
23+
'R': 'server_source_function'
24+
}
625

726
def __new__(mcls, name, bases, dct):
827
global __all__
@@ -11,6 +30,10 @@ def __new__(mcls, name, bases, dct):
1130
if cls.__module__ == 'asyncpg.exceptions':
1231
__all__ += (name,)
1332

33+
if name == 'Error':
34+
for f in mcls._field_map.values():
35+
setattr(cls, f, None)
36+
1437
code = dct.get('code')
1538
if code is not None:
1639
mcls._error_map[code] = cls
@@ -23,7 +46,33 @@ def get_error_for_code(mcls, code):
2346

2447

2548
class Error(Exception, metaclass=ErrorMeta):
26-
pass
49+
def __str__(self):
50+
msg = self.message
51+
if self.detail:
52+
msg += '\nDETAIL: {}'.format(self.detail)
53+
if self.hint:
54+
msg += '\nHINT: {}'.format(self.hint)
55+
56+
return msg
57+
58+
@classmethod
59+
def new(cls, fields, query=None):
60+
errcode = fields.get('C')
61+
mcls = cls.__class__
62+
exccls = mcls.get_error_for_code(errcode)
63+
mapped = {
64+
'query': query
65+
}
66+
67+
for k, v in fields.items():
68+
field = mcls._field_map.get(k)
69+
if field:
70+
mapped[field] = v
71+
72+
e = exccls(mapped.get('message'))
73+
e.__dict__.update(mapped)
74+
75+
return e
2776

2877

2978
class FatalError(Error):

asyncpg/protocol/coreproto.pxd

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,6 @@ cdef class CoreProtocol:
8989
str _encoding
9090

9191
####### Connection State:
92-
93-
ConnectionSettings _settings
94-
9592
int _backend_pid
9693
int _backend_secret
9794

asyncpg/protocol/prepared_stmt.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
cdef class PreparedStatementState:
22
cdef:
33
readonly str name
4+
readonly str query
45
list row_desc
56
list parameters_desc
67

asyncpg/protocol/prepared_stmt.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ cdef class Record:
5151

5252
cdef class PreparedStatementState:
5353

54-
def __cinit__(self, str name, BaseProtocol protocol):
54+
def __cinit__(self, str name, str query, BaseProtocol protocol):
5555
self.name = name
56+
self.query = query
5657
self.protocol = protocol
5758
self.settings = protocol._settings
5859
self.row_desc = self.parameters_desc = None

asyncpg/protocol/protocol.pxd

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@ cdef class BaseProtocol(CoreProtocol):
2929

3030
cdef:
3131
object _loop
32-
33-
object _connect_waiter
34-
object _waiter
3532
object _address
3633
tuple _hash
34+
ConnectionSettings _settings
35+
str _last_query
36+
object _connect_waiter
37+
object _waiter
3738

3839
ProtocolState _state
3940

asyncpg/protocol/protocol.pyx

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ cdef class BaseProtocol(CoreProtocol):
5858
self._address = address
5959
self._hash = (self._address, self._database)
6060
self._settings = ConnectionSettings(self._hash)
61-
61+
self._last_query = None
6262
self._connect_waiter = connect_waiter
6363
self._waiter = None
6464
self._state = STATE_NOT_CONNECTED
@@ -74,6 +74,7 @@ cdef class BaseProtocol(CoreProtocol):
7474
def query(self, query):
7575
self._start_state(STATE_QUERY)
7676
self._waiter = self._create_future()
77+
self._last_query = query
7778
self._query(query)
7879
return self._waiter
7980

@@ -85,7 +86,7 @@ cdef class BaseProtocol(CoreProtocol):
8586
if self._prepared_stmt is not None:
8687
raise RuntimeError('another prepared statement is set')
8788

88-
self._prepared_stmt = PreparedStatementState(name, self)
89+
self._prepared_stmt = PreparedStatementState(name, query, self)
8990

9091
self._waiter = self._create_future()
9192
self._parse(name, query)
@@ -99,6 +100,8 @@ cdef class BaseProtocol(CoreProtocol):
99100
self._start_state(STATE_EXECUTE)
100101
self._prepared_stmt = <PreparedStatementState>state
101102

103+
self._last_query = self._prepared_stmt.query
104+
102105
self._bind(
103106
"",
104107
state.name,
@@ -156,11 +159,8 @@ cdef class BaseProtocol(CoreProtocol):
156159

157160
if result.status == PGRES_FATAL_ERROR:
158161
self._prepared_stmt = None
159-
msg = '\n'.join(['{}: {}'.format(k, v)
160-
for k, v in result.err_fields.items()])
161-
exc_cls = exceptions.ErrorMeta.get_error_for_code(
162-
result.err_fields.get('C'))
163-
exc = exc_cls(msg)
162+
exc = exceptions.Error.new(result.err_fields,
163+
query=self._last_query)
164164
waiter.set_exception(exc)
165165
self._state = STATE_READY
166166
self._waiter = None

tests/test_exceptions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,14 @@ class TestExceptions(tb.ConnectedTestCase):
77
def test_exceptions_exported(self):
88
self.assertTrue(hasattr(asyncpg, 'ConnectionError'))
99
self.assertIn('ConnectionError', asyncpg.__all__)
10+
11+
async def test_exceptions_unpacking(self):
12+
with self.assertRaises(asyncpg.Error):
13+
try:
14+
await self.con.execute('SELECT * FROM _nonexistent_')
15+
except asyncpg.Error as e:
16+
self.assertEqual(e.sqlstate, '42P01')
17+
self.assertEqual(e.position, '15')
18+
self.assertEqual(e.query, 'SELECT * FROM _nonexistent_')
19+
self.assertIsNotNone(e.severity)
20+
raise

0 commit comments

Comments
 (0)