Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ami/jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,13 @@ class Job(BaseModel):
# N minutes". 10 is conservative; raise if legitimate long-running jobs get
# reaped.
STALLED_JOBS_MAX_MINUTES = 10
# Zombie-stream reaper: age threshold above which a NATS stream for a job
# in a terminal state (or missing from Django) is considered safe to drop.
# Kept well above :attr:`STALLED_JOBS_MAX_MINUTES` so newly-dispatched jobs
# whose stream was created before ``transaction.on_commit`` saved the Job
# row do not get reaped. Tighten only if ``cleanup-on-cancel`` misses are
# still stranding consumer poll cycles after this safety net lands.
ZOMBIE_STREAMS_MAX_AGE_MINUTES = STALLED_JOBS_MAX_MINUTES * 6

name = models.CharField(max_length=255)
queue = models.CharField(max_length=255, default="default")
Expand Down
117 changes: 116 additions & 1 deletion ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Callable
from typing import TYPE_CHECKING

from asgiref.sync import async_to_sync
from asgiref.sync import async_to_sync, sync_to_async
from celery.signals import task_failure, task_postrun, task_prerun
from django.db import transaction
from redis.exceptions import RedisError
Expand Down Expand Up @@ -558,6 +558,7 @@ class JobsHealthCheckResult:

stale_jobs: IntegrityCheckResult
running_job_snapshots: IntegrityCheckResult
zombie_streams: IntegrityCheckResult


def _run_stale_jobs_check() -> IntegrityCheckResult:
Expand Down Expand Up @@ -635,6 +636,119 @@ async def _snapshot_all() -> None:
return IntegrityCheckResult(checked=len(running_jobs), fixed=0, unfixable=errors)


def _run_zombie_streams_check() -> IntegrityCheckResult:
"""Drain NATS streams that outlived their Django Job.

Defense-in-depth for the cleanup-on-cancel path: a stream whose Job is in
a terminal state (or was deleted) is consuming worker poll cycles for no
reason. The age guard (``Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES``) prevents
races with freshly-dispatched jobs whose NATS stream is created before
``transaction.on_commit`` persists the Job row.

Observations-only for healthy in-flight jobs; only drains when both
conditions hold:

* Job is ``None`` or in :meth:`JobState.final_states`
* Stream's NATS-reported ``created`` timestamp is older than the threshold

``checked`` counts job-shaped streams inspected; ``fixed`` counts those
actually drained; ``unfixable`` counts per-stream drain failures.
"""
from ami.jobs.models import Job, JobState

threshold = datetime.timedelta(minutes=Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES)
now = datetime.datetime.now()

async def _drain_all() -> tuple[int, int, int]:
async with TaskQueueManager() as manager:
snapshots = await manager.list_job_stream_snapshots()
if not snapshots:
return 0, 0, 0

job_ids = [s["job_id"] for s in snapshots]
jobs_by_id = await sync_to_async(
lambda ids: {j.pk: j for j in Job.objects.filter(pk__in=ids).only("pk", "status")}
)(job_ids)

checked = len(snapshots)
drained = 0
errored = 0

# First pass: classify each snapshot. Streams with a missing
# created timestamp are skipped (unfixable) rather than drained —
# the age guard exists precisely to protect streams whose Job row
# hasn't committed yet, so an unparseable timestamp must be treated
# as "unknown age / unsafe to drain".
ages: dict[int, datetime.timedelta] = {}
candidates: list[dict] = []
for snap in snapshots:
created = snap["created"]
if created is None:
logger.warning(
"Skipping zombie drain for stream %s: created timestamp missing",
snap["stream_name"],
)
errored += 1
continue
age = now - created
if age < threshold:
continue
job = jobs_by_id.get(snap["job_id"])
job_status = job.status if job else None
if job is not None and JobState(job_status) not in JobState.final_states():
continue
ages[snap["job_id"]] = age
candidates.append(snap)

# Populate num_redelivered only for drain candidates to avoid an
# O(N) NATS round-trip on every beat tick for all historical streams.
if candidates:
await manager.populate_redelivered_counts(candidates)

# Second pass: drain each candidate.
for snap in candidates:
job = jobs_by_id.get(snap["job_id"])
job_status = job.status if job else None
status_label = str(job_status) if job else "missing"
age = ages[snap["job_id"]]
try:
consumer_deleted = await manager.delete_consumer(snap["job_id"])
stream_deleted = await manager.delete_stream(snap["job_id"])
except Exception:
errored += 1
logger.exception("Failed draining zombie NATS stream for job %s", snap["job_id"])
continue
if stream_deleted:
drained += 1
age_hours = age.total_seconds() / 3600.0
logger.info(
"Drained zombie NATS stream %s (status=%s, age=%.1fh, redelivered=%s, consumer_deleted=%s)",
snap["stream_name"],
status_label,
age_hours,
snap["num_redelivered"],
consumer_deleted,
)
else:
errored += 1
return checked, drained, errored

try:
checked, drained, errored = async_to_sync(_drain_all)()
except Exception:
logger.exception("zombie_streams check: connection/setup failed")
return IntegrityCheckResult(checked=0, fixed=0, unfixable=1)

log_fn = logger.warning if errored else logger.info
log_fn(
"zombie_streams check: %d stream(s) inspected, %d drained, %d error(s)",
checked,
drained,
errored,
)
return IntegrityCheckResult(checked=checked, fixed=drained, unfixable=errored)


def _safe_run_sub_check(name: str, fn: Callable[[], IntegrityCheckResult]) -> IntegrityCheckResult:
"""Run one umbrella sub-check, returning an ``unfixable=1`` sentinel on failure.

Expand Down Expand Up @@ -664,6 +778,7 @@ def jobs_health_check() -> dict:
result = JobsHealthCheckResult(
stale_jobs=_safe_run_sub_check("stale_jobs", _run_stale_jobs_check),
running_job_snapshots=_safe_run_sub_check("running_job_snapshots", _run_running_job_snapshot_check),
zombie_streams=_safe_run_sub_check("zombie_streams", _run_zombie_streams_check),
)
return dataclasses.asdict(result)

Expand Down
171 changes: 171 additions & 0 deletions ami/jobs/tests/test_periodic_beat_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def _stub_manager(self, mock_manager_cls) -> AsyncMock:
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.log_consumer_stats_snapshot = AsyncMock()
# Zombie-stream sub-check defaults: no streams to inspect, no drains.
instance.list_job_stream_snapshots = AsyncMock(return_value=[])
instance.populate_redelivered_counts = AsyncMock(return_value=None)
instance.delete_consumer = AsyncMock(return_value=True)
instance.delete_stream = AsyncMock(return_value=True)
return instance

def test_reports_both_sub_check_results(self, mock_manager_cls, _mock_cleanup):
Expand All @@ -50,6 +55,7 @@ def test_reports_both_sub_check_results(self, mock_manager_cls, _mock_cleanup):
{
"stale_jobs": {"checked": 2, "fixed": 2, "unfixable": 0},
"running_job_snapshots": _empty_check_dict(),
"zombie_streams": _empty_check_dict(),
},
)

Expand All @@ -63,6 +69,7 @@ def test_idle_deployment_returns_all_zeros(self, mock_manager_cls, _mock_cleanup
{
"stale_jobs": _empty_check_dict(),
"running_job_snapshots": _empty_check_dict(),
"zombie_streams": _empty_check_dict(),
},
)

Expand Down Expand Up @@ -173,3 +180,167 @@ def __init__(self, task_id):

# checked == 2 (both stale), fixed == 2 (one per branch), unfixable == 0
self.assertEqual(result["stale_jobs"], {"checked": 2, "fixed": 2, "unfixable": 0})

def test_zombie_stream_drained_when_job_is_terminal_and_old(self, mock_manager_cls, _mock_cleanup):
"""An old stream whose Job is in a final state should be drained."""
import datetime

terminal_job = Job.objects.create(project=self.project, name="zombie owner", status=JobState.SUCCESS)
instance = self._stub_manager(mock_manager_cls)
old_ts = datetime.datetime.now() - datetime.timedelta(minutes=Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES + 5)
instance.list_job_stream_snapshots = AsyncMock(
return_value=[
{
"job_id": terminal_job.pk,
"stream_name": f"job_{terminal_job.pk}",
"created": old_ts,
"messages": 0,
"num_redelivered": 7,
}
]
)

result = jobs_health_check()

self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 1, "unfixable": 0})
instance.delete_consumer.assert_awaited_once_with(terminal_job.pk)
instance.delete_stream.assert_awaited_once_with(terminal_job.pk)

def test_zombie_stream_drained_when_job_is_missing_and_old(self, mock_manager_cls, _mock_cleanup):
"""An old stream whose Job row no longer exists should be drained."""
import datetime

instance = self._stub_manager(mock_manager_cls)
old_ts = datetime.datetime.now() - datetime.timedelta(minutes=Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES + 1)
instance.list_job_stream_snapshots = AsyncMock(
return_value=[
{
"job_id": 987654, # no Job row with this pk
"stream_name": "job_987654",
"created": old_ts,
"messages": 3,
"num_redelivered": 0,
}
]
)

result = jobs_health_check()

self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 1, "unfixable": 0})
instance.delete_stream.assert_awaited_once_with(987654)

def test_zombie_stream_not_drained_when_below_age_threshold(self, mock_manager_cls, _mock_cleanup):
"""A fresh stream for a terminal job must NOT be drained (on_commit race guard)."""
import datetime

terminal_job = Job.objects.create(project=self.project, name="fresh zombie?", status=JobState.FAILURE)
instance = self._stub_manager(mock_manager_cls)
fresh_ts = datetime.datetime.now() - datetime.timedelta(minutes=1)
instance.list_job_stream_snapshots = AsyncMock(
return_value=[
{
"job_id": terminal_job.pk,
"stream_name": f"job_{terminal_job.pk}",
"created": fresh_ts,
"messages": 0,
"num_redelivered": 0,
}
]
)

result = jobs_health_check()

self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 0, "unfixable": 0})
instance.delete_stream.assert_not_awaited()

def test_zombie_stream_not_drained_when_job_still_running(self, mock_manager_cls, _mock_cleanup):
"""An old stream for a still-running job must NOT be drained."""
import datetime

running_job = Job.objects.create(project=self.project, name="still running", status=JobState.STARTED)
instance = self._stub_manager(mock_manager_cls)
old_ts = datetime.datetime.now() - datetime.timedelta(minutes=Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES + 10)
instance.list_job_stream_snapshots = AsyncMock(
return_value=[
{
"job_id": running_job.pk,
"stream_name": f"job_{running_job.pk}",
"created": old_ts,
"messages": 5,
"num_redelivered": 0,
}
]
)

result = jobs_health_check()

self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 0, "unfixable": 0})
instance.delete_stream.assert_not_awaited()

def test_zombie_stream_drain_failure_counts_as_unfixable(self, mock_manager_cls, _mock_cleanup):
"""A drain that raises should be counted as unfixable without crashing the umbrella."""
import datetime

terminal_job = Job.objects.create(project=self.project, name="unfixable", status=JobState.SUCCESS)
instance = self._stub_manager(mock_manager_cls)
old_ts = datetime.datetime.now() - datetime.timedelta(minutes=Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES + 2)
instance.list_job_stream_snapshots = AsyncMock(
return_value=[
{
"job_id": terminal_job.pk,
"stream_name": f"job_{terminal_job.pk}",
"created": old_ts,
"messages": 0,
"num_redelivered": 0,
}
]
)
instance.delete_consumer = AsyncMock(side_effect=RuntimeError("nats error"))

result = jobs_health_check()

self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 0, "unfixable": 1})

def test_zombie_stream_skipped_when_created_is_none(self, mock_manager_cls, _mock_cleanup):
"""A snapshot with created=None must be skipped (unfixable), never drained.

The age guard exists to protect streams whose Job row hasn't committed yet.
A missing/unparseable timestamp means we cannot determine age, so the safe
default is to leave the stream alone and report it as unfixable.
"""
instance = self._stub_manager(mock_manager_cls)
instance.list_job_stream_snapshots = AsyncMock(
return_value=[
{
"job_id": 111222,
"stream_name": "job_111222",
"created": None,
"messages": 2,
"num_redelivered": None,
}
]
)

result = jobs_health_check()

self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 0, "unfixable": 1})
instance.delete_consumer.assert_not_awaited()
instance.delete_stream.assert_not_awaited()

def test_list_job_stream_snapshots_raises_on_nats_error_payload(self, mock_manager_cls, _mock_cleanup):
"""list_job_stream_snapshots must raise (not return []) when the NATS server responds with an error.

Returning [] would mask an outage — the caller would silently interpret
"zero zombies" when NATS is actually unavailable.
"""
import nats.errors

instance = self._stub_manager(mock_manager_cls)
instance.list_job_stream_snapshots = AsyncMock(
side_effect=nats.errors.Error("NATS STREAM.LIST error 503: no responders available for request")
)

result = jobs_health_check()

# _safe_run_sub_check catches the exception and records unfixable=1.
self.assertEqual(result["zombie_streams"], {"checked": 0, "fixed": 0, "unfixable": 1})
Loading
Loading