Skip to content

Commit

Permalink
Give specific messages if job was killed due to SIGTERM or SIGKILL (a…
Browse files Browse the repository at this point in the history
…nsible#12435)

* Reap jobs on dispatcher startup to increase clarity, replace existing reaping logic

* Exit jobs if receiving SIGTERM signal

* Fix unwanted reaping on shutdown, let subprocess close out

* Add some sanity tests for signal module

* Add a log for an unhandled dispatcher error

* Refine wording of error messages

Co-authored-by: Elijah DeLee <kdelee@redhat.com>
  • Loading branch information
2 people authored and shanemcd committed Aug 8, 2022
1 parent 32a2a3b commit 88ff62e
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 6 deletions.
21 changes: 21 additions & 0 deletions awx/main/dispatch/reaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,27 @@
logger = logging.getLogger('awx.main.dispatch')


def startup_reaping():
"""
If this particular instance is starting, then we know that any running jobs are invalid
so we will reap those jobs as a special action here
"""
me = Instance.objects.me()
jobs = UnifiedJob.objects.filter(status='running', controller_node=me.hostname)
job_ids = []
for j in jobs:
job_ids.append(j.id)
j.status = 'failed'
j.start_args = ''
j.job_explanation += 'Task was marked as running at system start up. The system must have not shut down properly, so it has been marked as failed.'
j.save(update_fields=['status', 'start_args', 'job_explanation'])
if hasattr(j, 'send_notification_templates'):
j.send_notification_templates('failed')
j.websocket_emit_status('failed')
if job_ids:
logger.error(f'Unified jobs {job_ids} were reaped on dispatch startup')


def reap_job(j, status):
if UnifiedJob.objects.get(id=j.id).status not in ('running', 'waiting'):
# just in case, don't reap jobs that aren't running
Expand Down
9 changes: 7 additions & 2 deletions awx/main/dispatch/worker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,9 @@ def run(self, *args, **kwargs):
logger.exception(f"Error consuming new events from postgres, will retry for {self.pg_max_wait} s")
self.pg_down_time = time.time()
self.pg_is_down = True
if time.time() - self.pg_down_time > self.pg_max_wait:
logger.warning(f"Postgres event consumer has not recovered in {self.pg_max_wait} s, exiting")
current_downtime = time.time() - self.pg_down_time
if current_downtime > self.pg_max_wait:
logger.exception(f"Postgres event consumer has not recovered in {current_downtime} s, exiting")
raise
# Wait for a second before next attempt, but still listen for any shutdown signals
for i in range(10):
Expand All @@ -179,6 +180,10 @@ def run(self, *args, **kwargs):
time.sleep(0.1)
for conn in db.connections.all():
conn.close_if_unusable_or_obsolete()
except Exception:
# Log unanticipated exception in addition to writing to stderr to get timestamps and other metadata
logger.exception('Encountered unhandled error in dispatcher main loop')
raise


class BaseWorker(object):
Expand Down
3 changes: 2 additions & 1 deletion awx/main/management/commands/run_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from django.core.management.base import BaseCommand
from django.db import connection as django_connection

from awx.main.dispatch import get_local_queuename
from awx.main.dispatch import get_local_queuename, reaper
from awx.main.dispatch.control import Control
from awx.main.dispatch.pool import AutoscalePool
from awx.main.dispatch.worker import AWXConsumerPG, TaskWorker
Expand Down Expand Up @@ -53,6 +53,7 @@ def handle(self, *arg, **options):
# (like the node heartbeat)
periodic.run_continuously()

reaper.startup_reaping()
consumer = None

try:
Expand Down
9 changes: 8 additions & 1 deletion awx/main/tasks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from awx.main.constants import MINIMAL_EVENTS
from awx.main.utils.update_model import update_model
from awx.main.queue import CallbackQueueDispatcher
from awx.main.tasks.signals import signal_callback

logger = logging.getLogger('awx.main.tasks.callback')

Expand Down Expand Up @@ -148,7 +149,13 @@ def cancel_callback(self):
Ansible runner callback to tell the job when/if it is canceled
"""
unified_job_id = self.instance.pk
self.instance = self.update_model(unified_job_id)
if signal_callback():
return True
try:
self.instance = self.update_model(unified_job_id)
except Exception:
logger.exception(f'Encountered error during cancel check for {unified_job_id}, canceling now')
return True
if not self.instance:
logger.error('unified job {} was deleted while running, canceling'.format(unified_job_id))
return True
Expand Down
9 changes: 8 additions & 1 deletion awx/main/tasks/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
RunnerCallbackForProjectUpdate,
RunnerCallbackForSystemJob,
)
from awx.main.tasks.signals import with_signal_handling, signal_callback
from awx.main.tasks.receptor import AWXReceptorJob
from awx.main.exceptions import AwxTaskError, PostRunError, ReceptorNodeNotFound
from awx.main.utils.ansible import read_ansible_config
Expand Down Expand Up @@ -395,6 +396,7 @@ def final_run_hook(self, instance, status, private_data_dir, fact_modification_t
instance.save(update_fields=['ansible_version'])

@with_path_cleanup
@with_signal_handling
def run(self, pk, **kwargs):
"""
Run the job/task and capture its output.
Expand Down Expand Up @@ -427,7 +429,7 @@ def run(self, pk, **kwargs):
private_data_dir = self.build_private_data_dir(self.instance)
self.pre_run_hook(self.instance, private_data_dir)
self.instance.log_lifecycle("preparing_playbook")
if self.instance.cancel_flag:
if self.instance.cancel_flag or signal_callback():
self.instance = self.update_model(self.instance.pk, status='canceled')
if self.instance.status != 'running':
# Stop the task chain and prevent starting the job if it has
Expand Down Expand Up @@ -555,6 +557,11 @@ def run(self, pk, **kwargs):
# ensure failure notification sends even if playbook_on_stats event is not triggered
handle_success_and_failure_notifications.apply_async([self.instance.id])

elif status == 'canceled':
self.instance = self.update_model(pk)
if (getattr(self.instance, 'cancel_flag', False) is False) and signal_callback():
self.runner_callback.delay_update(job_explanation="Task was canceled due to receiving a shutdown signal.")
status = 'failed'
except ReceptorNodeNotFound as exc:
extra_update_fields['job_explanation'] = str(exc)
except Exception:
Expand Down
63 changes: 63 additions & 0 deletions awx/main/tasks/signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import signal
import functools
import logging


logger = logging.getLogger('awx.main.tasks.signals')


__all__ = ['with_signal_handling', 'signal_callback']


class SignalState:
def reset(self):
self.sigterm_flag = False
self.is_active = False
self.original_sigterm = None
self.original_sigint = None

def __init__(self):
self.reset()

def set_flag(self, *args):
"""Method to pass into the python signal.signal method to receive signals"""
self.sigterm_flag = True

def connect_signals(self):
self.original_sigterm = signal.getsignal(signal.SIGTERM)
self.original_sigint = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGTERM, self.set_flag)
signal.signal(signal.SIGINT, self.set_flag)
self.is_active = True

def restore_signals(self):
signal.signal(signal.SIGTERM, self.original_sigterm)
signal.signal(signal.SIGINT, self.original_sigint)
self.reset()


signal_state = SignalState()


def signal_callback():
return signal_state.sigterm_flag


def with_signal_handling(f):
"""
Change signal handling to make signal_callback return True in event of SIGTERM or SIGINT.
"""

@functools.wraps(f)
def _wrapped(*args, **kwargs):
try:
this_is_outermost_caller = False
if not signal_state.is_active:
signal_state.connect_signals()
this_is_outermost_caller = True
return f(*args, **kwargs)
finally:
if this_is_outermost_caller:
signal_state.restore_signals()

return _wrapped
50 changes: 50 additions & 0 deletions awx/main/tests/unit/tasks/test_signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import signal

from awx.main.tasks.signals import signal_state, signal_callback, with_signal_handling


def test_outer_inner_signal_handling():
"""
Even if the flag is set in the outer context, its value should persist in the inner context
"""

@with_signal_handling
def f2():
assert signal_callback()

@with_signal_handling
def f1():
assert signal_callback() is False
signal_state.set_flag()
assert signal_callback()
f2()

original_sigterm = signal.getsignal(signal.SIGTERM)
assert signal_callback() is False
f1()
assert signal_callback() is False
assert signal.getsignal(signal.SIGTERM) is original_sigterm


def test_inner_outer_signal_handling():
"""
Even if the flag is set in the inner context, its value should persist in the outer context
"""

@with_signal_handling
def f2():
assert signal_callback() is False
signal_state.set_flag()
assert signal_callback()

@with_signal_handling
def f1():
assert signal_callback() is False
f2()
assert signal_callback()

original_sigterm = signal.getsignal(signal.SIGTERM)
assert signal_callback() is False
f1()
assert signal_callback() is False
assert signal.getsignal(signal.SIGTERM) is original_sigterm
7 changes: 6 additions & 1 deletion awx/main/utils/update_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
import time

from awx.main.tasks.signals import signal_callback


logger = logging.getLogger('awx.main.tasks.utils')

Expand Down Expand Up @@ -37,7 +39,10 @@ def update_model(model, pk, _attempt=0, _max_attempts=5, **updates):
# Attempt to retry the update, assuming we haven't already
# tried too many times.
if _attempt < _max_attempts:
time.sleep(5)
for i in range(5):
time.sleep(1)
if signal_callback():
raise RuntimeError(f'Could not fetch {pk} because of receiving abort signal')
return update_model(model, pk, _attempt=_attempt + 1, _max_attempts=_max_attempts, **updates)
else:
logger.warning(f'Failed to update {model._meta.object_name} pk={pk} after {_attempt} retries.')
Expand Down

0 comments on commit 88ff62e

Please sign in to comment.