Skip to content

Commit

Permalink
protocol: Use try-finally explicitly every time we create a waiter
Browse files Browse the repository at this point in the history
  • Loading branch information
1st1 committed Sep 15, 2017
1 parent f29de23 commit 50edd8c
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 75 deletions.
7 changes: 6 additions & 1 deletion asyncpg/exceptions/_base.py
Expand Up @@ -10,7 +10,8 @@


__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage')
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
'InternalClientError')


def _is_asyncpg_class(cls):
Expand Down Expand Up @@ -190,6 +191,10 @@ def __init__(self, msg, *, detail=None, hint=None):
Warning.__init__(self, msg)


class InternalClientError(Exception):
pass


class PostgresLogMessage(PostgresMessage):
"""A base class for non-error server messages."""

Expand Down
18 changes: 9 additions & 9 deletions asyncpg/protocol/coreproto.pyx
Expand Up @@ -125,7 +125,7 @@ cdef class CoreProtocol:
self.buffer.consume_message()

else:
raise RuntimeError(
raise apg_exc.InternalClientError(
'protocol is in an unknown state {}'.format(state))

except Exception as ex:
Expand Down Expand Up @@ -472,7 +472,7 @@ cdef class CoreProtocol:

if ASYNCPG_DEBUG:
if buf.get_message_type() != b'D':
raise RuntimeError(
raise apg_exc.InternalClientError(
'_parse_data_msgs: first message is not "D"')

if self._discard_data:
Expand All @@ -484,7 +484,7 @@ cdef class CoreProtocol:

if ASYNCPG_DEBUG:
if type(self.result) is not list:
raise RuntimeError(
raise apg_exc.InternalClientError(
'_parse_data_msgs: result is not a list, but {!r}'.
format(self.result))

Expand Down Expand Up @@ -639,11 +639,11 @@ cdef class CoreProtocol:
cdef _set_state(self, ProtocolState new_state):
if new_state == PROTOCOL_IDLE:
if self.state == PROTOCOL_FAILED:
raise RuntimeError(
raise apg_exc.InternalClientError(
'cannot switch to "idle" state; '
'protocol is in the "failed" state')
elif self.state == PROTOCOL_IDLE:
raise RuntimeError(
raise apg_exc.InternalClientError(
'protocol is already in the "idle" state')
else:
self.state = new_state
Expand Down Expand Up @@ -671,18 +671,18 @@ cdef class CoreProtocol:
self.state = new_state

elif self.state == PROTOCOL_FAILED:
raise RuntimeError(
raise apg_exc.InternalClientError(
'cannot switch to state {}; '
'protocol is in the "failed" state'.format(new_state))
else:
raise RuntimeError(
raise apg_exc.InternalClientError(
'cannot switch to state {}; '
'another operation ({}) is in progress'.format(
new_state, self.state))

cdef _ensure_connected(self):
if self.con_status != CONNECTION_OK:
raise RuntimeError('not connected')
raise apg_exc.InternalClientError('not connected')

cdef WriteBuffer _build_bind_message(self, str portal_name,
str stmt_name,
Expand All @@ -707,7 +707,7 @@ cdef class CoreProtocol:
WriteBuffer outbuf

if self.con_status != CONNECTION_BAD:
raise RuntimeError('already connected')
raise apg_exc.InternalClientError('already connected')

self._set_state(PROTOCOL_AUTH)
self.con_status = CONNECTION_STARTED
Expand Down
1 change: 1 addition & 0 deletions asyncpg/protocol/protocol.pxd
Expand Up @@ -51,6 +51,7 @@ cdef class BaseProtocol(CoreProtocol):
cdef _get_timeout_impl(self, timeout)
cdef _check_state(self)
cdef _new_waiter(self, timeout)
cdef _coreproto_error(self)

cdef _on_result__connect(self, object waiter)
cdef _on_result__prepare(self, object waiter)
Expand Down
175 changes: 111 additions & 64 deletions asyncpg/protocol/protocol.pyx
Expand Up @@ -156,11 +156,16 @@ cdef class BaseProtocol(CoreProtocol):
self._check_state()
timeout = self._get_timeout_impl(timeout)

self._prepare(stmt_name, query)
self.last_query = query
self.statement = PreparedStatementState(stmt_name, query, self)

return await self._new_waiter(timeout)
waiter = self._new_waiter(timeout)
try:
self._prepare(stmt_name, query) # network op
self.last_query = query
self.statement = PreparedStatementState(stmt_name, query, self)
except Exception as ex:
waiter.set_exception(ex)
self._coreproto_error()
finally:
return await waiter

async def bind_execute(self, PreparedStatementState state, args,
str portal_name, int limit, return_extra,
Expand All @@ -174,19 +179,25 @@ cdef class BaseProtocol(CoreProtocol):

self._check_state()
timeout = self._get_timeout_impl(timeout)
args_buf = state._encode_bind_msg(args)

self._bind_execute(
portal_name,
state.name,
state._encode_bind_msg(args),
limit)

self.last_query = state.query
self.statement = state
self.return_extra = return_extra
self.queries_count += 1

return await self._new_waiter(timeout)
waiter = self._new_waiter(timeout)
try:
self._bind_execute(
portal_name,
state.name,
args_buf,
limit) # network op

self.last_query = state.query
self.statement = state
self.return_extra = return_extra
self.queries_count += 1
except Exception as ex:
waiter.set_exception(ex)
self._coreproto_error()
finally:
return await waiter

async def bind_execute_many(self, PreparedStatementState state, args,
str portal_name, timeout):
Expand All @@ -207,18 +218,21 @@ cdef class BaseProtocol(CoreProtocol):
arg_bufs = iter(data_gen)

waiter = self._new_waiter(timeout)
try:
self._bind_execute_many(
portal_name,
state.name,
arg_bufs) # network op

self._bind_execute_many(
portal_name,
state.name,
arg_bufs)

self.last_query = state.query
self.statement = state
self.return_extra = False
self.queries_count += 1

return await waiter
self.last_query = state.query
self.statement = state
self.return_extra = False
self.queries_count += 1
except Exception as ex:
waiter.set_exception(ex)
self._coreproto_error()
finally:
return await waiter

async def bind(self, PreparedStatementState state, args,
str portal_name, timeout):
Expand All @@ -231,16 +245,22 @@ cdef class BaseProtocol(CoreProtocol):

self._check_state()
timeout = self._get_timeout_impl(timeout)
args_buf = state._encode_bind_msg(args)

self._bind(
portal_name,
state.name,
state._encode_bind_msg(args))

self.last_query = state.query
self.statement = state

return await self._new_waiter(timeout)
waiter = self._new_waiter(timeout)
try:
self._bind(
portal_name,
state.name,
args_buf) # network op

self.last_query = state.query
self.statement = state
except Exception as ex:
waiter.set_exception(ex)
self._coreproto_error()
finally:
return await waiter

async def execute(self, PreparedStatementState state,
str portal_name, int limit, return_extra,
Expand All @@ -255,16 +275,21 @@ cdef class BaseProtocol(CoreProtocol):
self._check_state()
timeout = self._get_timeout_impl(timeout)

self._execute(
portal_name,
limit)

self.last_query = state.query
self.statement = state
self.return_extra = return_extra
self.queries_count += 1

return await self._new_waiter(timeout)
waiter = self._new_waiter(timeout)
try:
self._execute(
portal_name,
limit) # network op

self.last_query = state.query
self.statement = state
self.return_extra = return_extra
self.queries_count += 1
except Exception as ex:
waiter.set_exception(ex)
self._coreproto_error()
finally:
return await waiter

async def query(self, query, timeout):
if self.cancel_waiter is not None:
Expand All @@ -279,11 +304,16 @@ cdef class BaseProtocol(CoreProtocol):
# prepare/bind/execute methods.
timeout = self._get_timeout(timeout)

self._simple_query(query)
self.last_query = query
self.queries_count += 1

return await self._new_waiter(timeout)
waiter = self._new_waiter(timeout)
try:
self._simple_query(query) # network op
self.last_query = query
self.queries_count += 1
except Exception as ex:
waiter.set_exception(ex)
self._coreproto_error()
finally:
return await waiter

async def copy_out(self, copy_stmt, sink, timeout):
if self.cancel_waiter is not None:
Expand Down Expand Up @@ -378,7 +408,7 @@ cdef class BaseProtocol(CoreProtocol):
for codec in codecs:
if (not codec.has_encoder() or
codec.format != PG_FORMAT_BINARY):
raise RuntimeError(
raise apg_exc.InternalClientError(
'no binary format encoder for '
'type {} (OID {})'.format(codec.name, codec.oid))

Expand Down Expand Up @@ -439,7 +469,7 @@ cdef class BaseProtocol(CoreProtocol):
except TimeoutError:
raise
else:
raise RuntimeError('TimoutError was not raised')
raise apg_exc.InternalClientError('TimoutError was not raised')

except Exception as e:
self._write_copy_fail_msg(str(e))
Expand All @@ -460,16 +490,22 @@ cdef class BaseProtocol(CoreProtocol):
self.cancel_sent_waiter = None

self._check_state()
timeout = self._get_timeout_impl(timeout)

if state.refs != 0:
raise RuntimeError(
raise apg_exc.InternalClientError(
'cannot close prepared statement; refs == {} != 0'.format(
state.refs))

self._close(state.name, False)
state.closed = True
return await self._new_waiter(timeout)
timeout = self._get_timeout_impl(timeout)
waiter = self._new_waiter(timeout)
try:
self._close(state.name, False) # network op
state.closed = True
except Exception as ex:
waiter.set_exception(ex)
self._coreproto_error()
finally:
return await waiter

def is_closed(self):
return self.closing
Expand Down Expand Up @@ -579,6 +615,17 @@ cdef class BaseProtocol(CoreProtocol):
raise apg_exc.InterfaceError(
'cannot perform operation: another operation is in progress')

cdef _coreproto_error(self):
try:
if self.waiter is not None:
if not self.waiter.done():
raise apg_exc.InternalClientError(
'waiter is not done while handling critical '
'protocol error')
self.waiter = None
finally:
self.abort()

cdef _new_waiter(self, timeout):
if self.waiter is not None:
raise apg_exc.InterfaceError(
Expand All @@ -596,7 +643,7 @@ cdef class BaseProtocol(CoreProtocol):
cdef _on_result__prepare(self, object waiter):
if ASYNCPG_DEBUG:
if self.statement is None:
raise RuntimeError(
raise apg_exc.InternalClientError(
'_on_result__prepare: statement is None')

if self.result_param_desc is not None:
Expand Down Expand Up @@ -643,7 +690,7 @@ cdef class BaseProtocol(CoreProtocol):
cdef _decode_row(self, const char* buf, ssize_t buf_len):
if ASYNCPG_DEBUG:
if self.statement is None:
raise RuntimeError(
raise apg_exc.InternalClientError(
'_decode_row: statement is None')

return self.statement._decode_row(buf, buf_len)
Expand All @@ -654,13 +701,13 @@ cdef class BaseProtocol(CoreProtocol):

if ASYNCPG_DEBUG:
if waiter is None:
raise RuntimeError('_on_result: waiter is None')
raise apg_exc.InternalClientError('_on_result: waiter is None')

if waiter.cancelled():
return

if waiter.done():
raise RuntimeError('_on_result: waiter is done')
raise apg_exc.InternalClientError('_on_result: waiter is done')

if self.result_type == RESULT_FAILED:
if isinstance(self.result, dict):
Expand Down Expand Up @@ -704,7 +751,7 @@ cdef class BaseProtocol(CoreProtocol):
self._on_result__copy_in(waiter)

else:
raise RuntimeError(
raise apg_exc.InternalClientError(
'got result for unknown protocol state {}'.
format(self.state))

Expand Down

0 comments on commit 50edd8c

Please sign in to comment.