Skip to content

Commit

Permalink
Fix "Unable to detect disconnect when using NOTIFY/LISTEN", Closes #249
Browse files Browse the repository at this point in the history
  • Loading branch information
gjcarneiro committed Apr 5, 2021
1 parent 459cbe3 commit cdcc484
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 7 deletions.
15 changes: 11 additions & 4 deletions aiopg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
import psycopg2.extras

from .log import logger
from .utils import _ContextManager, create_completed_future, get_running_loop
from .utils import (
ClosableQueue,
_ContextManager,
create_completed_future,
get_running_loop,
)

TIMEOUT = 60.0

Expand Down Expand Up @@ -762,6 +767,7 @@ def __init__(
self._writing = False
self._echo = echo
self._notifies = asyncio.Queue() # type: ignore
self._notifies_proxy = ClosableQueue(self._notifies)
self._weakref = weakref.ref(self)
self._loop.add_reader(
self._fileno, self._ready, self._weakref # type: ignore
Expand Down Expand Up @@ -806,6 +812,7 @@ def _ready(weak_self: "weakref.ref[Any]") -> None:
# chain exception otherwise
exc2.__cause__ = exc
exc = exc2
self.notifies.close(exc)
if waiter is not None and not waiter.done():
waiter.set_exception(exc)
else:
Expand Down Expand Up @@ -1182,9 +1189,9 @@ def __del__(self) -> None:
self._loop.call_exception_handler(context)

@property
def notifies(self) -> asyncio.Queue: # type: ignore
"""Return notification queue."""
return self._notifies
def notifies(self) -> ClosableQueue:
"""Return notification queue (an asyncio.Queue -like object)."""
return self._notifies_proxy

async def _get_oids(self) -> Tuple[Any, Any]:
cursor = await self.cursor()
Expand Down
43 changes: 43 additions & 0 deletions aiopg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,46 @@ async def __anext__(self) -> _TObj:
finally:
self._obj = None
raise


class ClosableQueue:
"""
Proxy object for an asyncio.Queue that is "closable"
When the ClosableQueue is closed, with an exception object as parameter,
subsequent or ongoing attempts to read from the queue will result in that
exception being result in that exception being raised.
"""

def __init__(self, queue: asyncio.Queue):
self._queue = queue
self._close_exception = asyncio.Future() # type: asyncio.Future[None]

def close(self, exception: Exception) -> None:
if not self._close_exception.done():
self._close_exception.set_exception(exception)

async def get(self) -> Any:
loop = get_running_loop()
get = loop.create_task(self._queue.get())

_, pending = await asyncio.wait(
[get, self._close_exception],
return_when=asyncio.FIRST_COMPLETED,
)

if get.done():
return get.result()
get.cancel()
self._close_exception.result()

def empty(self) -> bool:
return self._queue.empty()

def qsize(self) -> int:
return self._queue.qsize()

def get_nowait(self) -> Any:
if self._close_exception.done():
self._close_exception.result()
return self._queue.get_nowait()
8 changes: 7 additions & 1 deletion docs/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ Example::

.. attribute:: notifies

An :class:`asyncio.Queue` instance for received notifications.
An instance of an :class:`asyncio.Queue` subclass for received notifications.

.. seealso:: :ref:`aiopg-core-notifications`

Expand Down Expand Up @@ -983,6 +983,12 @@ Receiving part should establish listening on notification channel by
`LISTEN`_ call and wait notification events from
:attr:`Connection.notifies` queue.

.. note::

calling `await connection.notifies.get()` may raise a psycopg2 exception
if the underlying connection gets disconnected while you're waiting for
notifications.

There is usage example:

.. literalinclude:: ../examples/notify.py
Expand Down
10 changes: 8 additions & 2 deletions examples/notify.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio

import psycopg2

import aiopg

dsn = "dbname=aiopg user=aiopg password=passwd host=127.0.0.1"
Expand All @@ -19,8 +21,12 @@ async def listen(conn):
async with conn.cursor() as cur:
await cur.execute("LISTEN channel")
while True:
msg = await conn.notifies.get()
if msg.payload == "finish":
try:
msg = await conn.notifies.get()
except psycopg2.Error as ex:
print("ERROR: ", ex)
return
if msg.payload == 'finish':
return
else:
print("Receive <-", msg.payload)
Expand Down
35 changes: 35 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,3 +593,38 @@ async def test_connection_on_server_restart(connect, pg_server, docker):
delay *= 2
else:
pytest.fail("Cannot connect to the restarted server")


async def test_connection_notify_on_server_restart(connect, pg_server, docker,
loop):
conn = await connect()

async def read_notifies():
while True:
await conn.notifies.get()

reader = loop.create_task(read_notifies())
await asyncio.sleep(0.1)

docker.restart(container=pg_server['Id'])

try:
with pytest.raises(psycopg2.OperationalError):
await asyncio.wait_for(reader, 10)
finally:
conn.close()
reader.cancel()

# Wait for postgres to be up and running again before moving on
# so as the restart won't affect other tests
delay = 0.001
for i in range(100):
try:
conn = await connect()
conn.close()
break
except psycopg2.Error:
time.sleep(delay)
delay *= 2
else:
pytest.fail("Cannot connect to the restarted server")
52 changes: 52 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import asyncio

import pytest

from aiopg.utils import ClosableQueue


async def test_closable_queue_noclose():
the_queue = asyncio.Queue()
queue = ClosableQueue(the_queue)
assert queue.empty()
assert queue.qsize() == 0

await the_queue.put(1)
assert not queue.empty()
assert queue.qsize() == 1
v = await queue.get()
assert v == 1

await the_queue.put(2)
v = queue.get_nowait()
assert v == 2


async def test_closable_queue_close(loop):
the_queue = asyncio.Queue()
queue = ClosableQueue(the_queue)
v1 = None

async def read():
nonlocal v1
v1 = await queue.get()
await queue.get()

reader = loop.create_task(read())
await the_queue.put(1)
await asyncio.sleep(0.1)
assert v1 == 1

queue.close(RuntimeError("connection closed"))
with pytest.raises(RuntimeError) as excinfo:
await reader
assert excinfo.value.args == ("connection closed",)


async def test_closable_queue_close_get_nowait(loop):
the_queue = asyncio.Queue()
queue = ClosableQueue(the_queue)
queue.close(RuntimeError("connection closed"))
with pytest.raises(RuntimeError) as excinfo:
await queue.get_nowait()
assert excinfo.value.args == ("connection closed",)

0 comments on commit cdcc484

Please sign in to comment.