-
Notifications
You must be signed in to change notification settings - Fork 690
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added the psycopg event broker (#917)
- Loading branch information
Showing
4 changed files
with
195 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
from __future__ import annotations | ||
|
||
from collections.abc import AsyncGenerator, Mapping | ||
from contextlib import AsyncExitStack, asynccontextmanager | ||
from logging import Logger | ||
from typing import TYPE_CHECKING, Any, NoReturn | ||
from urllib.parse import urlunparse | ||
|
||
import attrs | ||
from anyio import ( | ||
EndOfStream, | ||
create_memory_object_stream, | ||
move_on_after, | ||
) | ||
from anyio.abc import TaskStatus | ||
from anyio.streams.memory import MemoryObjectSendStream | ||
from attr.validators import instance_of | ||
from psycopg import AsyncConnection, InterfaceError | ||
|
||
from .._events import Event | ||
from .._exceptions import SerializationError | ||
from .._validators import positive_number | ||
from .base import BaseExternalEventBroker | ||
|
||
if TYPE_CHECKING: | ||
from sqlalchemy.ext.asyncio import AsyncEngine | ||
|
||
|
||
def convert_options(value: Mapping[str, Any]) -> dict[str, Any]: | ||
return dict(value, autocommit=True) | ||
|
||
|
||
@attrs.define(eq=False) | ||
class PsycopgEventBroker(BaseExternalEventBroker): | ||
""" | ||
An asynchronous, psycopg_ based event broker that uses a PostgreSQL server to | ||
broadcast events using its ``NOTIFY`` mechanism. | ||
.. psycopg: https://pypi.org/project/psycopg/ | ||
:param conninfo: a libpq connection string (e.g. | ||
``postgres://user:pass@host:port/dbname``) | ||
:param channel: the ``NOTIFY`` channel to use | ||
:param max_idle_time: maximum time (in seconds) to let the connection go idle, | ||
before sending a ``SELECT 1`` query to prevent a connection timeout | ||
""" | ||
|
||
conninfo: str = attrs.field(validator=instance_of(str)) | ||
options: Mapping[str, Any] = attrs.field( | ||
factory=dict, converter=convert_options, validator=instance_of(Mapping) | ||
) | ||
channel: str = attrs.field( | ||
kw_only=True, default="apscheduler", validator=instance_of(str) | ||
) | ||
max_idle_time: float = attrs.field( | ||
kw_only=True, default=10, validator=[instance_of((int, float)), positive_number] | ||
) | ||
|
||
_send: MemoryObjectSendStream[str] = attrs.field(init=False) | ||
|
||
@classmethod | ||
def from_async_sqla_engine( | ||
cls, | ||
engine: AsyncEngine, | ||
options: Mapping[str, Any] | None = None, | ||
**kwargs: Any, | ||
) -> PsycopgEventBroker: | ||
""" | ||
Create a new psycopg event broker from a SQLAlchemy engine. | ||
The engine will only be used to create the appropriate options for | ||
:meth:`psycopg.AsyncConnection.connect`. | ||
:param engine: an asynchronous SQLAlchemy engine using asyncpg as the driver | ||
:type engine: ~sqlalchemy.ext.asyncio.AsyncEngine | ||
:param options: extra keyword arguments passed to :func:`asyncpg.connect` (will | ||
override any automatically generated arguments based on the engine) | ||
:param kwargs: keyword arguments to pass to the initializer of this class | ||
:return: the newly created event broker | ||
""" | ||
if engine.dialect.driver != "psycopg": | ||
raise ValueError( | ||
f'The driver in the engine must be "psycopg" (current: ' | ||
f"{engine.dialect.driver})" | ||
) | ||
|
||
conninfo = urlunparse( | ||
[ | ||
"postgres", | ||
engine.url.username, | ||
engine.url.password, | ||
engine.url.host, | ||
engine.url.database, | ||
] | ||
) | ||
opts = dict(options, autocommit=True) | ||
return cls(conninfo, opts, **kwargs) | ||
|
||
@property | ||
def _temporary_failure_exceptions(self) -> tuple[type[Exception], ...]: | ||
return OSError, InterfaceError | ||
|
||
@asynccontextmanager | ||
async def _connect(self) -> AsyncGenerator[AsyncConnection, None]: | ||
async for attempt in self._retry(): | ||
with attempt: | ||
conn = await AsyncConnection.connect(self.conninfo, **self.options) | ||
try: | ||
yield conn | ||
finally: | ||
await conn.close() | ||
|
||
async def start(self, exit_stack: AsyncExitStack, logger: Logger) -> None: | ||
await super().start(exit_stack, logger) | ||
await self._task_group.start(self._listen_notifications) | ||
exit_stack.callback(self._task_group.cancel_scope.cancel) | ||
self._send = await self._task_group.start(self._publish_notifications) | ||
await exit_stack.enter_async_context(self._send) | ||
|
||
async def _listen_notifications(self, *, task_status: TaskStatus[None]) -> None: | ||
task_started_sent = False | ||
while True: | ||
async with self._connect() as conn: | ||
try: | ||
await conn.execute(f"LISTEN {self.channel}") | ||
|
||
if not task_started_sent: | ||
task_status.started() | ||
task_started_sent = True | ||
|
||
self._logger.debug("Listen connection established") | ||
async for notify in conn.notifies(): | ||
event = self.reconstitute_event_str(notify.payload) | ||
await self.publish_local(event) | ||
except InterfaceError as exc: | ||
self._logger.error("Connection error: %s", exc) | ||
|
||
async def _publish_notifications( | ||
self, *, task_status: TaskStatus[MemoryObjectSendStream[str]] | ||
) -> NoReturn: | ||
send, receive = create_memory_object_stream[str](100) | ||
task_started_sent = False | ||
with receive: | ||
while True: | ||
async with self._connect() as conn: | ||
if not task_started_sent: | ||
task_status.started(send) | ||
task_started_sent = True | ||
|
||
self._logger.debug("Publish connection established") | ||
notification: str | None = None | ||
while True: | ||
with move_on_after(self.max_idle_time): | ||
try: | ||
notification = await receive.receive() | ||
except EndOfStream: | ||
return | ||
|
||
if notification: | ||
await conn.execute( | ||
"SELECT pg_notify(%t, %t)", [self.channel, notification] | ||
) | ||
else: | ||
await conn.execute("SELECT 1") | ||
|
||
async def publish(self, event: Event) -> None: | ||
notification = self.generate_notification_str(event) | ||
if len(notification) > 7999: | ||
raise SerializationError( | ||
"Serialized event object exceeds 7999 bytes in size" | ||
) | ||
|
||
await self._send.send(notification) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters