diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd index 34c7c712..fa657231 100644 --- a/asyncpg/protocol/coreproto.pxd +++ b/asyncpg/protocol/coreproto.pxd @@ -80,6 +80,7 @@ cdef class CoreProtocol: ConnectionStatus con_status ProtocolState state + ProtocolState cancelled_from_state TransactionStatus xact_status str encoding diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index da96c412..c978a675 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -33,6 +33,7 @@ cdef class CoreProtocol: self.con_params = con_params self.con_status = CONNECTION_BAD self.state = PROTOCOL_IDLE + self.cancelled_from_state = PROTOCOL_IDLE self.xact_status = PQTRANS_IDLE self.encoding = 'utf-8' # type of `scram` is `SCRAMAuthentcation` @@ -835,11 +836,13 @@ cdef class CoreProtocol: pass else: self.state = new_state + self.cancelled_from_state = PROTOCOL_IDLE elif new_state == PROTOCOL_FAILED: self.state = PROTOCOL_FAILED elif new_state == PROTOCOL_CANCELLED: + self.cancelled_from_state = self.state self.state = PROTOCOL_CANCELLED elif new_state == PROTOCOL_TERMINATING: diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index acce4e9f..be1b65c7 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -851,39 +851,46 @@ cdef class BaseProtocol(CoreProtocol): waiter.set_exception(exc) return + state = self.state + if state == PROTOCOL_CANCELLED: + state = self.cancelled_from_state + if state == PROTOCOL_IDLE: + waiter.set_exception(asyncio.CancelledError()) + return + try: - if self.state == PROTOCOL_AUTH: + if state == PROTOCOL_AUTH: self._on_result__connect(waiter) - elif self.state == PROTOCOL_PREPARE: + elif state == PROTOCOL_PREPARE: self._on_result__prepare(waiter) - elif self.state == PROTOCOL_BIND_EXECUTE: + elif state == PROTOCOL_BIND_EXECUTE: self._on_result__bind_and_exec(waiter) - elif self.state == PROTOCOL_BIND_EXECUTE_MANY: + elif state == PROTOCOL_BIND_EXECUTE_MANY: self._on_result__bind_and_exec(waiter) - elif self.state == PROTOCOL_EXECUTE: + elif state == PROTOCOL_EXECUTE: self._on_result__bind_and_exec(waiter) - elif self.state == PROTOCOL_BIND: + elif state == PROTOCOL_BIND: self._on_result__bind(waiter) - elif self.state == PROTOCOL_CLOSE_STMT_PORTAL: + elif state == PROTOCOL_CLOSE_STMT_PORTAL: self._on_result__close_stmt_or_portal(waiter) - elif self.state == PROTOCOL_SIMPLE_QUERY: + elif state == PROTOCOL_SIMPLE_QUERY: self._on_result__simple_query(waiter) - elif (self.state == PROTOCOL_COPY_OUT_DATA or - self.state == PROTOCOL_COPY_OUT_DONE): + elif (state == PROTOCOL_COPY_OUT_DATA or + state == PROTOCOL_COPY_OUT_DONE): self._on_result__copy_out(waiter) - elif self.state == PROTOCOL_COPY_IN_DATA: + elif state == PROTOCOL_COPY_IN_DATA: self._on_result__copy_in(waiter) - elif self.state == PROTOCOL_TERMINATING: + elif state == PROTOCOL_TERMINATING: # We are waiting for the connection to drop, so # ignore any stray results at this point. pass