diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index 7ed2f3fc02314..279fe411cefe2 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -29,7 +29,6 @@ from importlib import import_module from typing import Any, Callable, Dict, KeysView, List, NamedTuple, Optional, Tuple -import psutil from setproctitle import setproctitle # pylint: disable=no-name-in-module from sqlalchemy import or_ from tabulate import tabulate @@ -46,7 +45,7 @@ from airflow.utils import timezone from airflow.utils.file import list_py_file_paths from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.process_utils import reap_process_group +from airflow.utils.process_utils import kill_child_processes_by_pids, reap_process_group from airflow.utils.session import provide_session from airflow.utils.state import State @@ -1136,35 +1135,7 @@ def end(self): """ pids_to_kill = self.get_all_pids() if pids_to_kill: - # First try SIGTERM - this_process = psutil.Process(os.getpid()) - # Only check child processes to ensure that we don't have a case - # where we kill the wrong process because a child process died - # but the PID got reused. - child_processes = [x for x in this_process.children(recursive=True) - if x.is_running() and x.pid in pids_to_kill] - for child in child_processes: - self.log.info("Terminating child PID: %s", child.pid) - child.terminate() - # TODO: Remove magic number - timeout = 5 - self.log.info("Waiting up to %s seconds for processes to exit...", timeout) - try: - psutil.wait_procs( - child_processes, timeout=timeout, - callback=lambda x: self.log.info('Terminated PID %s', x.pid)) - except psutil.TimeoutExpired: - self.log.debug("Ran out of time while waiting for processes to exit") - - # Then SIGKILL - child_processes = [x for x in this_process.children(recursive=True) - if x.is_running() and x.pid in pids_to_kill] - if child_processes: - self.log.info("SIGKILL processes that did not terminate gracefully") - for child in child_processes: - self.log.info("Killing child PID: %s", child.pid) - child.kill() - child.wait() + kill_child_processes_by_pids(pids_to_kill) def emit_metrics(self): """ diff --git a/airflow/utils/process_utils.py b/airflow/utils/process_utils.py index fcec493258409..7b473b50c08a4 100644 --- a/airflow/utils/process_utils.py +++ b/airflow/utils/process_utils.py @@ -139,3 +139,48 @@ def execute_in_subprocess(cmd: List[str]): exit_code = proc.wait() if exit_code != 0: raise subprocess.CalledProcessError(exit_code, cmd) + + +def kill_child_processes_by_pids(pids_to_kill: List[int], timeout: int = 5) -> None: + """ + Kills child processes for the current process. +  + First, it sends the SIGTERM signal, and after the time specified by the `timeout` parameter, sends + the SIGKILL signal, if the process is still alive. + + :param pids_to_kill: List of PID to be killed. + :type pids_to_kill: List[int] + :param timeout: The time to wait before sending the SIGKILL signal. + :type timeout: Optional[int] + """ + this_process = psutil.Process(os.getpid()) + # Only check child processes to ensure that we don't have a case + # where we kill the wrong process because a child process died + # but the PID got reused. + child_processes = [ + x for x in this_process.children(recursive=True) if x.is_running() and x.pid in pids_to_kill + ] + + # First try SIGTERM + for child in child_processes: + log.info("Terminating child PID: %s", child.pid) + child.terminate() + + log.info("Waiting up to %s seconds for processes to exit...", timeout) + try: + psutil.wait_procs( + child_processes, timeout=timeout, callback=lambda x: log.info("Terminated PID %s", x.pid) + ) + except psutil.TimeoutExpired: + log.debug("Ran out of time while waiting for processes to exit") + + # Then SIGKILL + child_processes = [ + x for x in this_process.children(recursive=True) if x.is_running() and x.pid in pids_to_kill + ] + if child_processes: + log.info("SIGKILL processes that did not terminate gracefully") + for child in child_processes: + log.info("Killing child PID: %s", child.pid) + child.kill() + child.wait() diff --git a/tests/utils/test_process_utils.py b/tests/utils/test_process_utils.py index 5d5bcab3d351a..3b89282f21a94 100644 --- a/tests/utils/test_process_utils.py +++ b/tests/utils/test_process_utils.py @@ -20,9 +20,11 @@ import multiprocessing import os import signal +import subprocess import time import unittest from subprocess import CalledProcessError +from time import sleep import psutil @@ -104,3 +106,47 @@ def test_should_print_all_messages1(self): def test_should_raise_exception(self): with self.assertRaises(CalledProcessError): process_utils.execute_in_subprocess(["bash", "-c", "exit 1"]) + + +def my_sleep_subprocess(): + sleep(100) + + +def my_sleep_subprocess_with_signals(): + signal.signal(signal.SIGINT, lambda signum, frame: None) + signal.signal(signal.SIGTERM, lambda signum, frame: None) + sleep(100) + + +class TestKillChildProcessesByPids(unittest.TestCase): + def test_should_kill_process(self): + before_num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n") + + process = multiprocessing.Process(target=my_sleep_subprocess, args=()) + process.start() + sleep(0) + + num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n") + self.assertEqual(before_num_process + 1, num_process) + + process_utils.kill_child_processes_by_pids([process.pid]) + + num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n") + self.assertEqual(before_num_process, num_process) + + def test_should_force_kill_process(self): + before_num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n") + + process = multiprocessing.Process(target=my_sleep_subprocess_with_signals, args=()) + process.start() + sleep(0) + + num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n") + self.assertEqual(before_num_process + 1, num_process) + + with self.assertLogs(process_utils.log) as cm: + process_utils.kill_child_processes_by_pids([process.pid], timeout=0) + self.assertTrue(any("Killing child PID" in line for line in cm.output)) + + num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n") + self.assertEqual(before_num_process, num_process)