From f723ea16444c71d4660196e03e4347f501bc6caf Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 15 Apr 2026 15:34:17 -0700 Subject: [PATCH 1/2] feat(jobs): drain zombie NATS streams in periodic health check A NATS stream whose Django Job is already terminal (or whose Job row has been deleted) is a zombie: it consumes worker poll cycles and redelivery-advisory traffic for no reason, yet the existing cleanup-on-cancel path does not always run. Add a defense-in-depth sub-check to the 15-min jobs_health_check umbrella. _run_zombie_streams_check: - Enumerates every job_{N} stream currently in JetStream via the raw $JS.API.STREAM.LIST endpoint (nats.py's streams_info() drops the server-side "created" timestamp we need for the age guard). - Skips streams younger than Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES (STALLED_JOBS_MAX_MINUTES * 6 by default) to avoid racing with freshly-dispatched jobs whose stream is created before transaction.on_commit persists the Job row. - Drains (delete_consumer + delete_stream) only when the backing Job is in JobState.final_states or the row is missing entirely. Running jobs are left alone regardless of age. - Surfaces results through an IntegrityCheckResult with checked/fixed/unfixable counters and logs a per-stream line carrying status, age, redelivered count, and consumer-drain outcome. New Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES class attribute for symmetry with existing STALLED_JOBS_MAX_MINUTES and FAILED_JOBS_DISPLAY_MAX_HOURS constants. Tighten this if zombie streams start stranding poll cycles faster than the 60-minute default catches them. Tests cover the four outcomes (drain on terminal+old, drain on missing+old, skip on fresh-but-terminal, skip on old-but-running) plus the unfixable path when the NATS drain itself raises. Co-Authored-By: Claude --- ami/jobs/models.py | 7 ++ ami/jobs/tasks.py | 89 ++++++++++++++- ami/jobs/tests/test_periodic_beat_tasks.py | 126 +++++++++++++++++++++ ami/ml/orchestration/nats_queue.py | 96 ++++++++++++++++ 4 files changed, 317 insertions(+), 1 deletion(-) diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 17549b3e9..94c1fdbe8 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -824,6 +824,13 @@ class Job(BaseModel): # N minutes". 10 is conservative; raise if legitimate long-running jobs get # reaped. STALLED_JOBS_MAX_MINUTES = 10 + # Zombie-stream reaper: age threshold above which a NATS stream for a job + # in a terminal state (or missing from Django) is considered safe to drop. + # Kept well above :attr:`STALLED_JOBS_MAX_MINUTES` so newly-dispatched jobs + # whose stream was created before ``transaction.on_commit`` saved the Job + # row do not get reaped. Tighten only if ``cleanup-on-cancel`` misses are + # still stranding consumer poll cycles after this safety net lands. + ZOMBIE_STREAMS_MAX_AGE_MINUTES = STALLED_JOBS_MAX_MINUTES * 6 name = models.CharField(max_length=255) queue = models.CharField(max_length=255, default="default") diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index fa695111d..4eeab7f3d 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -6,7 +6,7 @@ from collections.abc import Callable from typing import TYPE_CHECKING -from asgiref.sync import async_to_sync +from asgiref.sync import async_to_sync, sync_to_async from celery.signals import task_failure, task_postrun, task_prerun from django.db import transaction from redis.exceptions import RedisError @@ -558,6 +558,7 @@ class JobsHealthCheckResult: stale_jobs: IntegrityCheckResult running_job_snapshots: IntegrityCheckResult + zombie_streams: IntegrityCheckResult def _run_stale_jobs_check() -> IntegrityCheckResult: @@ -635,6 +636,91 @@ async def _snapshot_all() -> None: return IntegrityCheckResult(checked=len(running_jobs), fixed=0, unfixable=errors) +def _run_zombie_streams_check() -> IntegrityCheckResult: + """Drain NATS streams that outlived their Django Job. + + Defense-in-depth for the cleanup-on-cancel path: a stream whose Job is in + a terminal state (or was deleted) is consuming worker poll cycles for no + reason. The age guard (``Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES``) prevents + races with freshly-dispatched jobs whose NATS stream is created before + ``transaction.on_commit`` persists the Job row. + + Observations-only for healthy in-flight jobs; only drains when both + conditions hold: + + * Job is ``None`` or in :meth:`JobState.final_states` + * Stream's NATS-reported ``created`` timestamp is older than the threshold + + ``checked`` counts job-shaped streams inspected; ``fixed`` counts those + actually drained; ``unfixable`` counts per-stream drain failures. + """ + from ami.jobs.models import Job, JobState + + threshold = datetime.timedelta(minutes=Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES) + now = datetime.datetime.now() + + async def _drain_all() -> tuple[int, int, int]: + async with TaskQueueManager() as manager: + snapshots = await manager.list_job_stream_snapshots() + if not snapshots: + return 0, 0, 0 + + job_ids = [s["job_id"] for s in snapshots] + jobs_by_id = await sync_to_async( + lambda ids: {j.pk: j for j in Job.objects.filter(pk__in=ids).only("pk", "status")} + )(job_ids) + + checked = len(snapshots) + drained = 0 + errored = 0 + for snap in snapshots: + created = snap["created"] + age = now - created if created else threshold + datetime.timedelta(minutes=1) + if age < threshold: + continue + job = jobs_by_id.get(snap["job_id"]) + job_status = job.status if job else None + if job is not None and JobState(job_status) not in JobState.final_states(): + continue + status_label = str(job_status) if job else "missing" + try: + consumer_deleted = await manager.delete_consumer(snap["job_id"]) + stream_deleted = await manager.delete_stream(snap["job_id"]) + except Exception: + errored += 1 + logger.exception("Failed draining zombie NATS stream for job %s", snap["job_id"]) + continue + if stream_deleted: + drained += 1 + age_hours = age.total_seconds() / 3600.0 + logger.info( + "Drained zombie NATS stream %s (status=%s, age=%.1fh, redelivered=%s, consumer_deleted=%s)", + snap["stream_name"], + status_label, + age_hours, + snap["num_redelivered"], + consumer_deleted, + ) + else: + errored += 1 + return checked, drained, errored + + try: + checked, drained, errored = async_to_sync(_drain_all)() + except Exception: + logger.exception("zombie_streams check: connection/setup failed") + return IntegrityCheckResult(checked=0, fixed=0, unfixable=1) + + log_fn = logger.warning if errored else logger.info + log_fn( + "zombie_streams check: %d stream(s) inspected, %d drained, %d error(s)", + checked, + drained, + errored, + ) + return IntegrityCheckResult(checked=checked, fixed=drained, unfixable=errored) + + def _safe_run_sub_check(name: str, fn: Callable[[], IntegrityCheckResult]) -> IntegrityCheckResult: """Run one umbrella sub-check, returning an ``unfixable=1`` sentinel on failure. @@ -664,6 +750,7 @@ def jobs_health_check() -> dict: result = JobsHealthCheckResult( stale_jobs=_safe_run_sub_check("stale_jobs", _run_stale_jobs_check), running_job_snapshots=_safe_run_sub_check("running_job_snapshots", _run_running_job_snapshot_check), + zombie_streams=_safe_run_sub_check("zombie_streams", _run_zombie_streams_check), ) return dataclasses.asdict(result) diff --git a/ami/jobs/tests/test_periodic_beat_tasks.py b/ami/jobs/tests/test_periodic_beat_tasks.py index 383b3af81..7805cd00b 100644 --- a/ami/jobs/tests/test_periodic_beat_tasks.py +++ b/ami/jobs/tests/test_periodic_beat_tasks.py @@ -36,6 +36,10 @@ def _stub_manager(self, mock_manager_cls) -> AsyncMock: instance.__aenter__ = AsyncMock(return_value=instance) instance.__aexit__ = AsyncMock(return_value=False) instance.log_consumer_stats_snapshot = AsyncMock() + # Zombie-stream sub-check defaults: no streams to inspect, no drains. + instance.list_job_stream_snapshots = AsyncMock(return_value=[]) + instance.delete_consumer = AsyncMock(return_value=True) + instance.delete_stream = AsyncMock(return_value=True) return instance def test_reports_both_sub_check_results(self, mock_manager_cls, _mock_cleanup): @@ -50,6 +54,7 @@ def test_reports_both_sub_check_results(self, mock_manager_cls, _mock_cleanup): { "stale_jobs": {"checked": 2, "fixed": 2, "unfixable": 0}, "running_job_snapshots": _empty_check_dict(), + "zombie_streams": _empty_check_dict(), }, ) @@ -63,6 +68,7 @@ def test_idle_deployment_returns_all_zeros(self, mock_manager_cls, _mock_cleanup { "stale_jobs": _empty_check_dict(), "running_job_snapshots": _empty_check_dict(), + "zombie_streams": _empty_check_dict(), }, ) @@ -173,3 +179,123 @@ def __init__(self, task_id): # checked == 2 (both stale), fixed == 2 (one per branch), unfixable == 0 self.assertEqual(result["stale_jobs"], {"checked": 2, "fixed": 2, "unfixable": 0}) + + def test_zombie_stream_drained_when_job_is_terminal_and_old(self, mock_manager_cls, _mock_cleanup): + """An old stream whose Job is in a final state should be drained.""" + import datetime + + terminal_job = Job.objects.create(project=self.project, name="zombie owner", status=JobState.SUCCESS) + instance = self._stub_manager(mock_manager_cls) + old_ts = datetime.datetime.now() - datetime.timedelta(minutes=Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES + 5) + instance.list_job_stream_snapshots = AsyncMock( + return_value=[ + { + "job_id": terminal_job.pk, + "stream_name": f"job_{terminal_job.pk}", + "created": old_ts, + "messages": 0, + "num_redelivered": 7, + } + ] + ) + + result = jobs_health_check() + + self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 1, "unfixable": 0}) + instance.delete_consumer.assert_awaited_once_with(terminal_job.pk) + instance.delete_stream.assert_awaited_once_with(terminal_job.pk) + + def test_zombie_stream_drained_when_job_is_missing_and_old(self, mock_manager_cls, _mock_cleanup): + """An old stream whose Job row no longer exists should be drained.""" + import datetime + + instance = self._stub_manager(mock_manager_cls) + old_ts = datetime.datetime.now() - datetime.timedelta(minutes=Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES + 1) + instance.list_job_stream_snapshots = AsyncMock( + return_value=[ + { + "job_id": 987654, # no Job row with this pk + "stream_name": "job_987654", + "created": old_ts, + "messages": 3, + "num_redelivered": 0, + } + ] + ) + + result = jobs_health_check() + + self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 1, "unfixable": 0}) + instance.delete_stream.assert_awaited_once_with(987654) + + def test_zombie_stream_not_drained_when_below_age_threshold(self, mock_manager_cls, _mock_cleanup): + """A fresh stream for a terminal job must NOT be drained (on_commit race guard).""" + import datetime + + terminal_job = Job.objects.create(project=self.project, name="fresh zombie?", status=JobState.FAILURE) + instance = self._stub_manager(mock_manager_cls) + fresh_ts = datetime.datetime.now() - datetime.timedelta(minutes=1) + instance.list_job_stream_snapshots = AsyncMock( + return_value=[ + { + "job_id": terminal_job.pk, + "stream_name": f"job_{terminal_job.pk}", + "created": fresh_ts, + "messages": 0, + "num_redelivered": 0, + } + ] + ) + + result = jobs_health_check() + + self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 0, "unfixable": 0}) + instance.delete_stream.assert_not_awaited() + + def test_zombie_stream_not_drained_when_job_still_running(self, mock_manager_cls, _mock_cleanup): + """An old stream for a still-running job must NOT be drained.""" + import datetime + + running_job = Job.objects.create(project=self.project, name="still running", status=JobState.STARTED) + instance = self._stub_manager(mock_manager_cls) + old_ts = datetime.datetime.now() - datetime.timedelta(minutes=Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES + 10) + instance.list_job_stream_snapshots = AsyncMock( + return_value=[ + { + "job_id": running_job.pk, + "stream_name": f"job_{running_job.pk}", + "created": old_ts, + "messages": 5, + "num_redelivered": 0, + } + ] + ) + + result = jobs_health_check() + + self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 0, "unfixable": 0}) + instance.delete_stream.assert_not_awaited() + + def test_zombie_stream_drain_failure_counts_as_unfixable(self, mock_manager_cls, _mock_cleanup): + """A drain that raises should be counted as unfixable without crashing the umbrella.""" + import datetime + + terminal_job = Job.objects.create(project=self.project, name="unfixable", status=JobState.SUCCESS) + instance = self._stub_manager(mock_manager_cls) + old_ts = datetime.datetime.now() - datetime.timedelta(minutes=Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES + 2) + instance.list_job_stream_snapshots = AsyncMock( + return_value=[ + { + "job_id": terminal_job.pk, + "stream_name": f"job_{terminal_job.pk}", + "created": old_ts, + "messages": 0, + "num_redelivered": 0, + } + ] + ) + instance.delete_consumer = AsyncMock(side_effect=RuntimeError("nats error")) + + result = jobs_health_check() + + self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 0, "unfixable": 1}) diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index c5ec8705c..23db93a06 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -11,8 +11,10 @@ """ import asyncio +import datetime import json import logging +import re import nats from asgiref.sync import sync_to_async @@ -56,6 +58,27 @@ async def get_connection(nats_url: str) -> tuple[nats.NATS, JetStreamContext]: ADVISORY_STREAM_NAME = "advisories" # Shared stream for max delivery advisories across all jobs +def _parse_nats_timestamp(raw: str) -> datetime.datetime: + """Parse an RFC3339-ish NATS timestamp, tolerating sub-microsecond precision. + + NATS servers emit nanoseconds (``...20494325Z``); Python's ``fromisoformat`` + rejects anything beyond 6 fractional digits, so we truncate before parsing. + Returns a naive datetime in local time to match the rest of the codebase + (``settings.USE_TZ = False``). + """ + cleaned = raw.rstrip("Z") + if "." in cleaned: + head, frac = cleaned.split(".", 1) + cleaned = f"{head}.{frac[:6]}" + parsed = datetime.datetime.fromisoformat(cleaned) + # NATS emits UTC; attach UTC tzinfo if none is present, then convert to the + # local zone and drop tzinfo to match the naive-local datetimes used + # throughout the codebase (``settings.USE_TZ = False``). + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=datetime.timezone.utc) + return parsed.astimezone().replace(tzinfo=None) + + class TaskQueueManager: """ Manager for NATS JetStream task queue operations. @@ -573,6 +596,79 @@ async def delete_stream(self, job_id: int) -> bool: await self.log_async(logging.ERROR, f"Failed to delete NATS stream for job '{job_id}': {e}") return False + async def list_job_stream_snapshots(self) -> list[dict]: + """Return a snapshot of every ``job_{N}`` stream currently in JetStream. + + Each entry: ``{"job_id": int, "stream_name": str, "created": datetime, + "messages": int, "num_redelivered": int | None}``. ``num_redelivered`` + is pulled from the matching consumer when present and is ``None`` when + the consumer has already been removed (stream-only zombies are still + worth reporting). + + Uses the raw ``$JS.API.STREAM.LIST`` endpoint because + ``JetStreamContext.streams_info`` in the currently pinned nats.py drops + the server-side ``created`` timestamp from :class:`StreamInfo` — we need + it here to age zombies out with a safety margin. + """ + if self.nc is None or self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + snapshots: list[dict] = [] + offset = 0 + # $JS.API.STREAM.LIST pages at 256 streams per response; loop so a + # deployment with a long tail of zombies is still fully enumerated. + while True: + resp = await asyncio.wait_for( + self.nc.request("$JS.API.STREAM.LIST", json.dumps({"offset": offset}).encode()), + timeout=NATS_JETSTREAM_TIMEOUT, + ) + payload = json.loads(resp.data) + streams = payload.get("streams") or [] + if not streams: + break + for stream in streams: + config = stream.get("config") or {} + name = config.get("name") or "" + match = re.match(r"^job_(\d+)$", name) + if not match: + continue + job_id = int(match.group(1)) + created_raw = stream.get("created") + try: + created = _parse_nats_timestamp(created_raw) if created_raw else None + except ValueError: + created = None + state = stream.get("state") or {} + snapshots.append( + { + "job_id": job_id, + "stream_name": name, + "created": created, + "messages": int(state.get("messages") or 0), + "num_redelivered": await self._consumer_redelivered_count(job_id), + } + ) + total = int(payload.get("total") or 0) + offset += len(streams) + if offset >= total: + break + return snapshots + + async def _consumer_redelivered_count(self, job_id: int) -> int | None: + """Return ``num_redelivered`` from the job's consumer, or ``None`` if gone.""" + if self.js is None: + return None + stream_name = self._get_stream_name(job_id) + consumer_name = self._get_consumer_name(job_id) + try: + info = await asyncio.wait_for( + self.js.consumer_info(stream_name, consumer_name), + timeout=NATS_JETSTREAM_TIMEOUT, + ) + except Exception: + return None + return getattr(info, "num_redelivered", None) + async def _setup_advisory_stream(self): """Ensure the shared advisory stream exists to capture max-delivery events. From 1e24d512b170e60b61f88e395efea6e5149e3ea8 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 15 Apr 2026 17:17:21 -0700 Subject: [PATCH 2/2] fix(jobs): address review feedback on zombie stream drain - Invert created=None safety guard: skip (unfixable) rather than drain when a stream's created timestamp is missing or unparseable, preserving the intent of the age guard for unknown-age streams - Raise nats.errors.Error on NATS STREAM.LIST error payloads so an outage surfaces as unfixable=1 instead of masking as "zero streams" - Defer _consumer_redelivered_count to drain candidates only via a new populate_redelivered_counts() helper, reducing O(N) NATS round-trips to O(candidates) per beat tick Co-Authored-By: Claude --- ami/jobs/tasks.py | 30 +++++++++- ami/jobs/tests/test_periodic_beat_tasks.py | 45 +++++++++++++++ ami/ml/orchestration/nats_queue.py | 36 ++++++++++-- ami/ml/orchestration/tests/test_nats_queue.py | 55 +++++++++++++++++++ 4 files changed, 159 insertions(+), 7 deletions(-) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 4eeab7f3d..f3989594b 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -673,16 +673,44 @@ async def _drain_all() -> tuple[int, int, int]: checked = len(snapshots) drained = 0 errored = 0 + + # First pass: classify each snapshot. Streams with a missing + # created timestamp are skipped (unfixable) rather than drained — + # the age guard exists precisely to protect streams whose Job row + # hasn't committed yet, so an unparseable timestamp must be treated + # as "unknown age / unsafe to drain". + ages: dict[int, datetime.timedelta] = {} + candidates: list[dict] = [] for snap in snapshots: created = snap["created"] - age = now - created if created else threshold + datetime.timedelta(minutes=1) + if created is None: + logger.warning( + "Skipping zombie drain for stream %s: created timestamp missing", + snap["stream_name"], + ) + errored += 1 + continue + age = now - created if age < threshold: continue job = jobs_by_id.get(snap["job_id"]) job_status = job.status if job else None if job is not None and JobState(job_status) not in JobState.final_states(): continue + ages[snap["job_id"]] = age + candidates.append(snap) + + # Populate num_redelivered only for drain candidates to avoid an + # O(N) NATS round-trip on every beat tick for all historical streams. + if candidates: + await manager.populate_redelivered_counts(candidates) + + # Second pass: drain each candidate. + for snap in candidates: + job = jobs_by_id.get(snap["job_id"]) + job_status = job.status if job else None status_label = str(job_status) if job else "missing" + age = ages[snap["job_id"]] try: consumer_deleted = await manager.delete_consumer(snap["job_id"]) stream_deleted = await manager.delete_stream(snap["job_id"]) diff --git a/ami/jobs/tests/test_periodic_beat_tasks.py b/ami/jobs/tests/test_periodic_beat_tasks.py index 7805cd00b..9e1592872 100644 --- a/ami/jobs/tests/test_periodic_beat_tasks.py +++ b/ami/jobs/tests/test_periodic_beat_tasks.py @@ -38,6 +38,7 @@ def _stub_manager(self, mock_manager_cls) -> AsyncMock: instance.log_consumer_stats_snapshot = AsyncMock() # Zombie-stream sub-check defaults: no streams to inspect, no drains. instance.list_job_stream_snapshots = AsyncMock(return_value=[]) + instance.populate_redelivered_counts = AsyncMock(return_value=None) instance.delete_consumer = AsyncMock(return_value=True) instance.delete_stream = AsyncMock(return_value=True) return instance @@ -299,3 +300,47 @@ def test_zombie_stream_drain_failure_counts_as_unfixable(self, mock_manager_cls, result = jobs_health_check() self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 0, "unfixable": 1}) + + def test_zombie_stream_skipped_when_created_is_none(self, mock_manager_cls, _mock_cleanup): + """A snapshot with created=None must be skipped (unfixable), never drained. + + The age guard exists to protect streams whose Job row hasn't committed yet. + A missing/unparseable timestamp means we cannot determine age, so the safe + default is to leave the stream alone and report it as unfixable. + """ + instance = self._stub_manager(mock_manager_cls) + instance.list_job_stream_snapshots = AsyncMock( + return_value=[ + { + "job_id": 111222, + "stream_name": "job_111222", + "created": None, + "messages": 2, + "num_redelivered": None, + } + ] + ) + + result = jobs_health_check() + + self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 0, "unfixable": 1}) + instance.delete_consumer.assert_not_awaited() + instance.delete_stream.assert_not_awaited() + + def test_list_job_stream_snapshots_raises_on_nats_error_payload(self, mock_manager_cls, _mock_cleanup): + """list_job_stream_snapshots must raise (not return []) when the NATS server responds with an error. + + Returning [] would mask an outage — the caller would silently interpret + "zero zombies" when NATS is actually unavailable. + """ + import nats.errors + + instance = self._stub_manager(mock_manager_cls) + instance.list_job_stream_snapshots = AsyncMock( + side_effect=nats.errors.Error("NATS STREAM.LIST error 503: no responders available for request") + ) + + result = jobs_health_check() + + # _safe_run_sub_check catches the exception and records unfixable=1. + self.assertEqual(result["zombie_streams"], {"checked": 0, "fixed": 0, "unfixable": 1}) diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index 23db93a06..119d28a8c 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -599,16 +599,20 @@ async def delete_stream(self, job_id: int) -> bool: async def list_job_stream_snapshots(self) -> list[dict]: """Return a snapshot of every ``job_{N}`` stream currently in JetStream. - Each entry: ``{"job_id": int, "stream_name": str, "created": datetime, - "messages": int, "num_redelivered": int | None}``. ``num_redelivered`` - is pulled from the matching consumer when present and is ``None`` when - the consumer has already been removed (stream-only zombies are still - worth reporting). + Each entry: ``{"job_id": int, "stream_name": str, "created": datetime | None, + "messages": int, "num_redelivered": None}``. ``num_redelivered`` is always + ``None`` here — call :meth:`populate_redelivered_counts` on the subset of + interest (e.g. drain candidates) to fill it in, avoiding O(N) consumer-info + round-trips for every stream on every beat tick. Uses the raw ``$JS.API.STREAM.LIST`` endpoint because ``JetStreamContext.streams_info`` in the currently pinned nats.py drops the server-side ``created`` timestamp from :class:`StreamInfo` — we need it here to age zombies out with a safety margin. + + Raises :class:`nats.errors.Error` if the NATS server responds with an error + payload (e.g. 503 no-responders), so the caller can surface it as an outage + rather than silently treating it as "zero streams." """ if self.nc is None or self.js is None: raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") @@ -623,6 +627,11 @@ async def list_job_stream_snapshots(self) -> list[dict]: timeout=NATS_JETSTREAM_TIMEOUT, ) payload = json.loads(resp.data) + if payload.get("error"): + err = payload["error"] + raise nats.errors.Error( + f"NATS STREAM.LIST error {err.get('code')}: {err.get('description', 'unknown')}" + ) streams = payload.get("streams") or [] if not streams: break @@ -645,7 +654,7 @@ async def list_job_stream_snapshots(self) -> list[dict]: "stream_name": name, "created": created, "messages": int(state.get("messages") or 0), - "num_redelivered": await self._consumer_redelivered_count(job_id), + "num_redelivered": None, } ) total = int(payload.get("total") or 0) @@ -654,6 +663,21 @@ async def list_job_stream_snapshots(self) -> list[dict]: break return snapshots + async def populate_redelivered_counts(self, snapshots: list[dict], concurrency: int = 8) -> None: + """Fill in ``num_redelivered`` in-place for the given snapshot dicts. + + Fetches consumer info concurrently (bounded by *concurrency*) so that + callers can limit per-consumer round-trips to a filtered subset rather + than fetching for all streams on every beat tick. + """ + sem = asyncio.Semaphore(concurrency) + + async def _fetch_one(snap: dict) -> None: + async with sem: + snap["num_redelivered"] = await self._consumer_redelivered_count(snap["job_id"]) + + await asyncio.gather(*(_fetch_one(s) for s in snapshots)) + async def _consumer_redelivered_count(self, job_id: int) -> int | None: """Return ``num_redelivered`` from the job's consumer, or ``None`` if gone.""" if self.js is None: diff --git a/ami/ml/orchestration/tests/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py index 9c35a4dae..d92e9be0b 100644 --- a/ami/ml/orchestration/tests/test_nats_queue.py +++ b/ami/ml/orchestration/tests/test_nats_queue.py @@ -544,3 +544,58 @@ async def test_no_job_logger_falls_back_to_module_logger_only(self): async with TaskQueueManager() as manager: # no job_logger passed # Must not raise. await manager.publish_task(1, self._create_sample_task()) + + async def test_list_job_stream_snapshots_raises_on_nats_error_payload(self): + """list_job_stream_snapshots must raise nats.errors.Error when the NATS + server returns an error payload instead of a stream list. + + Returning [] in this case would mask an outage — the caller would + silently interpret "zero zombies" while NATS is actually unavailable. + """ + nc, js = self._create_mock_nats_connection() + + error_payload = json.dumps( + {"error": {"code": 503, "description": "no responders available for request"}} + ).encode() + mock_response = MagicMock() + mock_response.data = error_payload + nc.request = AsyncMock(return_value=mock_response) + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + with self.assertRaises(nats.errors.Error): + await manager.list_job_stream_snapshots() + + async def test_list_job_stream_snapshots_returns_none_for_num_redelivered(self): + """list_job_stream_snapshots returns num_redelivered=None for all snapshots. + + Per-consumer info is deferred to populate_redelivered_counts() so that + the O(N) fetch only runs for drain candidates, not every stream. + """ + nc, js = self._create_mock_nats_connection() + + stream_payload = json.dumps( + { + "total": 1, + "streams": [ + { + "config": {"name": "job_42"}, + "created": "2024-01-01T00:00:00Z", + "state": {"messages": 0}, + } + ], + } + ).encode() + mock_response = MagicMock() + mock_response.data = stream_payload + nc.request = AsyncMock(return_value=mock_response) + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + snapshots = await manager.list_job_stream_snapshots() + + self.assertEqual(len(snapshots), 1) + self.assertEqual(snapshots[0]["job_id"], 42) + self.assertIsNone(snapshots[0]["num_redelivered"]) + # consumer_info must NOT have been called — no redelivered fetch during list + js.consumer_info.assert_not_called()