Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run triggers inline with dag test #34642

Merged
merged 3 commits into from Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
67 changes: 30 additions & 37 deletions airflow/models/dag.py
Expand Up @@ -17,6 +17,8 @@
# under the License.
from __future__ import annotations

import asyncio
import collections
dstandish marked this conversation as resolved.
Show resolved Hide resolved
import collections.abc
import copy
import functools
Expand Down Expand Up @@ -82,11 +84,11 @@
from airflow.exceptions import (
AirflowDagInconsistent,
AirflowException,
AirflowSkipException,
DuplicateTaskIdFound,
FailStopDagInvalidTriggerRule,
ParamValidationError,
RemovedInAirflow3Warning,
TaskDeferred,
TaskNotFound,
)
from airflow.jobs.job import run_job
Expand All @@ -101,7 +103,6 @@
Context,
TaskInstance,
TaskInstanceKey,
TaskReturnCode,
clear_task_instances,
)
from airflow.secrets.local_filesystem import LocalFilesystemBackend
Expand Down Expand Up @@ -285,12 +286,11 @@ def get_dataset_triggered_next_run_info(
}


class _StopDagTest(Exception):
"""
Raise when DAG.test should stop immediately.
def _triggerer_is_healthy():
from airflow.jobs.triggerer_job_runner import TriggererJobRunner

:meta private:
"""
job = TriggererJobRunner.most_recent_job()
return job and job.is_alive()


@functools.total_ordering
Expand Down Expand Up @@ -2844,21 +2844,12 @@ def add_logger_if_needed(ti: TaskInstance):
if not scheduled_tis and ids_unrunnable:
self.log.warning("No tasks to run. unrunnable tasks: %s", ids_unrunnable)
time.sleep(1)
triggerer_running = _triggerer_is_healthy()
for ti in scheduled_tis:
try:
add_logger_if_needed(ti)
ti.task = tasks[ti.task_id]
ret = _run_task(ti, session=session)
if ret is TaskReturnCode.DEFERRED:
if not _triggerer_is_healthy():
raise _StopDagTest(
"Task has deferred but triggerer component is not running. "
"You can start the triggerer by running `airflow triggerer` in a terminal."
)
except _StopDagTest:
# Let this exception bubble out and not be swallowed by the
# except block below.
raise
_run_task(ti=ti, inline_trigger=not triggerer_running, session=session)
except Exception:
self.log.exception("Task failed; ti=%s", ti)
if conn_file_path or variable_file_path:
Expand Down Expand Up @@ -3988,14 +3979,15 @@ def get_current_dag(cls) -> DAG | None:
return None


def _triggerer_is_healthy():
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
def _run_trigger(trigger):
async def _run_trigger_main():
async for event in trigger.run():
return event

job = TriggererJobRunner.most_recent_job()
return job and job.is_alive()
return asyncio.run(_run_trigger_main())


def _run_task(ti: TaskInstance, session) -> TaskReturnCode | None:
def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session: Session):
"""
Run a single task instance, and push result to Xcom for downstream tasks.

Expand All @@ -4005,20 +3997,21 @@ def _run_task(ti: TaskInstance, session) -> TaskReturnCode | None:
Args:
ti: TaskInstance to run
"""
ret = None
log.info("*****************************************************")
if ti.map_index > 0:
log.info("Running task %s index %d", ti.task_id, ti.map_index)
else:
log.info("Running task %s", ti.task_id)
try:
ret = ti._run_raw_task(session=session)
session.flush()
log.info("%s ran successfully!", ti.task_id)
except AirflowSkipException:
log.info("Task Skipped, continuing")
log.info("*****************************************************")
return ret
log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, ti.map_index)
while True:
try:
log.info("[DAG TEST] running task %s", ti)
ti._run_raw_task(session=session, raise_on_defer=inline_trigger)
break
except TaskDeferred as e:
log.info("[DAG TEST] running trigger in line")
event = _run_trigger(e.trigger)
ti.next_method = e.method_name
ti.next_kwargs = {"event": event.payload} if event else e.kwargs
log.info("[DAG TEST] Trigger completed")
session.merge(ti)
session.commit()
log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id, ti.map_index)


def _get_or_create_dagrun(
Expand Down
3 changes: 3 additions & 0 deletions airflow/models/taskinstance.py
Expand Up @@ -2207,6 +2207,7 @@ def _run_raw_task(
test_mode: bool = False,
job_id: str | None = None,
pool: str | None = None,
raise_on_defer: bool = False,
session: Session = NEW_SESSION,
) -> TaskReturnCode | None:
"""
Expand Down Expand Up @@ -2261,6 +2262,8 @@ def _run_raw_task(
except TaskDeferred as defer:
# The task has signalled it wants to defer execution based on
# a trigger.
if raise_on_defer:
raise
self._defer_task(defer=defer, session=session)
self.log.info(
"Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s",
Expand Down
81 changes: 47 additions & 34 deletions tests/cli/commands/test_dag_command.py
Expand Up @@ -37,9 +37,10 @@
from airflow.exceptions import AirflowException
from airflow.models import DagBag, DagModel, DagRun
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import _StopDagTest
from airflow.models.dag import _run_trigger
from airflow.models.serialized_dag import SerializedDagModel
from airflow.triggers.temporal import TimeDeltaTrigger
from airflow.triggers.base import TriggerEvent
from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.types import DagRunType
Expand Down Expand Up @@ -824,35 +825,47 @@ def test_dag_test_with_custom_timetable(self, mock__get_or_create_dagrun, _):
dag_command.dag_test(cli_args)
assert "data_interval" in mock__get_or_create_dagrun.call_args.kwargs

def test_dag_test_no_triggerer(self, dag_maker):
with dag_maker() as dag:

@task
def one():
return 1

@task
def two(val):
return val + 1

class MyOp(BaseOperator):
template_fields = ("tfield",)

def __init__(self, tfield, **kwargs):
self.tfield = tfield
super().__init__(**kwargs)

def execute(self, context, event=None):
if event is None:
print("I AM DEFERRING")
self.defer(trigger=TimeDeltaTrigger(timedelta(seconds=20)), method_name="execute")
return
print("RESUMING")
return self.tfield + 1

task_one = one()
task_two = two(task_one)
op = MyOp(task_id="abc", tfield=str(task_two))
task_two >> op
with pytest.raises(_StopDagTest, match="Task has deferred but triggerer component is not running"):
dag.test()
def test_dag_test_run_trigger(self, dag_maker):
now = timezone.utcnow()
trigger = DateTimeTrigger(moment=now)
e = _run_trigger(trigger)
assert isinstance(e, TriggerEvent)
assert e.payload == now

def test_dag_test_no_triggerer_running(self, dag_maker):
with mock.patch("airflow.models.dag._run_trigger", wraps=_run_trigger) as mock_run:
with dag_maker() as dag:

@task
def one():
return 1

@task
def two(val):
return val + 1

trigger = TimeDeltaTrigger(timedelta(seconds=0))

class MyOp(BaseOperator):
template_fields = ("tfield",)

def __init__(self, tfield, **kwargs):
self.tfield = tfield
super().__init__(**kwargs)

def execute(self, context, event=None):
if event is None:
print("I AM DEFERRING")
self.defer(trigger=trigger, method_name="execute")
return
print("RESUMING")
return self.tfield + 1

task_one = one()
task_two = two(task_one)
op = MyOp(task_id="abc", tfield=task_two)
task_two >> op
dr = dag.test()
assert mock_run.call_args_list[0] == ((trigger,), {})
tis = dr.get_task_instances()
assert [x for x in tis if x.task_id == "abc"][0].state == "success"
2 changes: 1 addition & 1 deletion tests/models/test_mappedoperator.py
Expand Up @@ -95,7 +95,7 @@ def execute(self, context: Context):
mapped = CustomOperator.partial(task_id="task_2").expand(arg=unrenderable_values)
task1 >> mapped
dag.test()
assert caplog.text.count("task_2 ran successfully") == 2
assert caplog.text.count("[DAG TEST] end task task_id=task_2") == 2
assert (
"Unable to check if the value of type 'UnrenderableClass' is False for task 'task_2', field 'arg'"
in caplog.text
Expand Down