Skip to content

Commit

Permalink
Resolve internal warnings for TestLocalTaskJob and TestSigTermOnRunner
Browse files Browse the repository at this point in the history
  • Loading branch information
Owen-CH-Leung committed Apr 10, 2024
1 parent 87acf61 commit 7976130
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 73 deletions.
13 changes: 0 additions & 13 deletions tests/deprecations_ignore.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,6 @@
- tests/jobs/test_backfill_job.py::TestBackfillJob::test_subdag_clear_parentdag_downstream_clear
- tests/jobs/test_backfill_job.py::TestBackfillJob::test_update_counters
- tests/jobs/test_backfill_job.py::TestBackfillJob::test_backfilling_dags
- tests/jobs/test_local_task_job.py::TestLocalTaskJob::test_dagrun_timeout_logged_in_task_logs
- tests/jobs/test_local_task_job.py::TestLocalTaskJob::test_failure_callback_called_by_airflow_run_raw_process
- tests/jobs/test_local_task_job.py::TestLocalTaskJob::test_fast_follow
- tests/jobs/test_local_task_job.py::TestLocalTaskJob::test_heartbeat_failed_fast
- tests/jobs/test_local_task_job.py::TestLocalTaskJob::test_local_task_return_code_metric
- tests/jobs/test_local_task_job.py::TestLocalTaskJob::test_localtaskjob_double_trigger
- tests/jobs/test_local_task_job.py::TestLocalTaskJob::test_localtaskjob_maintain_heart_rate
- tests/jobs/test_local_task_job.py::TestLocalTaskJob::test_mark_failure_on_failure_callback
- tests/jobs/test_local_task_job.py::TestLocalTaskJob::test_mark_success_no_kill
- tests/jobs/test_local_task_job.py::TestLocalTaskJob::test_mark_success_on_success_callback
- tests/jobs/test_local_task_job.py::TestLocalTaskJob::test_mini_scheduler_works_with_wait_for_upstream
- tests/jobs/test_local_task_job.py::TestLocalTaskJob::test_process_os_signal_calls_on_failure_callback
- tests/jobs/test_local_task_job.py::TestSigtermOnRunner::test_process_sigterm_works_with_retries
- tests/jobs/test_scheduler_job.py::TestSchedulerJob::test_adopt_or_reset_orphaned_tasks
- tests/jobs/test_scheduler_job.py::TestSchedulerJob::test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run
- tests/jobs/test_scheduler_job.py::TestSchedulerJob::test_dagrun_deadlock_ignore_depends_on_past
Expand Down
147 changes: 87 additions & 60 deletions tests/jobs/test_local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import os
import re
import signal
import tempfile
import threading
import time
import uuid
import warnings
from unittest import mock
from unittest.mock import patch

Expand Down Expand Up @@ -60,6 +60,7 @@
pytestmark = pytest.mark.db_test

DEFAULT_DATE = timezone.datetime(2016, 1, 1)
DEFAULT_LOGICAL_DATE = timezone.coerce_datetime(DEFAULT_DATE)
TEST_DAG_FOLDER = os.environ["AIRFLOW__CORE__DAGS_FOLDER"]


Expand Down Expand Up @@ -293,13 +294,14 @@ def test_heartbeat_failed_fast(self):
task_id = "test_heartbeat_failed_fast_op"
dag = self.dagbag.get_dag(dag_id)
task = dag.get_task(task_id)

data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
dr = dag.create_dagrun(
run_id="test_heartbeat_failed_fast_run",
state=State.RUNNING,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
session=session,
data_interval=data_interval,
)

ti = dr.task_instances[0]
Expand Down Expand Up @@ -327,11 +329,13 @@ def test_mark_success_no_kill(self, caplog, get_test_dag, session):
the task to fail, and that the task exits
"""
dag = get_test_dag("test_mark_state")
data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
dr = dag.create_dagrun(
state=State.RUNNING,
execution_date=DEFAULT_DATE,
run_type=DagRunType.SCHEDULED,
session=session,
data_interval=data_interval,
)
task = dag.get_task(task_id="test_mark_success_no_kill")

Expand All @@ -352,6 +356,7 @@ def test_mark_success_no_kill(self, caplog, get_test_dag, session):
def test_localtaskjob_double_trigger(self):
dag = self.dagbag.dags.get("test_localtaskjob_double_trigger")
task = dag.get_task("test_localtaskjob_double_trigger_task")
data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)

session = settings.Session()

Expand All @@ -362,6 +367,7 @@ def test_localtaskjob_double_trigger(self):
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
session=session,
data_interval=data_interval,
)

ti = dr.get_task_instance(task_id=task.task_id, session=session)
Expand All @@ -388,9 +394,10 @@ def test_localtaskjob_double_trigger(self):
@patch.object(StandardTaskRunner, "return_code")
@mock.patch("airflow.jobs.scheduler_job_runner.Stats.incr", autospec=True)
def test_local_task_return_code_metric(self, mock_stats_incr, mock_return_code, create_dummy_dag):
_, task = create_dummy_dag("test_localtaskjob_code")
dag, task = create_dummy_dag("test_localtaskjob_code")
dag_run = dag.get_last_dagrun()

ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti_run = TaskInstance(task=task, run_id=dag_run.run_id)
ti_run.refresh_from_db()
job1 = Job(dag_id=ti_run.dag_id, executor=SequentialExecutor())
job_runner = LocalTaskJobRunner(job=job1, task_instance=ti_run)
Expand Down Expand Up @@ -418,9 +425,10 @@ def test_local_task_return_code_metric(self, mock_stats_incr, mock_return_code,

@patch.object(StandardTaskRunner, "return_code")
def test_localtaskjob_maintain_heart_rate(self, mock_return_code, caplog, create_dummy_dag):
_, task = create_dummy_dag("test_localtaskjob_double_trigger")
dag, task = create_dummy_dag("test_localtaskjob_double_trigger")
dag_run = dag.get_last_dagrun()

ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti_run = TaskInstance(task=task, run_id=dag_run.run_id)
ti_run.refresh_from_db()
job1 = Job(dag_id=ti_run.dag_id, executor=SequentialExecutor())
job_runner = LocalTaskJobRunner(job=job1, task_instance=ti_run)
Expand Down Expand Up @@ -453,12 +461,14 @@ def test_mark_failure_on_failure_callback(self, caplog, get_test_dag):
the task, and executes on_failure_callback
"""
dag = get_test_dag("test_mark_state")
data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
with create_session() as session:
dr = dag.create_dagrun(
state=State.RUNNING,
execution_date=DEFAULT_DATE,
run_type=DagRunType.SCHEDULED,
session=session,
data_interval=data_interval,
)
task = dag.get_task(task_id="test_mark_failure_externally")
ti = dr.get_task_instance(task.task_id)
Expand All @@ -484,13 +494,15 @@ def test_dagrun_timeout_logged_in_task_logs(self, caplog, get_test_dag):
"""
dag = get_test_dag("test_mark_state")
dag.dagrun_timeout = datetime.timedelta(microseconds=1)
data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
with create_session() as session:
dr = dag.create_dagrun(
state=State.RUNNING,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
run_type=DagRunType.SCHEDULED,
session=session,
data_interval=data_interval,
)
task = dag.get_task(task_id="test_mark_skipped_externally")
ti = dr.get_task_instance(task.task_id)
Expand All @@ -515,15 +527,17 @@ def test_failure_callback_called_by_airflow_run_raw_process(self, monkeypatch, t
callback_file.touch()
monkeypatch.setenv("AIRFLOW_CALLBACK_FILE", str(callback_file))
dag = get_test_dag("test_on_failure_callback")
data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
with create_session() as session:
dag.create_dagrun(
dr = dag.create_dagrun(
state=State.RUNNING,
execution_date=DEFAULT_DATE,
run_type=DagRunType.SCHEDULED,
session=session,
data_interval=data_interval,
)
task = dag.get_task(task_id="test_on_failure_callback_task")
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti = TaskInstance(task=task, run_id=dr.run_id)
ti.refresh_from_db()

job1 = Job(executor=SequentialExecutor(), dag_id=ti.dag_id)
Expand All @@ -546,12 +560,14 @@ def test_mark_success_on_success_callback(self, caplog, get_test_dag):
on_success_callback gets executed
"""
dag = get_test_dag("test_mark_state")
data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
with create_session() as session:
dr = dag.create_dagrun(
state=State.RUNNING,
execution_date=DEFAULT_DATE,
run_type=DagRunType.SCHEDULED,
session=session,
data_interval=data_interval,
)
task = dag.get_task(task_id="test_mark_success_no_kill")

Expand Down Expand Up @@ -583,15 +599,18 @@ def test_process_os_signal_calls_on_failure_callback(
# callback_file will be created by the task: bash_sleep
monkeypatch.setenv("AIRFLOW_CALLBACK_FILE", str(callback_file))
dag = get_test_dag("test_on_failure_callback")
data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
with create_session() as session:
dag.create_dagrun(
state=State.RUNNING,
execution_date=DEFAULT_DATE,
run_type=DagRunType.SCHEDULED,
session=session,
data_interval=data_interval,
)
task = dag.get_task(task_id="bash_sleep")
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
dag_run = dag.get_last_dagrun()
ti = TaskInstance(task=task, run_id=dag_run.run_id)
ti.refresh_from_db()

signal_sent_status = {"sent": False}
Expand Down Expand Up @@ -724,7 +743,7 @@ def test_fast_follow(
session.merge(ti)
ti_by_task_id[task_id] = ti

ti = TaskInstance(task=dag.get_task(task_ids_to_run[0]), execution_date=dag_run.execution_date)
ti = TaskInstance(task=dag.get_task(task_ids_to_run[0]), run_id=dag_run.run_id)
ti.refresh_from_db()
job1 = Job(executor=SequentialExecutor(), dag_id=ti.dag_id)
job_runner = LocalTaskJobRunner(job=job1, task_instance=ti, ignore_ti_state=True)
Expand All @@ -733,9 +752,7 @@ def test_fast_follow(
run_job(job=job1, execute_callable=job_runner._execute)
self.validate_ti_states(dag_run, first_run_state, error_message)
if second_run_state:
ti = TaskInstance(
task=dag.get_task(task_ids_to_run[1]), execution_date=dag_run.execution_date
)
ti = TaskInstance(task=dag.get_task(task_ids_to_run[1]), run_id=dag_run.run_id)
ti.refresh_from_db()
job2 = Job(dag_id=ti.dag_id, executor=SequentialExecutor())
job_runner = LocalTaskJobRunner(job=job2, task_instance=ti, ignore_ti_state=True)
Expand All @@ -748,12 +765,18 @@ def test_fast_follow(
@conf_vars({("scheduler", "schedule_after_task_execution"): "True"})
def test_mini_scheduler_works_with_wait_for_upstream(self, caplog, get_test_dag):
dag = get_test_dag("test_dagrun_fast_follow")
data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
dag.catchup = False
SerializedDagModel.write_dag(dag)

dr = dag.create_dagrun(run_id="test_1", state=State.RUNNING, execution_date=DEFAULT_DATE)
dr = dag.create_dagrun(
run_id="test_1", state=State.RUNNING, execution_date=DEFAULT_DATE, data_interval=data_interval
)
dr2 = dag.create_dagrun(
run_id="test_2", state=State.RUNNING, execution_date=DEFAULT_DATE + datetime.timedelta(hours=1)
run_id="test_2",
state=State.RUNNING,
execution_date=DEFAULT_DATE + datetime.timedelta(hours=1),
data_interval=data_interval,
)
task_k = dag.get_task("K")
task_l = dag.get_task("L")
Expand Down Expand Up @@ -870,9 +893,7 @@ class TestSigtermOnRunner:
pytest.param("spawn", 30, id="spawn"),
],
)
def test_process_sigterm_works_with_retries(
self, mp_method, wait_timeout, daemon, clear_db, request, capfd
):
def test_process_sigterm_works_with_retries(self, mp_method, wait_timeout, daemon, clear_db):
"""Test that ensures that task runner sets tasks to retry when task runner receive SIGTERM."""
mp_context = mp.get_context(mp_method)

Expand All @@ -886,53 +907,53 @@ def test_process_sigterm_works_with_retries(
execution_date = DEFAULT_DATE
run_id = f"test-{execution_date.date().isoformat()}"

# Run LocalTaskJob in separate process
proc = mp_context.Process(
target=self._sigterm_local_task_runner,
args=(dag_id, task_id, run_id, execution_date, task_started, retry_callback_called),
name="LocalTaskJob-TestProcess",
daemon=daemon,
)
proc.start()

try:
with timeout(wait_timeout, "Timeout during waiting start LocalTaskJob"):
while task_started.value == 0:
time.sleep(0.2)
os.kill(proc.pid, signal.SIGTERM)

with timeout(wait_timeout, "Timeout during waiting callback"):
while retry_callback_called.value == 0:
time.sleep(0.2)
finally:
proc.kill()

assert retry_callback_called.value == 1
# Internally callback finished before TaskInstance commit changes in DB (as of Jan 2022).
# So we can't easily check TaskInstance.state without any race conditions drawbacks,
# and fact that process with LocalTaskJob could be already killed.
# We could add state validation (`UP_FOR_RETRY`) if callback mechanism changed.

pytest_capture = request.config.option.capture
if pytest_capture == "no":
# Since we run `LocalTaskJob` in the separate process we can grab ut easily by `caplog`.
# However, we could grab it from stdout/stderr but only if `-s` flag set, see:
# https://github.com/pytest-dev/pytest/issues/5997
captured = capfd.readouterr()
with tempfile.NamedTemporaryFile() as tmpfile:
# Run LocalTaskJob in separate process
proc = mp_context.Process(
target=self._sigterm_local_task_runner,
args=(
tmpfile.name,
dag_id,
task_id,
run_id,
execution_date,
task_started,
retry_callback_called,
),
name="LocalTaskJob-TestProcess",
daemon=daemon,
)
proc.start()

try:
with timeout(wait_timeout, "Timeout during waiting start LocalTaskJob"):
while task_started.value == 0:
time.sleep(0.2)
os.kill(proc.pid, signal.SIGTERM)

with timeout(wait_timeout, "Timeout during waiting callback"):
while retry_callback_called.value == 0:
time.sleep(0.2)
finally:
proc.kill()

assert retry_callback_called.value == 1
# Internally callback finished before TaskInstance commit changes in DB (as of Jan 2022).
# So we can't easily check TaskInstance.state without any race conditions drawbacks,
# and fact that process with LocalTaskJob could be already killed.
# We could add state validation (`UP_FOR_RETRY`) if callback mechanism changed.

captured = tmpfile.read().decode()
for msg in [
"Received SIGTERM. Terminating subprocesses",
"Task exited with return code 143",
]:
assert msg in captured.out or msg in captured.err
else:
warnings.warn(
f"Skip test logs in stdout/stderr when capture enabled: {pytest_capture}, "
f"please pass `-s` option.",
UserWarning,
)
# assert msg in captured.out or msg in captured.err
assert msg in captured

@staticmethod
def _sigterm_local_task_runner(
tmpfile_path,
dag_id,
task_id,
run_id,
Expand Down Expand Up @@ -963,9 +984,15 @@ def task_function():
retries=1,
on_retry_callback=retry_callback,
)
logger = logging.getLogger()
tmpfile_handler = logging.FileHandler(tmpfile_path)
logger.addHandler(tmpfile_handler)

dag.create_dagrun(state=State.RUNNING, run_id=run_id, execution_date=execution_date)
ti = TaskInstance(task=task, execution_date=execution_date)
data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
dag_run = dag.create_dagrun(
state=State.RUNNING, run_id=run_id, execution_date=execution_date, data_interval=data_interval
)
ti = TaskInstance(task=task, run_id=dag_run.run_id)
ti.refresh_from_db()
job = Job(executor=SequentialExecutor(), dag_id=ti.dag_id)
job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True)
Expand Down

0 comments on commit 7976130

Please sign in to comment.