Skip to content

Commit

Permalink
Enable "airflow tasks test" to run deferrable operator (apache#37542)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W authored and abhishekbhakat committed Mar 5, 2024
1 parent f4c3921 commit a26e2ae
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 18 deletions.
25 changes: 21 additions & 4 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Task sub-commands."""
from __future__ import annotations

import functools
import importlib
import json
import logging
Expand All @@ -34,13 +35,13 @@
from airflow import settings
from airflow.cli.simple_table import AirflowConsole
from airflow.configuration import conf
from airflow.exceptions import AirflowException, DagRunNotFound, TaskInstanceNotFound
from airflow.exceptions import AirflowException, DagRunNotFound, TaskDeferred, TaskInstanceNotFound
from airflow.executors.executor_loader import ExecutorLoader
from airflow.jobs.job import Job, run_job
from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
from airflow.listeners.listener import get_listener_manager
from airflow.models import DagPickle, TaskInstance
from airflow.models.dag import DAG
from airflow.models.dag import DAG, _run_inline_trigger
from airflow.models.dagrun import DagRun
from airflow.models.operator import needs_expansion
from airflow.models.param import ParamsDict
Expand Down Expand Up @@ -588,7 +589,8 @@ def format_task_instance(ti: TaskInstance) -> dict[str, str]:


@cli_utils.action_cli(check_db=False)
def task_test(args, dag: DAG | None = None) -> None:
@provide_session
def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> None:
"""Test task for a given dag_id."""
# We want to log output from operators etc to show up here. Normally
# airflow.task would redirect to a file, but here we want it to propagate
Expand Down Expand Up @@ -632,7 +634,22 @@ def task_test(args, dag: DAG | None = None) -> None:
if args.dry_run:
ti.dry_run()
else:
ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True)
ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True, raise_on_defer=True)
except TaskDeferred as defer:
ti.defer_task(defer=defer, session=session)
log.info("[TASK TEST] running trigger in line")

event = _run_inline_trigger(defer.trigger)
ti.next_method = defer.method_name
ti.next_kwargs = {"event": event.payload} if event else defer.kwargs

execute_callable = getattr(task, ti.next_method)
if ti.next_kwargs:
execute_callable = functools.partial(execute_callable, **ti.next_kwargs)
context = ti.get_template_context(ignore_param_exceptions=False)
execute_callable(context)

log.info("[TASK TEST] Trigger completed")
except Exception:
if args.post_mortem:
debugger = _guess_debugger()
Expand Down
8 changes: 4 additions & 4 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4057,12 +4057,12 @@ def get_current_dag(cls) -> DAG | None:
return None


def _run_trigger(trigger):
async def _run_trigger_main():
def _run_inline_trigger(trigger):
async def _run_inline_trigger_main():
async for event in trigger.run():
return event

return asyncio.run(_run_trigger_main())
return asyncio.run(_run_inline_trigger_main())


def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session: Session):
Expand All @@ -4083,7 +4083,7 @@ def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session: Sessio
break
except TaskDeferred as e:
log.info("[DAG TEST] running trigger in line")
event = _run_trigger(e.trigger)
event = _run_inline_trigger(e.trigger)
ti.next_method = e.method_name
ti.next_kwargs = {"event": event.payload} if event else e.kwargs
log.info("[DAG TEST] Trigger completed")
Expand Down
17 changes: 13 additions & 4 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2378,7 +2378,7 @@ def _run_raw_task(
# a trigger.
if raise_on_defer:
raise
self._defer_task(defer=defer, session=session)
self.defer_task(defer=defer, session=session)
self.log.info(
"Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s",
self.dag_id,
Expand Down Expand Up @@ -2565,8 +2565,11 @@ def _execute_task(self, context, task_orig):
return _execute_task(self, context, task_orig)

@provide_session
def _defer_task(self, session: Session, defer: TaskDeferred) -> None:
"""Mark the task as deferred and sets up the trigger that is needed to resume it."""
def defer_task(self, session: Session, defer: TaskDeferred) -> None:
"""Mark the task as deferred and sets up the trigger that is needed to resume it.
:meta: private
"""
from airflow.models.trigger import Trigger

# First, make the trigger entry
Expand Down Expand Up @@ -2625,6 +2628,7 @@ def run(
job_id: str | None = None,
pool: str | None = None,
session: Session = NEW_SESSION,
raise_on_defer: bool = False,
) -> None:
"""Run TaskInstance."""
res = self.check_and_change_state_before_execution(
Expand All @@ -2644,7 +2648,12 @@ def run(
return

self._run_raw_task(
mark_success=mark_success, test_mode=test_mode, job_id=job_id, pool=pool, session=session
mark_success=mark_success,
test_mode=test_mode,
job_id=job_id,
pool=pool,
session=session,
raise_on_defer=raise_on_defer,
)

def dry_run(self) -> None:
Expand Down
8 changes: 4 additions & 4 deletions tests/cli/commands/test_dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from airflow.exceptions import AirflowException
from airflow.models import DagBag, DagModel, DagRun
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import _run_trigger
from airflow.models.dag import _run_inline_trigger
from airflow.models.serialized_dag import SerializedDagModel
from airflow.triggers.base import TriggerEvent
from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
Expand Down Expand Up @@ -878,15 +878,15 @@ def test_dag_test_with_custom_timetable(self, mock__get_or_create_dagrun, _):
dag_command.dag_test(cli_args)
assert "data_interval" in mock__get_or_create_dagrun.call_args.kwargs

def test_dag_test_run_trigger(self, dag_maker):
def test_dag_test_run_inline_trigger(self, dag_maker):
now = timezone.utcnow()
trigger = DateTimeTrigger(moment=now)
e = _run_trigger(trigger)
e = _run_inline_trigger(trigger)
assert isinstance(e, TriggerEvent)
assert e.payload == now

def test_dag_test_no_triggerer_running(self, dag_maker):
with mock.patch("airflow.models.dag._run_trigger", wraps=_run_trigger) as mock_run:
with mock.patch("airflow.models.dag._run_inline_trigger", wraps=_run_inline_trigger) as mock_run:
with dag_maker() as dag:

@task
Expand Down
21 changes: 21 additions & 0 deletions tests/cli/commands/test_task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,27 @@ def test_cli_test_with_env_vars(self):
assert "foo=bar" in output
assert "AIRFLOW_TEST_MODE=True" in output

@pytest.mark.asyncio
@mock.patch("airflow.triggers.file.os.path.getmtime", return_value=0)
@mock.patch("airflow.triggers.file.glob", return_value=["/tmp/test"])
@mock.patch("airflow.triggers.file.os.path.isfile", return_value=True)
@mock.patch("airflow.sensors.filesystem.FileSensor.poke", return_value=False)
def test_cli_test_with_deferrable_operator(self, mock_pock, mock_is_file, mock_glob, mock_getmtime):
with redirect_stdout(StringIO()) as stdout:
task_command.task_test(
self.parser.parse_args(
[
"tasks",
"test",
"example_sensors",
"wait_for_file_async",
DEFAULT_DATE.isoformat(),
]
)
)
output = stdout.getvalue()
assert "wait_for_file_async completed successfully as /tmp/temporary_file_for_testing found" in output

@pytest.mark.parametrize(
"option",
[
Expand Down
4 changes: 2 additions & 2 deletions tests/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def test_trigger_lifecycle(session):
class TestTriggerRunner:
@pytest.mark.asyncio
@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.set_individual_trigger_logging")
async def test_run_trigger_canceled(self, session) -> None:
async def test_run_inline_trigger_canceled(self, session) -> None:
trigger_runner = TriggerRunner()
trigger_runner.triggers = {1: {"task": MagicMock(), "name": "mock_name", "events": 0}}
mock_trigger = MagicMock()
Expand All @@ -278,7 +278,7 @@ async def test_run_trigger_canceled(self, session) -> None:

@pytest.mark.asyncio
@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.set_individual_trigger_logging")
async def test_run_trigger_timeout(self, session, caplog) -> None:
async def test_run_inline_trigger_timeout(self, session, caplog) -> None:
trigger_runner = TriggerRunner()
trigger_runner.triggers = {1: {"task": MagicMock(), "name": "mock_name", "events": 0}}
mock_trigger = MagicMock()
Expand Down

0 comments on commit a26e2ae

Please sign in to comment.