Skip to content

Commit

Permalink
[AIRFLOW-6956] Extract kill_child_processes_by_pids from DagFileProce…
Browse files Browse the repository at this point in the history
…ssorManager
  • Loading branch information
Kamil Breguła committed Mar 2, 2020
1 parent 1d16de7 commit 7d83211
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 31 deletions.
33 changes: 2 additions & 31 deletions airflow/utils/dag_processing.py
Expand Up @@ -29,7 +29,6 @@
from importlib import import_module
from typing import Any, Callable, Dict, KeysView, List, NamedTuple, Optional, Tuple

import psutil
from setproctitle import setproctitle # pylint: disable=no-name-in-module
from sqlalchemy import or_
from tabulate import tabulate
Expand All @@ -46,7 +45,7 @@
from airflow.utils import timezone
from airflow.utils.file import list_py_file_paths
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.process_utils import reap_process_group
from airflow.utils.process_utils import kill_child_processes_by_pids, reap_process_group
from airflow.utils.session import provide_session
from airflow.utils.state import State

Expand Down Expand Up @@ -1136,35 +1135,7 @@ def end(self):
"""
pids_to_kill = self.get_all_pids()
if pids_to_kill:
# First try SIGTERM
this_process = psutil.Process(os.getpid())
# Only check child processes to ensure that we don't have a case
# where we kill the wrong process because a child process died
# but the PID got reused.
child_processes = [x for x in this_process.children(recursive=True)
if x.is_running() and x.pid in pids_to_kill]
for child in child_processes:
self.log.info("Terminating child PID: %s", child.pid)
child.terminate()
# TODO: Remove magic number
timeout = 5
self.log.info("Waiting up to %s seconds for processes to exit...", timeout)
try:
psutil.wait_procs(
child_processes, timeout=timeout,
callback=lambda x: self.log.info('Terminated PID %s', x.pid))
except psutil.TimeoutExpired:
self.log.debug("Ran out of time while waiting for processes to exit")

# Then SIGKILL
child_processes = [x for x in this_process.children(recursive=True)
if x.is_running() and x.pid in pids_to_kill]
if child_processes:
self.log.info("SIGKILL processes that did not terminate gracefully")
for child in child_processes:
self.log.info("Killing child PID: %s", child.pid)
child.kill()
child.wait()
kill_child_processes_by_pids(pids_to_kill)

def emit_metrics(self):
"""
Expand Down
45 changes: 45 additions & 0 deletions airflow/utils/process_utils.py
Expand Up @@ -139,3 +139,48 @@ def execute_in_subprocess(cmd: List[str]):
exit_code = proc.wait()
if exit_code != 0:
raise subprocess.CalledProcessError(exit_code, cmd)


def kill_child_processes_by_pids(pids_to_kill: List[int], timeout: int = 5) -> None:
"""
Kills child processes for the current process.
 
First, it sends the SIGTERM signal, and after the time specified by the `timeout` parameter, sends
the SIGKILL signal, if the process is still alive.
:param pids_to_kill: List of PID to be killed.
:type pids_to_kill: List[int]
:param timeout: The time to wait before sending the SIGKILL signal.
:type timeout: Optional[int]
"""
this_process = psutil.Process(os.getpid())
# Only check child processes to ensure that we don't have a case
# where we kill the wrong process because a child process died
# but the PID got reused.
child_processes = [
x for x in this_process.children(recursive=True) if x.is_running() and x.pid in pids_to_kill
]

# First try SIGTERM
for child in child_processes:
log.info("Terminating child PID: %s", child.pid)
child.terminate()

log.info("Waiting up to %s seconds for processes to exit...", timeout)
try:
psutil.wait_procs(
child_processes, timeout=timeout, callback=lambda x: log.info("Terminated PID %s", x.pid)
)
except psutil.TimeoutExpired:
log.debug("Ran out of time while waiting for processes to exit")

# Then SIGKILL
child_processes = [
x for x in this_process.children(recursive=True) if x.is_running() and x.pid in pids_to_kill
]
if child_processes:
log.info("SIGKILL processes that did not terminate gracefully")
for child in child_processes:
log.info("Killing child PID: %s", child.pid)
child.kill()
child.wait()
46 changes: 46 additions & 0 deletions tests/utils/test_process_utils.py
Expand Up @@ -20,9 +20,11 @@
import multiprocessing
import os
import signal
import subprocess
import time
import unittest
from subprocess import CalledProcessError
from time import sleep

import psutil

Expand Down Expand Up @@ -104,3 +106,47 @@ def test_should_print_all_messages1(self):
def test_should_raise_exception(self):
with self.assertRaises(CalledProcessError):
process_utils.execute_in_subprocess(["bash", "-c", "exit 1"])


def my_sleep_subprocess():
sleep(100)


def my_sleep_subprocess_with_signals():
signal.signal(signal.SIGINT, lambda signum, frame: None)
signal.signal(signal.SIGTERM, lambda signum, frame: None)
sleep(100)


class TestKillChildProcessesByPids(unittest.TestCase):
def test_should_kill_process(self):
before_num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n")

process = multiprocessing.Process(target=my_sleep_subprocess, args=())
process.start()
sleep(0)

num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n")
self.assertEqual(before_num_process + 1, num_process)

process_utils.kill_child_processes_by_pids([process.pid])

num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n")
self.assertEqual(before_num_process, num_process)

def test_should_force_kill_process(self):
before_num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n")

process = multiprocessing.Process(target=my_sleep_subprocess_with_signals, args=())
process.start()
sleep(0)

num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n")
self.assertEqual(before_num_process + 1, num_process)

with self.assertLogs(process_utils.log) as cm:
process_utils.kill_child_processes_by_pids([process.pid], timeout=0)
self.assertTrue(any("Killing child PID" in line for line in cm.output))

num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n")
self.assertEqual(before_num_process, num_process)

0 comments on commit 7d83211

Please sign in to comment.