diff --git a/src/adcp/webhook_supervisor.py b/src/adcp/webhook_supervisor.py index 9c3a7e5ba..ddb219d47 100644 --- a/src/adcp/webhook_supervisor.py +++ b/src/adcp/webhook_supervisor.py @@ -214,6 +214,8 @@ async def send_mcp( result: Any = None, token: str | None = None, sequence_key: str | None = None, + breaker_key: str | None = None, + notification_type: str | None = None, ) -> WebhookDeliveryResult | None: ... diff --git a/src/adcp/webhook_supervisor_pg.py b/src/adcp/webhook_supervisor_pg.py new file mode 100644 index 000000000..74d90271f --- /dev/null +++ b/src/adcp/webhook_supervisor_pg.py @@ -0,0 +1,807 @@ +"""PostgreSQL-backed :class:`WebhookDeliverySupervisor` for multi-worker durability. + +Gives multi-instance AdCP sellers a shared circuit-breaker state and a +durable retry queue so webhook deliveries survive process crashes and are +never double-sent across concurrent workers. + +.. rubric:: REQUIRED: run_worker() + +Unlike :class:`~adcp.webhook_supervisor.InMemoryWebhookDeliverySupervisor`, +:meth:`PgWebhookDeliverySupervisor.send_mcp` does **not** deliver +immediately — it enqueues the job and returns ``None``. You **MUST** start +:meth:`run_worker` somewhere in your application; without it enqueued jobs +accumulate and are never sent: + +:: + + from psycopg_pool import AsyncConnectionPool + from adcp.webhook_supervisor_pg import PgWebhookDeliverySupervisor + + pool = AsyncConnectionPool("postgresql://...", min_size=4, max_size=20) + supervisor = PgWebhookDeliverySupervisor(pool=pool, sender=my_sender) + await supervisor.create_schema() # idempotent; call once per boot + + asyncio.create_task(supervisor.run_worker()) # REQUIRED before serve() + + serve(my_handler, webhook_supervisor=supervisor) + +.. rubric:: Observability + +``send_mcp`` always returns ``None``. For per-attempt outcomes configure a +:class:`~adcp.webhook_supervisor.DeliveryLogSink` or query the +``adcp_webhook_delivery_log`` table directly. + +.. rubric:: Cross-process guarantees + +* **Circuit breaker state** — shared via ``adcp_webhook_circuit_state`` (one + row per *breaker_key*). Half-open probe atomicity is enforced via + ``UPDATE … RETURNING`` so two concurrent workers can't both transition to + ``closed`` from a single success. +* **Sequence numbers** — the ``BIGSERIAL`` queue-row id is the sequence + number (monotonically increasing, crash-durable across restarts). +* **Worker leasing** — ``SELECT … FOR UPDATE SKIP LOCKED LIMIT 1`` prevents + double-delivery when multiple workers poll the same queue. The lock is + held for the duration of the HTTP send; a crashed worker automatically + releases the job (transaction rollback). + +.. rubric:: DDL for migration tools + +:: + + -- adcp_webhook_circuit_state + CREATE TABLE adcp_webhook_circuit_state ( + breaker_key TEXT COLLATE "C" NOT NULL PRIMARY KEY, + state TEXT NOT NULL DEFAULT 'closed', + failure_count INT NOT NULL DEFAULT 0, + success_count INT NOT NULL DEFAULT 0, + opened_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() + ); + + -- adcp_webhook_delivery_queue + CREATE TABLE adcp_webhook_delivery_queue ( + id BIGSERIAL PRIMARY KEY, + breaker_key TEXT NOT NULL, + url TEXT NOT NULL, + task_id TEXT NOT NULL, + task_type TEXT, + status_str TEXT NOT NULL DEFAULT 'pending', + result_json TEXT, + token TEXT, + sequence_key TEXT, + attempt_count INT NOT NULL DEFAULT 0, + max_attempts INT NOT NULL DEFAULT 3, + scheduled_at TIMESTAMPTZ NOT NULL DEFAULT now(), + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + idempotency_key TEXT, + sent_body BYTEA, + notification_type TEXT + ); + CREATE INDEX adcp_webhook_delivery_queue_work_idx + ON adcp_webhook_delivery_queue (status_str, scheduled_at) + WHERE status_str IN ('pending', 'retry'); + + -- adcp_webhook_delivery_log + CREATE TABLE adcp_webhook_delivery_log ( + id BIGSERIAL PRIMARY KEY, + queue_id BIGINT, + breaker_key TEXT NOT NULL, + url TEXT NOT NULL, + task_id TEXT NOT NULL, + sequence_key TEXT, + sequence_number BIGINT, + attempt_number INT NOT NULL, + max_attempts INT NOT NULL, + outcome TEXT NOT NULL, + http_status_code INT, + error_message TEXT, + response_time_ms INT NOT NULL DEFAULT 0, + occurred_at TIMESTAMPTZ NOT NULL, + will_retry BOOLEAN NOT NULL DEFAULT false, + next_retry_at TIMESTAMPTZ, + task_type TEXT, + payload_size_bytes INT, + notification_type TEXT + ); + CREATE INDEX adcp_webhook_delivery_log_queue_id_idx + ON adcp_webhook_delivery_log (queue_id); + CREATE INDEX adcp_webhook_delivery_log_task_id_idx + ON adcp_webhook_delivery_log (task_id); + +.. rubric:: Stranded-job note + +Jobs remain in ``status_str = 'pending'`` (via rollback) if a worker +process crashes mid-send. No separate recovery sweep is needed — the next +worker poll picks them up automatically. Jobs that raised a Python exception +on the final attempt are deleted from the queue and logged with +``outcome = 'failure'``. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import re +import time +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from adcp.types import GeneratedTaskStatus + from adcp.webhook_sender import WebhookDeliveryResult, WebhookSender + from adcp.webhook_supervisor import DeliveryLogSink + +try: + import psycopg # noqa: F401 + import psycopg_pool # noqa: F401 + + PG_AVAILABLE = True +except ImportError: + PG_AVAILABLE = False + +from adcp.webhook_supervisor import ( + CircuitBreakerPolicy, + DeliveryAttempt, + RetryPolicy, +) + +UTC = timezone.utc +logger = logging.getLogger(__name__) + +# Byte-level ASCII identifier guard — same rationale as PgReplayStore. +# str.islower() accepts non-ASCII Unicode letters (é, µ, etc.) which would +# format verbatim into SQL as a different table than configured. +_SAFE_IDENTIFIER_RE = re.compile(r"^[a-z_][a-z0-9_]{0,62}$") + +_INSTALL_HINT = ( + "PgWebhookDeliverySupervisor requires psycopg3 and psycopg-pool. " + "Install the 'pg' extra: `pip install 'adcp[pg]'` " + "(Poetry: `poetry add 'adcp[pg]'`)." +) + +DEFAULT_CIRCUIT_TABLE = "adcp_webhook_circuit_state" +DEFAULT_QUEUE_TABLE = "adcp_webhook_delivery_queue" +DEFAULT_LOG_TABLE = "adcp_webhook_delivery_log" + + +def _safe_identifier(name: str) -> str: + if not _SAFE_IDENTIFIER_RE.fullmatch(name): + raise ValueError( + f"Table name must match [a-z_][a-z0-9_]{{0,62}} (ASCII only), got {name!r}" + ) + return name + + +class PgWebhookDeliverySupervisor: + """Postgres-backed :class:`~adcp.webhook_supervisor.WebhookDeliverySupervisor`. + + Parameters + ---------- + pool: + A ``psycopg_pool.AsyncConnectionPool`` owned by the caller. Each + operation acquires a short-lived connection (or holds one for the + duration of a delivery). We don't open, own, or close the pool. + sender: + The underlying :class:`~adcp.webhook_sender.WebhookSender` used for + the actual HTTP-Signatures POST. Must be non-None. + retry: + Retry policy (default: 3 attempts, exponential backoff with jitter). + circuit: + Circuit-breaker tuning (default: 5 failures open, 60s recovery, + 2 successes close). + log_sink: + Optional :class:`~adcp.webhook_supervisor.DeliveryLogSink` called + after each attempt. Failures are swallowed. Use for custom + observability pipelines in addition to the built-in log table. + circuit_table / queue_table / log_table: + Override table names (ASCII only, 1–63 chars). Useful for multi- + tenant schemas. Defaults are ``adcp_webhook_circuit_state``, + ``adcp_webhook_delivery_queue``, and ``adcp_webhook_delivery_log``. + + Concurrency + ----------- + Safe to share across asyncio tasks in a single process. Multiple + processes/pods each running :meth:`run_worker` are explicitly supported + via ``FOR UPDATE SKIP LOCKED``. + """ + + def __init__( + self, + pool: Any, # psycopg_pool.AsyncConnectionPool; Any avoids import at runtime + sender: WebhookSender, + *, + retry: RetryPolicy | None = None, + circuit: CircuitBreakerPolicy | None = None, + log_sink: DeliveryLogSink | None = None, + circuit_table: str = DEFAULT_CIRCUIT_TABLE, + queue_table: str = DEFAULT_QUEUE_TABLE, + log_table: str = DEFAULT_LOG_TABLE, + ) -> None: + if not PG_AVAILABLE: + raise ImportError(_INSTALL_HINT) + if sender is None: + raise ValueError( + "PgWebhookDeliverySupervisor requires a non-None WebhookSender. " + "Construct one via WebhookSender.from_jwk(...) or " + "WebhookSender.from_pem(...) and pass it as the first positional argument." + ) + self._pool = pool + self._sender = sender + self._retry = retry or RetryPolicy() + self._circuit_policy = circuit or CircuitBreakerPolicy() + self._log_sink = log_sink + + ct = _safe_identifier(circuit_table) + qt = _safe_identifier(queue_table) + lt = _safe_identifier(log_table) + self._circuit_t = ct + self._queue_t = qt + self._log_t = lt + + # Pre-format SQL at construction time to avoid per-call f-string overhead + # and to bake in the validated table names once. + self._sql_circuit_get = ( + f"SELECT state, opened_at FROM {ct} WHERE breaker_key = %s" # noqa: S608 + ) + self._sql_circuit_set_half_open = ( + f"UPDATE {ct} SET state = 'half_open', success_count = 0, " # noqa: S608 + f"updated_at = now() WHERE breaker_key = %s AND state = 'open'" + ) + # Atomic upsert: increment failure_count; open circuit when threshold crossed. + # {ct}.column_name in DO UPDATE SET refers to the *existing* row's value + # (before the update), which is what Postgres requires for self-referential + # ON CONFLICT expressions. + failure_t = self._circuit_policy.failure_threshold + self._sql_circuit_failure = ( + f"INSERT INTO {ct} " # noqa: S608 + f"(breaker_key, state, failure_count, success_count, updated_at) " + f"VALUES (%s, 'closed', 1, 0, now()) " + f"ON CONFLICT (breaker_key) DO UPDATE SET " + f" failure_count = {ct}.failure_count + 1, " + f" success_count = 0, " + f" state = CASE " + f" WHEN {ct}.failure_count + 1 >= {failure_t} OR {ct}.state = 'half_open' " + f" THEN 'open' ELSE {ct}.state END, " + f" opened_at = CASE " + f" WHEN ({ct}.failure_count + 1 >= {failure_t} OR {ct}.state = 'half_open') " + f" AND {ct}.state != 'open' THEN now() " + f" ELSE {ct}.opened_at END, " + f" updated_at = now()" + ) + # Atomic upsert: increment success_count; close circuit when threshold crossed. + # RETURNING lets us read the post-update state without a second query, which + # eliminates the half-open race: two concurrent workers see different final + # counts and only the worker whose UPDATE produces count >= threshold + # transitions to 'closed'. + success_t = self._circuit_policy.success_threshold + self._sql_circuit_success = ( + f"INSERT INTO {ct} " # noqa: S608 + f"(breaker_key, state, failure_count, success_count, updated_at) " + f"VALUES (%s, 'closed', 0, 0, now()) " + f"ON CONFLICT (breaker_key) DO UPDATE SET " + f" failure_count = 0, " + f" success_count = CASE " + f" WHEN {ct}.state = 'half_open' THEN {ct}.success_count + 1 ELSE 0 END, " + f" state = CASE " + f" WHEN {ct}.state = 'half_open' AND {ct}.success_count + 1 >= {success_t} " + f" THEN 'closed' ELSE {ct}.state END, " + f" updated_at = now() " + f"RETURNING state, success_count" + ) + self._sql_enqueue = ( + f"INSERT INTO {qt} " # noqa: S608 + f"(breaker_key, url, task_id, task_type, status_str, result_json, " + f"token, sequence_key, max_attempts, notification_type) " + f"VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) RETURNING id" + ) + self._sql_poll = ( + f"SELECT id, breaker_key, url, task_id, task_type, status_str, " # noqa: S608 + f"result_json, token, sequence_key, attempt_count, max_attempts, " + f"idempotency_key, sent_body, notification_type " + f"FROM {qt} " + f"WHERE status_str IN ('pending', 'retry') AND scheduled_at <= now() " + f"ORDER BY scheduled_at LIMIT 1 FOR UPDATE SKIP LOCKED" + ) + self._sql_delete_job = f"DELETE FROM {qt} WHERE id = %s" # noqa: S608 + self._sql_reschedule = ( + f"UPDATE {qt} SET " # noqa: S608 + f" status_str = 'retry', " + f" attempt_count = %s, " + f" scheduled_at = %s, " + f" sent_body = %s, " + f" idempotency_key = %s " + f"WHERE id = %s" + ) + self._sql_log_insert = ( + f"INSERT INTO {lt} " # noqa: S608 + f"(queue_id, breaker_key, url, task_id, sequence_key, sequence_number, " + f"attempt_number, max_attempts, outcome, http_status_code, error_message, " + f"response_time_ms, occurred_at, will_retry, next_retry_at, task_type, " + f"payload_size_bytes, notification_type) " + f"VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)" + ) + + self._worker_started = False + self._worker_warned = False + + # ------------------------------------------------------------------ schema + + async def create_schema(self) -> None: + """Bootstrap all three tables and their indexes. Idempotent. + + Safe to call on every app boot. Each DDL statement is executed + separately — psycopg does not split on ``;``. + + :: + + pool = AsyncConnectionPool("postgresql://...") + supervisor = PgWebhookDeliverySupervisor(pool=pool, sender=sender) + await supervisor.create_schema() + """ + ct, qt, lt = self._circuit_t, self._queue_t, self._log_t + statements = [ + f"""CREATE TABLE IF NOT EXISTS {ct} ( + breaker_key TEXT COLLATE "C" NOT NULL PRIMARY KEY, + state TEXT NOT NULL DEFAULT 'closed', + failure_count INT NOT NULL DEFAULT 0, + success_count INT NOT NULL DEFAULT 0, + opened_at TIMESTAMPTZ, + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() + )""", + f"""CREATE TABLE IF NOT EXISTS {qt} ( + id BIGSERIAL PRIMARY KEY, + breaker_key TEXT NOT NULL, + url TEXT NOT NULL, + task_id TEXT NOT NULL, + task_type TEXT, + status_str TEXT NOT NULL DEFAULT 'pending', + result_json TEXT, + token TEXT, + sequence_key TEXT, + attempt_count INT NOT NULL DEFAULT 0, + max_attempts INT NOT NULL DEFAULT 3, + scheduled_at TIMESTAMPTZ NOT NULL DEFAULT now(), + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + idempotency_key TEXT, + sent_body BYTEA, + notification_type TEXT + )""", + # Partial index on work-eligible rows; avoids scanning completed/in-flight rows. + f"""CREATE INDEX IF NOT EXISTS {qt}_work_idx + ON {qt} (status_str, scheduled_at) + WHERE status_str IN ('pending', 'retry')""", + f"""CREATE TABLE IF NOT EXISTS {lt} ( + id BIGSERIAL PRIMARY KEY, + queue_id BIGINT, + breaker_key TEXT NOT NULL, + url TEXT NOT NULL, + task_id TEXT NOT NULL, + sequence_key TEXT, + sequence_number BIGINT, + attempt_number INT NOT NULL, + max_attempts INT NOT NULL, + outcome TEXT NOT NULL, + http_status_code INT, + error_message TEXT, + response_time_ms INT NOT NULL DEFAULT 0, + occurred_at TIMESTAMPTZ NOT NULL, + will_retry BOOLEAN NOT NULL DEFAULT false, + next_retry_at TIMESTAMPTZ, + task_type TEXT, + payload_size_bytes INT, + notification_type TEXT + )""", + f"CREATE INDEX IF NOT EXISTS {lt}_queue_id_idx ON {lt} (queue_id)", + f"CREATE INDEX IF NOT EXISTS {lt}_task_id_idx ON {lt} (task_id)", + ] + async with self._pool.connection() as conn: + for stmt in statements: + await conn.execute(stmt) + + # ----------------------------------------------------------------- send + + async def send_mcp( + self, + *, + url: str, + task_id: str, + status: GeneratedTaskStatus | str, + task_type: str | None = None, + result: Any = None, + token: str | None = None, + sequence_key: str | None = None, + breaker_key: str | None = None, + notification_type: str | None = None, + ) -> WebhookDeliveryResult | None: + """Enqueue one MCP-style webhook delivery; **always returns None**. + + The actual HTTP send happens in :meth:`run_worker`. This method + writes the job to ``adcp_webhook_delivery_queue`` and returns + immediately. ``None`` is returned whether the circuit is open + (delivery skipped) or the job was successfully enqueued. + + For delivery outcomes use a + :class:`~adcp.webhook_supervisor.DeliveryLogSink` or query the + ``adcp_webhook_delivery_log`` table directly. + + :param breaker_key: Override the circuit-breaker lookup key (default: + ``url``). Multi-tenant sellers whose buyers share a SaaS receiver + URL MUST pass a tenant-scoped key (e.g. ``f"{tenant_id}:{url}"``). + :param notification_type: Passed through to the delivery log for + delivery-report webhooks (``scheduled`` / ``final`` / etc.). + """ + if not self._worker_started and not self._worker_warned: + self._worker_warned = True + logger.warning( + "[adcp.webhook_supervisor_pg] send_mcp() called before run_worker() " + "has been started. Deliveries will be queued but not sent until a worker " + "is running. Call asyncio.create_task(supervisor.run_worker()) at startup." + ) + + bkey = breaker_key or url + + # Check circuit state; reject immediately if OPEN within the timeout window. + async with self._pool.connection() as conn: + cur = await conn.execute(self._sql_circuit_get, (bkey,)) + row = await cur.fetchone() + + if row is not None: + state: str = row[0] + opened_at: datetime | None = row[1] + if state == "open" and opened_at is not None: + if opened_at.tzinfo is None: + opened_at = opened_at.replace(tzinfo=UTC) + elapsed = (datetime.now(UTC) - opened_at).total_seconds() + if elapsed < self._circuit_policy.open_timeout_seconds: + occurred_at = datetime.now(UTC) + attempt = DeliveryAttempt( + url=url, + sequence_key=sequence_key, + sequence_number=None, + attempt_number=0, + max_attempts=self._retry.max_attempts, + outcome="circuit_open", + http_status_code=None, + error_message=f"circuit breaker OPEN for {bkey} — skipped delivery", + response_time_ms=0, + occurred_at=occurred_at, + will_retry=False, + next_retry_at=None, + task_type=task_type, + task_id=task_id, + payload_size_bytes=None, + notification_type=notification_type, + ) + await self._log_circuit_open(bkey, attempt) + await self._call_sink(attempt) + logger.warning( + "[adcp.webhook_supervisor_pg] circuit OPEN for %s — skipped %s", + bkey, + task_type or "webhook", + ) + return None + # Open timeout elapsed; transition to half_open so next worker probes. + async with self._pool.connection() as conn: + await conn.execute(self._sql_circuit_set_half_open, (bkey,)) + + status_str = status if isinstance(status, str) else str(status) + result_json = json.dumps(result) if result is not None else None + + async with self._pool.connection() as conn: + cur = await conn.execute( + self._sql_enqueue, + ( + bkey, + url, + task_id, + task_type, + status_str, + result_json, + token, + sequence_key, + self._retry.max_attempts, + notification_type, + ), + ) + enqueue_row = await cur.fetchone() + + queue_id = enqueue_row[0] if enqueue_row else None + logger.debug( + "[adcp.webhook_supervisor_pg] enqueued %s → %s (queue_id=%s)", + task_id, + url, + queue_id, + ) + return None + + # ----------------------------------------------------------------- worker + + async def run_worker(self, *, poll_interval: float = 0.5) -> None: + """Poll the delivery queue with ``FOR UPDATE SKIP LOCKED``; runs forever. + + **REQUIRED** — start at app startup:: + + asyncio.create_task(supervisor.run_worker()) + + Multiple processes or coroutines can run concurrently; ``SKIP LOCKED`` + ensures each job is processed by exactly one worker. The DB connection + is held for the duration of the HTTP send so a crashed worker + automatically releases the job back to the queue via transaction + rollback (no separate recovery sweep needed). + + :param poll_interval: Seconds to sleep when the queue is empty. + """ + self._worker_started = True + logger.info( + "[adcp.webhook_supervisor_pg] worker started (poll_interval=%.2fs)", + poll_interval, + ) + while True: + try: + delivered = await self._poll_and_process() + if not delivered: + await asyncio.sleep(poll_interval) + except asyncio.CancelledError: + logger.info("[adcp.webhook_supervisor_pg] worker cancelled — shut down") + raise + except Exception: + logger.exception( + "[adcp.webhook_supervisor_pg] worker error; will retry after %.2fs", + poll_interval, + ) + await asyncio.sleep(poll_interval) + + # ---------------------------------------------------------- internal helpers + + async def _poll_and_process(self) -> bool: + """Lease one job from the queue, send it, update state. Returns True if processed.""" + # The connection stays open for the duration of the HTTP send so the + # FOR UPDATE lock is held throughout. On crash (or CancelledError), + # the transaction rolls back and the job returns to 'pending'. + async with self._pool.connection() as conn: + cur = await conn.execute(self._sql_poll) + row = await cur.fetchone() + if row is None: + return False + + ( + queue_id, + bkey, + url, + task_id, + task_type, + status_str, + result_json, + token, + sequence_key, + attempt_count, + max_attempts, + idempotency_key, + sent_body, + notification_type, + ) = row + + attempt_number = attempt_count + 1 + will_retry = attempt_number < max_attempts + + # Deferred import avoids circular module dependency at package init. + from adcp.webhook_sender import WebhookDeliveryResult # noqa: PLC0415 + + t0 = time.monotonic() + occurred_at = datetime.now(UTC) + delivery_result: WebhookDeliveryResult | None = None + exc_caught: BaseException | None = None + + try: + if sent_body: + # Spec-compliant retry: replay the exact bytes so the receiver + # can dedup via the same idempotency_key (per mcp-webhook-payload.json: + # "Publishers MUST … reuse the same key on every retry"). + prev = WebhookDeliveryResult( + status_code=0, + idempotency_key=idempotency_key or "", + url=url, + response_headers={}, + response_body=b"", + sent_body=sent_body, + ) + delivery_result = await self._sender.resend(prev) + else: + result_obj = json.loads(result_json) if result_json else None + delivery_result = await self._sender.send_mcp( + url=url, + task_id=task_id, + status=status_str, + task_type=task_type, + result=result_obj, + token=token, + ) + except Exception as exc: + exc_caught = exc + + response_time_ms = int((time.monotonic() - t0) * 1000) + success = delivery_result is not None and delivery_result.ok + + if success: + assert delivery_result is not None # narrowing for mypy + await conn.execute(self._sql_delete_job, (queue_id,)) + await conn.execute(self._sql_circuit_success, (bkey,)) + attempt = DeliveryAttempt( + url=url, + sequence_key=sequence_key, + sequence_number=queue_id, + attempt_number=attempt_number, + max_attempts=max_attempts, + outcome="success", + http_status_code=delivery_result.status_code, + error_message=None, + response_time_ms=response_time_ms, + occurred_at=occurred_at, + will_retry=False, + next_retry_at=None, + task_type=task_type, + task_id=task_id, + payload_size_bytes=( + len(delivery_result.sent_body) if delivery_result.sent_body else None + ), + notification_type=notification_type, + ) + else: + await conn.execute(self._sql_circuit_failure, (bkey,)) + + next_delay = ( + self._retry.delay_for_attempt(attempt_number + 1) if will_retry else None + ) + next_retry_at = ( + occurred_at + timedelta(seconds=next_delay) + if next_delay is not None + else None + ) + + if will_retry: + stored_body = delivery_result.sent_body if delivery_result is not None else None + stored_ikey = ( + delivery_result.idempotency_key if delivery_result is not None else None + ) + await conn.execute( + self._sql_reschedule, + ( + attempt_number, + next_retry_at or occurred_at, + stored_body, + stored_ikey, + queue_id, + ), + ) + else: + await conn.execute(self._sql_delete_job, (queue_id,)) + + if delivery_result is not None: + err_msg: str | None = ( + f"HTTP {delivery_result.status_code}: " + f"{delivery_result.response_body[:200]!r}" + ) + http_status: int | None = delivery_result.status_code + psize: int | None = ( + len(delivery_result.sent_body) if delivery_result.sent_body else None + ) + elif exc_caught is not None: + err_msg = f"{type(exc_caught).__name__}: {exc_caught}" + http_status = None + psize = None + else: + err_msg = None + http_status = None + psize = None + + attempt = DeliveryAttempt( + url=url, + sequence_key=sequence_key, + sequence_number=queue_id, + attempt_number=attempt_number, + max_attempts=max_attempts, + outcome="failure", + http_status_code=http_status, + error_message=err_msg, + response_time_ms=response_time_ms, + occurred_at=occurred_at, + will_retry=will_retry, + next_retry_at=next_retry_at, + task_type=task_type, + task_id=task_id, + payload_size_bytes=psize, + notification_type=notification_type, + ) + + await self._log_attempt_via_conn(conn, queue_id, bkey, attempt) + # connection __aexit__ commits here; lock released + + # Sink is called outside the transaction so a slow/broken sink + # cannot interfere with the queue update. + await self._call_sink(attempt) + + if exc_caught is not None and not will_retry: + raise exc_caught # propagate to run_worker's exception logger + + return True + + async def _log_attempt_via_conn( + self, + conn: Any, + queue_id: int | None, + bkey: str, + attempt: DeliveryAttempt, + ) -> None: + """Write a delivery attempt row using an already-open connection.""" + try: + await conn.execute( + self._sql_log_insert, + ( + queue_id, + bkey, + attempt.url, + attempt.task_id, + attempt.sequence_key, + attempt.sequence_number, + attempt.attempt_number, + attempt.max_attempts, + attempt.outcome, + attempt.http_status_code, + attempt.error_message, + attempt.response_time_ms, + attempt.occurred_at, + attempt.will_retry, + attempt.next_retry_at, + attempt.task_type, + attempt.payload_size_bytes, + attempt.notification_type, + ), + ) + except Exception: + logger.warning( + "[adcp.webhook_supervisor_pg] failed to log attempt for %s — delivery unaffected", + attempt.url, + exc_info=True, + ) + + async def _log_circuit_open(self, bkey: str, attempt: DeliveryAttempt) -> None: + """Write a circuit_open attempt row using a fresh connection.""" + try: + async with self._pool.connection() as conn: + await self._log_attempt_via_conn(conn, None, bkey, attempt) + except Exception: + logger.warning( + "[adcp.webhook_supervisor_pg] failed to log circuit_open for %s", + attempt.url, + exc_info=True, + ) + + async def _call_sink(self, attempt: DeliveryAttempt) -> None: + if self._log_sink is None: + return + try: + await asyncio.wait_for( + self._log_sink.record(attempt), + timeout=self._retry.sink_timeout_seconds, + ) + except asyncio.TimeoutError: + logger.warning( + "[adcp.webhook_supervisor_pg] DeliveryLogSink timed out for %s", + attempt.url, + ) + except Exception: + logger.warning( + "[adcp.webhook_supervisor_pg] DeliveryLogSink raised for %s", + attempt.url, + exc_info=True, + ) + + +__all__ = [ + "DEFAULT_CIRCUIT_TABLE", + "DEFAULT_LOG_TABLE", + "DEFAULT_QUEUE_TABLE", + "PG_AVAILABLE", + "PgWebhookDeliverySupervisor", +] diff --git a/src/adcp/webhooks.py b/src/adcp/webhooks.py index 47e71a81b..33aad44f2 100644 --- a/src/adcp/webhooks.py +++ b/src/adcp/webhooks.py @@ -1157,6 +1157,9 @@ def _validate_header_value(name: str, value: Any) -> None: WebhookDeliveryResult, WebhookSender, ) +from adcp.webhook_supervisor_pg import ( # noqa: E402 + PgWebhookDeliverySupervisor, +) __all__ = [ # Sender — payload builders @@ -1193,4 +1196,6 @@ def _validate_header_value(name: str, value: Any) -> None: # Dedup / idempotency backends (re-exported so one import root suffices) "MemoryBackend", "WebhookDedupStore", + # Pg-backed supervisor (requires adcp[pg] extra) + "PgWebhookDeliverySupervisor", ] diff --git a/tests/test_webhook_supervisor_pg.py b/tests/test_webhook_supervisor_pg.py new file mode 100644 index 000000000..597a8b7ae --- /dev/null +++ b/tests/test_webhook_supervisor_pg.py @@ -0,0 +1,649 @@ +"""Tests for :mod:`adcp.webhook_supervisor_pg` — unit tests with mock psycopg pool. + +psycopg3 is an optional dependency (``adcp[pg]``); these tests mock the pool +and connection entirely so they pass without a real Postgres instance or the +psycopg package installed. + +Behaviour under test: + +* :class:`PgWebhookDeliverySupervisor` raises on construction when pg deps + are absent, sender is None, or a table name is invalid. +* ``send_mcp`` emits a warning when called before ``run_worker`` is started. +* ``send_mcp`` checks circuit state from the DB and rejects OPEN circuits. +* ``send_mcp`` enqueues to the delivery queue and returns ``None``. +* ``_poll_and_process`` (worker core) handles success, failure, and retry. +* Retry path uses ``sender.resend()`` to replay the same wire bytes + (spec-compliant idempotency-key reuse). +* Circuit success uses RETURNING to get post-update state (half-open atomicity). +* Sink timeouts and exceptions are swallowed; a broken sink must not cascade. +* ``create_schema`` executes exactly one statement per DDL item (psycopg + does not split on semicolons; each must be a separate ``execute`` call). +""" + +from __future__ import annotations + +import asyncio +from datetime import datetime, timezone +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from adcp.webhook_sender import WebhookDeliveryResult +from adcp.webhook_supervisor import ( + DeliveryAttempt, + RetryPolicy, +) + +UTC = timezone.utc + +# --------------------------------------------------------------------------- helpers + + +def _ok(url: str = "https://buyer.example/wh") -> WebhookDeliveryResult: + return WebhookDeliveryResult( + status_code=200, + idempotency_key="ikey-1", + url=url, + response_headers={}, + response_body=b"{}", + sent_body=b'{"task":"done"}', + ) + + +def _fail(url: str = "https://buyer.example/wh") -> WebhookDeliveryResult: + return WebhookDeliveryResult( + status_code=503, + idempotency_key="ikey-1", + url=url, + response_headers={}, + response_body=b"upstream error", + sent_body=b'{"task":"done"}', + ) + + +def _cursor(val: Any = None) -> AsyncMock: + """Fake async cursor whose fetchone() returns val.""" + cur = AsyncMock() + cur.fetchone = AsyncMock(return_value=val) + return cur + + +def _make_conn(*fetchone_vals: Any) -> AsyncMock: + """Fake AsyncConnection whose sequential execute() calls return cursors. + + Each positional arg becomes the fetchone() return value for the + corresponding execute() call. + """ + cursors = [_cursor(v) for v in fetchone_vals] + conn = AsyncMock() + conn.__aenter__ = AsyncMock(return_value=conn) + conn.__aexit__ = AsyncMock(return_value=False) + conn.execute = AsyncMock(side_effect=cursors) + return conn + + +def _make_pool(*conns: AsyncMock) -> MagicMock: + """Fake AsyncConnectionPool that yields the given connections in order.""" + ctxs = [] + for conn in conns: + ctx = AsyncMock() + ctx.__aenter__ = AsyncMock(return_value=conn) + ctx.__aexit__ = AsyncMock(return_value=False) + ctxs.append(ctx) + pool = MagicMock() + pool.connection = MagicMock(side_effect=ctxs) + return pool + + +def _make_sender(send_result: WebhookDeliveryResult | None = None) -> AsyncMock: + sender = AsyncMock() + sender.send_mcp = AsyncMock(return_value=send_result or _ok()) + sender.resend = AsyncMock(return_value=send_result or _ok()) + return sender + + +def _make_supervisor(pool: Any, sender: Any, **kwargs: Any) -> Any: + """Construct PgWebhookDeliverySupervisor with PG_AVAILABLE patched to True.""" + from adcp.webhook_supervisor_pg import PgWebhookDeliverySupervisor + + with patch("adcp.webhook_supervisor_pg.PG_AVAILABLE", True): + sup = PgWebhookDeliverySupervisor.__new__(PgWebhookDeliverySupervisor) + PgWebhookDeliverySupervisor.__init__(sup, pool, sender, **kwargs) + return sup + + +# ---------------------------------------------------------------------- fixtures + +_QUEUE_ROW_BASE = ( + # id, breaker_key, url, task_id, task_type, status_str, + 1, + "https://buyer.example/wh", + "https://buyer.example/wh", + "task-123", + "sync_completion", + "pending", + # result_json, token, sequence_key, attempt_count, max_attempts + None, + None, + None, + 0, + 3, + # idempotency_key, sent_body, notification_type + None, + None, + None, +) + + +def _queue_row(**overrides: Any) -> tuple: + row = list(_QUEUE_ROW_BASE) + _fields = [ + "id", + "breaker_key", + "url", + "task_id", + "task_type", + "status_str", + "result_json", + "token", + "sequence_key", + "attempt_count", + "max_attempts", + "idempotency_key", + "sent_body", + "notification_type", + ] + for k, v in overrides.items(): + row[_fields.index(k)] = v + return tuple(row) + + +# ----------------------------------------------------------------------- tests + + +class TestConstruction: + def test_raises_without_pg(self) -> None: + from adcp.webhook_supervisor_pg import PgWebhookDeliverySupervisor + + with patch("adcp.webhook_supervisor_pg.PG_AVAILABLE", False): + with pytest.raises(ImportError, match="pip install 'adcp\\[pg\\]'"): + PgWebhookDeliverySupervisor(MagicMock(), _make_sender()) + + def test_raises_for_none_sender(self) -> None: + with pytest.raises(ValueError, match="non-None WebhookSender"): + _make_supervisor(MagicMock(), None) # type: ignore[arg-type] + + def test_raises_for_invalid_table_name(self) -> None: + with pytest.raises(ValueError, match="ASCII only"): + _make_supervisor(MagicMock(), _make_sender(), circuit_table="bad-name!") + + def test_raises_for_unicode_table_name(self) -> None: + with pytest.raises(ValueError, match="ASCII only"): + _make_supervisor(MagicMock(), _make_sender(), queue_table="adcp_wébhook") + + def test_custom_table_names_accepted(self) -> None: + sup = _make_supervisor( + MagicMock(), + _make_sender(), + circuit_table="my_circuit", + queue_table="my_queue", + log_table="my_log", + ) + assert sup._circuit_t == "my_circuit" + assert sup._queue_t == "my_queue" + assert sup._log_t == "my_log" + + def test_preformatted_sql_uses_table_names(self) -> None: + sup = _make_supervisor( + MagicMock(), + _make_sender(), + circuit_table="my_circuit", + queue_table="my_queue", + ) + assert "my_circuit" in sup._sql_circuit_get + assert "my_queue" in sup._sql_enqueue + assert "my_queue" in sup._sql_poll + + +class TestCreateSchema: + @pytest.mark.asyncio + async def test_executes_all_ddl_statements_separately(self) -> None: + conn = AsyncMock() + conn.__aenter__ = AsyncMock(return_value=conn) + conn.__aexit__ = AsyncMock(return_value=False) + conn.execute = AsyncMock() + ctx = AsyncMock() + ctx.__aenter__ = AsyncMock(return_value=conn) + ctx.__aexit__ = AsyncMock(return_value=False) + pool = MagicMock() + pool.connection = MagicMock(return_value=ctx) + + sup = _make_supervisor(pool, _make_sender()) + await sup.create_schema() + + # 6 statements: 3 tables + 3 indexes (1 partial + 2 standard) + assert conn.execute.call_count == 6 + + @pytest.mark.asyncio + async def test_each_statement_contains_table_name(self) -> None: + conn = AsyncMock() + conn.__aenter__ = AsyncMock(return_value=conn) + conn.__aexit__ = AsyncMock(return_value=False) + conn.execute = AsyncMock() + ctx = AsyncMock() + ctx.__aenter__ = AsyncMock(return_value=conn) + ctx.__aexit__ = AsyncMock(return_value=False) + pool = MagicMock() + pool.connection = MagicMock(return_value=ctx) + + sup = _make_supervisor(pool, _make_sender()) + await sup.create_schema() + + sqls = [call.args[0] for call in conn.execute.call_args_list] + assert any("adcp_webhook_circuit_state" in s for s in sqls) + assert any("adcp_webhook_delivery_queue" in s for s in sqls) + assert any("adcp_webhook_delivery_log" in s for s in sqls) + + +class TestSendMcp: + @pytest.mark.asyncio + async def test_returns_none_on_success(self) -> None: + conn_circuit = _make_conn(None) # no circuit row (first send) + conn_enqueue = _make_conn((42,)) # queue_id = 42 + pool = _make_pool(conn_circuit, conn_enqueue) + + sup = _make_supervisor(pool, _make_sender()) + sup._worker_started = True # suppress the warning + + result = await sup.send_mcp(url="https://b.example/wh", task_id="t1", status="completed") + assert result is None + + @pytest.mark.asyncio + async def test_warning_emitted_before_worker_starts(self, caplog: Any) -> None: + conn_circuit = _make_conn(None) + conn_enqueue = _make_conn((1,)) + pool = _make_pool(conn_circuit, conn_enqueue) + + sup = _make_supervisor(pool, _make_sender()) + # Do NOT set _worker_started + + import logging + + with caplog.at_level(logging.WARNING, logger="adcp.webhook_supervisor_pg"): + await sup.send_mcp(url="https://b.example/wh", task_id="t1", status="completed") + + assert any("run_worker" in r.message for r in caplog.records) + + @pytest.mark.asyncio + async def test_warning_emitted_only_once(self, caplog: Any) -> None: + import logging + + # Two calls on same supervisor → one warning + conn_c1 = _make_conn(None) + conn_e1 = _make_conn((1,)) + conn_c2 = _make_conn(None) + conn_e2 = _make_conn((2,)) + pool = _make_pool(conn_c1, conn_e1, conn_c2, conn_e2) + sup = _make_supervisor(pool, _make_sender()) + with caplog.at_level(logging.WARNING, logger="adcp.webhook_supervisor_pg"): + await sup.send_mcp(url="u", task_id="t", status="s") + await sup.send_mcp(url="u", task_id="t", status="s") + + warn_msgs = [r.message for r in caplog.records if "run_worker" in r.message] + assert len(warn_msgs) == 1 + + @pytest.mark.asyncio + async def test_circuit_open_rejects_and_returns_none(self) -> None: + from datetime import timedelta + + opened_at = datetime.now(UTC) - timedelta(seconds=5) # opened 5s ago, timeout=60s + + # 1st pool.connection: circuit_get → OPEN row + conn_circuit = _make_conn(("open", opened_at)) + # 2nd: log_circuit_open → no fetchone needed + conn_log = _make_conn(None) + pool = _make_pool(conn_circuit, conn_log) + + sup = _make_supervisor(pool, _make_sender()) + sup._worker_started = True + + result = await sup.send_mcp(url="https://b.example/wh", task_id="t", status="s") + assert result is None + # Should NOT have called enqueue + enqueue_calls = [ + c + for c in conn_circuit.execute.call_args_list + if "INSERT INTO" in (c.args[0] if c.args else "") + and "delivery_queue" in (c.args[0] if c.args else "") + ] + assert len(enqueue_calls) == 0 + + @pytest.mark.asyncio + async def test_circuit_open_timeout_transitions_to_half_open(self) -> None: + from datetime import timedelta + + opened_at = datetime.now(UTC) - timedelta(seconds=90) # 90s > 60s timeout + + conn_circuit = _make_conn(("open", opened_at)) + conn_half_open = _make_conn(None) # set_half_open + conn_enqueue = _make_conn((7,)) + pool = _make_pool(conn_circuit, conn_half_open, conn_enqueue) + + sup = _make_supervisor(pool, _make_sender()) + sup._worker_started = True + + result = await sup.send_mcp(url="https://b.example/wh", task_id="t", status="s") + assert result is None # always None from send_mcp + + # set_half_open was called + half_open_sql = conn_half_open.execute.call_args_list[0].args[0] + assert "half_open" in half_open_sql + + @pytest.mark.asyncio + async def test_breaker_key_used_as_circuit_lookup_key(self) -> None: + conn_circuit = _make_conn(None) + conn_enqueue = _make_conn((1,)) + pool = _make_pool(conn_circuit, conn_enqueue) + + sup = _make_supervisor(pool, _make_sender()) + sup._worker_started = True + + await sup.send_mcp( + url="https://shared.example/wh", + task_id="t", + status="s", + breaker_key="tenant-42:https://shared.example/wh", + ) + + circuit_params = conn_circuit.execute.call_args_list[0].args[1] + assert circuit_params[0] == "tenant-42:https://shared.example/wh" + + +class TestWorkerSuccess: + @pytest.mark.asyncio + async def test_success_deletes_job_and_updates_circuit(self) -> None: + sender = _make_sender(_ok()) + poll_row = _queue_row() + + # All within one connection (worker keeps it open) + conn = _make_conn( + poll_row, # poll + None, # delete_job + ("closed", 0), # circuit_success RETURNING + None, # log_insert + ) + pool = _make_pool(conn) + + sup = _make_supervisor(pool, sender) + sup._worker_started = True + + delivered = await sup._poll_and_process() + assert delivered is True + + sql_calls = [c.args[0] for c in conn.execute.call_args_list] + assert any("FOR UPDATE SKIP LOCKED" in s for s in sql_calls) + assert any("DELETE FROM" in s for s in sql_calls) + assert any("circuit_success" in s or "success_count" in s for s in sql_calls) + + @pytest.mark.asyncio + async def test_success_calls_sender_send_mcp(self) -> None: + sender = _make_sender(_ok()) + conn = _make_conn( + _queue_row(), + None, + ("closed", 0), + None, + ) + pool = _make_pool(conn) + sup = _make_supervisor(pool, sender) + + await sup._poll_and_process() + + sender.send_mcp.assert_awaited_once() + sender.resend.assert_not_awaited() + + @pytest.mark.asyncio + async def test_empty_queue_returns_false(self) -> None: + conn = _make_conn(None) # poll returns None + pool = _make_pool(conn) + sup = _make_supervisor(pool, _make_sender()) + + delivered = await sup._poll_and_process() + assert delivered is False + + +class TestWorkerFailureAndRetry: + @pytest.mark.asyncio + async def test_failure_reschedules_when_retries_remain(self) -> None: + sender = _make_sender(_fail()) + conn = _make_conn( + _queue_row(attempt_count=0, max_attempts=3), # poll + None, # circuit_failure + None, # reschedule (not delete) + None, # log_insert + ) + pool = _make_pool(conn) + sup = _make_supervisor(pool, sender) + + await sup._poll_and_process() + + sql_calls = [c.args[0] for c in conn.execute.call_args_list] + # Should reschedule, not delete + assert any("status_str = 'retry'" in s for s in sql_calls) + assert not any("DELETE FROM" in s for s in sql_calls) + + @pytest.mark.asyncio + async def test_final_failure_deletes_job(self) -> None: + sender = _make_sender(_fail()) + conn = _make_conn( + _queue_row(attempt_count=2, max_attempts=3), # attempt 3/3 + None, # circuit_failure + None, # delete_job + None, # log_insert + ) + pool = _make_pool(conn) + sup = _make_supervisor(pool, sender) + + await sup._poll_and_process() + + sql_calls = [c.args[0] for c in conn.execute.call_args_list] + assert any("DELETE FROM" in s for s in sql_calls) + assert not any("status_str = 'retry'" in s for s in sql_calls) + + @pytest.mark.asyncio + async def test_retry_uses_resend_when_sent_body_stored(self) -> None: + """Second attempt must call resend() with stored bytes for idempotency-key parity.""" + sender = _make_sender(_ok()) + stored_body = b'{"original":"payload"}' + conn = _make_conn( + _queue_row( + attempt_count=1, # this is attempt 2 + max_attempts=3, + sent_body=stored_body, + idempotency_key="original-ikey", + ), + None, # delete_job (success) + ("closed", 0), # circuit_success + None, # log_insert + ) + pool = _make_pool(conn) + sup = _make_supervisor(pool, sender) + + await sup._poll_and_process() + + sender.resend.assert_awaited_once() + sender.send_mcp.assert_not_awaited() + + # The resend call should receive the stored idempotency_key + resend_arg = sender.resend.call_args.args[0] + assert resend_arg.idempotency_key == "original-ikey" + assert resend_arg.sent_body == stored_body + + @pytest.mark.asyncio + async def test_reschedule_stores_sent_body_for_next_attempt(self) -> None: + fail_result = _fail() + sender = _make_sender(fail_result) + conn = _make_conn( + _queue_row(attempt_count=0, max_attempts=3), # attempt 1/3 + None, # circuit_failure + None, # reschedule + None, # log_insert + ) + pool = _make_pool(conn) + sup = _make_supervisor(pool, sender) + + await sup._poll_and_process() + + sql_calls = conn.execute.call_args_list + reschedule_call = next( + (c for c in sql_calls if "status_str = 'retry'" in c.args[0]), None + ) + assert reschedule_call is not None + # sent_body and idempotency_key are positional params 3 and 4 (0-indexed) + params = reschedule_call.args[1] + assert params[2] == fail_result.sent_body # sent_body + assert params[3] == fail_result.idempotency_key # idempotency_key + + +class TestWorkerCircuitState: + @pytest.mark.asyncio + async def test_success_circuit_query_uses_returning(self) -> None: + conn = _make_conn( + _queue_row(), + None, + ("closed", 0), + None, + ) + pool = _make_pool(conn) + sup = _make_supervisor(pool, _make_sender()) + + await sup._poll_and_process() + + circuit_success_sql = conn.execute.call_args_list[2].args[0] + assert "RETURNING" in circuit_success_sql.upper() + + @pytest.mark.asyncio + async def test_failure_circuit_query_is_upsert(self) -> None: + conn = _make_conn( + _queue_row(attempt_count=2, max_attempts=3), + None, + None, + None, + ) + pool = _make_pool(conn) + sup = _make_supervisor(pool, _make_sender(_fail())) + + await sup._poll_and_process() + + circuit_fail_sql = conn.execute.call_args_list[1].args[0] + assert "ON CONFLICT" in circuit_fail_sql.upper() + assert "failure_count" in circuit_fail_sql + + +class TestSinkBehavior: + @pytest.mark.asyncio + async def test_slow_sink_is_swallowed(self) -> None: + async def _slow_record(attempt: DeliveryAttempt) -> None: + await asyncio.sleep(99) + + from adcp.webhook_supervisor import DeliveryLogSink + + class _SlowSink(DeliveryLogSink): + async def record(self, attempt: DeliveryAttempt) -> None: # type: ignore[override] + await asyncio.sleep(99) + + conn = _make_conn( + _queue_row(), + None, + ("closed", 0), + None, + ) + pool = _make_pool(conn) + # Use a very short timeout so the test doesn't actually sleep 99s + retry = RetryPolicy(sink_timeout_seconds=0.01) + sup = _make_supervisor(pool, _make_sender(), retry=retry, log_sink=_SlowSink()) + + # Must not raise — sink timeout must be swallowed + delivered = await sup._poll_and_process() + assert delivered is True + + @pytest.mark.asyncio + async def test_exploding_sink_is_swallowed(self) -> None: + from adcp.webhook_supervisor import DeliveryLogSink + + class _ExplodingSink(DeliveryLogSink): + async def record(self, attempt: DeliveryAttempt) -> None: # type: ignore[override] + raise RuntimeError("BOOM") + + conn = _make_conn( + _queue_row(), + None, + ("closed", 0), + None, + ) + pool = _make_pool(conn) + sup = _make_supervisor(pool, _make_sender(), log_sink=_ExplodingSink()) + + delivered = await sup._poll_and_process() + assert delivered is True + + +class TestLogAttemptFault: + @pytest.mark.asyncio + async def test_log_insert_failure_does_not_crash_worker(self) -> None: + conn = AsyncMock() + conn.__aenter__ = AsyncMock(return_value=conn) + conn.__aexit__ = AsyncMock(return_value=False) + + poll_cur = _cursor(_queue_row()) + delete_cur = _cursor(None) + circuit_cur = _cursor(("closed", 0)) + + conn.execute = AsyncMock( + side_effect=[poll_cur, delete_cur, circuit_cur, RuntimeError("DB gone")] + ) + ctx = AsyncMock() + ctx.__aenter__ = AsyncMock(return_value=conn) + ctx.__aexit__ = AsyncMock(return_value=False) + pool = MagicMock() + pool.connection = MagicMock(return_value=ctx) + + sup = _make_supervisor(pool, _make_sender()) + # Should not raise — log errors are swallowed + delivered = await sup._poll_and_process() + assert delivered is True + + +class TestRunWorkerLifecycle: + @pytest.mark.asyncio + async def test_run_worker_sets_worker_started(self) -> None: + conn = _make_conn(None) # empty queue → sleep → cancel + pool = _make_pool(conn) + sup = _make_supervisor(pool, _make_sender()) + + assert not sup._worker_started + + async def _run_briefly() -> None: + task = asyncio.create_task(sup.run_worker(poll_interval=0.01)) + await asyncio.sleep(0.02) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + await _run_briefly() + assert sup._worker_started + + @pytest.mark.asyncio + async def test_run_worker_reraises_cancelled_error(self) -> None: + conn = _make_conn(None) + pool = _make_pool(conn) + sup = _make_supervisor(pool, _make_sender()) + + task = asyncio.create_task(sup.run_worker(poll_interval=0.01)) + await asyncio.sleep(0.01) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task