Skip to content

Commit

Permalink
Run triggers inline with dag test
Browse files Browse the repository at this point in the history
No need to have trigger running -- will just run them async.
  • Loading branch information
dstandish committed Nov 13, 2023
1 parent 4cc98ba commit c8b4a5e
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 72 deletions.
67 changes: 30 additions & 37 deletions airflow/models/dag.py
Expand Up @@ -17,6 +17,8 @@
# under the License.
from __future__ import annotations

import asyncio
import collections
import collections.abc
import copy
import functools
Expand Down Expand Up @@ -82,11 +84,11 @@
from airflow.exceptions import (
AirflowDagInconsistent,
AirflowException,
AirflowSkipException,
DuplicateTaskIdFound,
FailStopDagInvalidTriggerRule,
ParamValidationError,
RemovedInAirflow3Warning,
TaskDeferred,
TaskNotFound,
)
from airflow.jobs.job import run_job
Expand All @@ -101,7 +103,6 @@
Context,
TaskInstance,
TaskInstanceKey,
TaskReturnCode,
clear_task_instances,
)
from airflow.secrets.local_filesystem import LocalFilesystemBackend
Expand Down Expand Up @@ -285,12 +286,11 @@ def get_dataset_triggered_next_run_info(
}


class _StopDagTest(Exception):
"""
Raise when DAG.test should stop immediately.
def _triggerer_is_healthy():
from airflow.jobs.triggerer_job_runner import TriggererJobRunner

:meta private:
"""
job = TriggererJobRunner.most_recent_job()
return job and job.is_alive()


@functools.total_ordering
Expand Down Expand Up @@ -2844,21 +2844,12 @@ def add_logger_if_needed(ti: TaskInstance):
if not scheduled_tis and ids_unrunnable:
self.log.warning("No tasks to run. unrunnable tasks: %s", ids_unrunnable)
time.sleep(1)
triggerer_running = _triggerer_is_healthy()
for ti in scheduled_tis:
try:
add_logger_if_needed(ti)
ti.task = tasks[ti.task_id]
ret = _run_task(ti, session=session)
if ret is TaskReturnCode.DEFERRED:
if not _triggerer_is_healthy():
raise _StopDagTest(
"Task has deferred but triggerer component is not running. "
"You can start the triggerer by running `airflow triggerer` in a terminal."
)
except _StopDagTest:
# Let this exception bubble out and not be swallowed by the
# except block below.
raise
_run_task(ti=ti, inline_trigger=not triggerer_running, session=session)
except Exception:
self.log.exception("Task failed; ti=%s", ti)
if conn_file_path or variable_file_path:
Expand Down Expand Up @@ -3988,14 +3979,15 @@ def get_current_dag(cls) -> DAG | None:
return None


def _triggerer_is_healthy():
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
def _run_trigger(trigger):
async def _run_trigger_main():
async for event in trigger.run():
return event

job = TriggererJobRunner.most_recent_job()
return job and job.is_alive()
return asyncio.run(_run_trigger_main())


def _run_task(ti: TaskInstance, session) -> TaskReturnCode | None:
def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session: Session):
"""
Run a single task instance, and push result to Xcom for downstream tasks.
Expand All @@ -4005,20 +3997,21 @@ def _run_task(ti: TaskInstance, session) -> TaskReturnCode | None:
Args:
ti: TaskInstance to run
"""
ret = None
log.info("*****************************************************")
if ti.map_index > 0:
log.info("Running task %s index %d", ti.task_id, ti.map_index)
else:
log.info("Running task %s", ti.task_id)
try:
ret = ti._run_raw_task(session=session)
session.flush()
log.info("%s ran successfully!", ti.task_id)
except AirflowSkipException:
log.info("Task Skipped, continuing")
log.info("*****************************************************")
return ret
log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, ti.map_index)
while True:
try:
log.info("[DAG TEST] running task %s", ti)
ti._run_raw_task(session=session, raise_on_defer=inline_trigger)
break
except TaskDeferred as e:
log.info("[DAG TEST] running trigger in line")
event = _run_trigger(e.trigger)
ti.next_method = e.method_name
ti.next_kwargs = {"event": event.payload} if event else e.kwargs
log.info("[DAG TEST] Trigger completed")
session.merge(ti)
session.commit()
log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id, ti.map_index)


def _get_or_create_dagrun(
Expand Down
3 changes: 3 additions & 0 deletions airflow/models/taskinstance.py
Expand Up @@ -2207,6 +2207,7 @@ def _run_raw_task(
test_mode: bool = False,
job_id: str | None = None,
pool: str | None = None,
raise_on_defer: bool = False,
session: Session = NEW_SESSION,
) -> TaskReturnCode | None:
"""
Expand Down Expand Up @@ -2261,6 +2262,8 @@ def _run_raw_task(
except TaskDeferred as defer:
# The task has signalled it wants to defer execution based on
# a trigger.
if raise_on_defer:
raise
self._defer_task(defer=defer, session=session)
self.log.info(
"Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s",
Expand Down
81 changes: 47 additions & 34 deletions tests/cli/commands/test_dag_command.py
Expand Up @@ -37,9 +37,10 @@
from airflow.exceptions import AirflowException
from airflow.models import DagBag, DagModel, DagRun
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import _StopDagTest
from airflow.models.dag import _run_trigger
from airflow.models.serialized_dag import SerializedDagModel
from airflow.triggers.temporal import TimeDeltaTrigger
from airflow.triggers.base import TriggerEvent
from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.types import DagRunType
Expand Down Expand Up @@ -824,35 +825,47 @@ def test_dag_test_with_custom_timetable(self, mock__get_or_create_dagrun, _):
dag_command.dag_test(cli_args)
assert "data_interval" in mock__get_or_create_dagrun.call_args.kwargs

def test_dag_test_no_triggerer(self, dag_maker):
with dag_maker() as dag:

@task
def one():
return 1

@task
def two(val):
return val + 1

class MyOp(BaseOperator):
template_fields = ("tfield",)

def __init__(self, tfield, **kwargs):
self.tfield = tfield
super().__init__(**kwargs)

def execute(self, context, event=None):
if event is None:
print("I AM DEFERRING")
self.defer(trigger=TimeDeltaTrigger(timedelta(seconds=20)), method_name="execute")
return
print("RESUMING")
return self.tfield + 1

task_one = one()
task_two = two(task_one)
op = MyOp(task_id="abc", tfield=str(task_two))
task_two >> op
with pytest.raises(_StopDagTest, match="Task has deferred but triggerer component is not running"):
dag.test()
def test_dag_test_run_trigger(self, dag_maker):
now = timezone.utcnow()
trigger = DateTimeTrigger(moment=now)
e = _run_trigger(trigger)
assert isinstance(e, TriggerEvent)
assert e.payload == now

def test_dag_test_no_triggerer_running(self, dag_maker):
with mock.patch("airflow.models.dag._run_trigger", wraps=_run_trigger) as mock_run:
with dag_maker() as dag:

@task
def one():
return 1

@task
def two(val):
return val + 1

trigger = TimeDeltaTrigger(timedelta(seconds=0))

class MyOp(BaseOperator):
template_fields = ("tfield",)

def __init__(self, tfield, **kwargs):
self.tfield = tfield
super().__init__(**kwargs)

def execute(self, context, event=None):
if event is None:
print("I AM DEFERRING")
self.defer(trigger=trigger, method_name="execute")
return
print("RESUMING")
return self.tfield + 1

task_one = one()
task_two = two(task_one)
op = MyOp(task_id="abc", tfield=task_two)
task_two >> op
dr = dag.test()
assert mock_run.call_args_list[0] == ((trigger,), {})
tis = dr.get_task_instances()
assert [x for x in tis if x.task_id == "abc"][0].state == "success"
2 changes: 1 addition & 1 deletion tests/models/test_mappedoperator.py
Expand Up @@ -95,7 +95,7 @@ def execute(self, context: Context):
mapped = CustomOperator.partial(task_id="task_2").expand(arg=unrenderable_values)
task1 >> mapped
dag.test()
assert caplog.text.count("task_2 ran successfully") == 2
assert caplog.text.count("[DAG TEST] end task task_id=task_2") == 2
assert (
"Unable to check if the value of type 'UnrenderableClass' is False for task 'task_2', field 'arg'"
in caplog.text
Expand Down
140 changes: 140 additions & 0 deletions tests/providers/amazon/aws/sensors/test_s3_keys_unchanged.py
@@ -0,0 +1,140 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from datetime import datetime
from unittest import mock

import pytest
import time_machine

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.sensors.s3 import S3KeysUnchangedSensor

TEST_DAG_ID = "unit_tests_aws_sensor"
DEFAULT_DATE = datetime(2015, 1, 1)


class TestS3KeysUnchangedSensor:
def setup_method(self):
self.dag = DAG(f"{TEST_DAG_ID}test_schedule_dag_once", start_date=DEFAULT_DATE, schedule="@once")

self.sensor = S3KeysUnchangedSensor(
task_id="sensor_1",
bucket_name="test-bucket",
prefix="test-prefix/path",
inactivity_period=12,
poke_interval=0.1,
min_objects=1,
allow_delete=True,
dag=self.dag,
)

def test_reschedule_mode_not_allowed(self):
with pytest.raises(ValueError):
S3KeysUnchangedSensor(
task_id="sensor_2",
bucket_name="test-bucket",
prefix="test-prefix/path",
poke_interval=0.1,
mode="reschedule",
dag=self.dag,
)

def test_render_template_fields(self):
S3KeysUnchangedSensor(
task_id="sensor_3",
bucket_name="test-bucket",
prefix="test-prefix/path",
inactivity_period=12,
poke_interval=0.1,
min_objects=1,
allow_delete=True,
dag=self.dag,
).render_template_fields({})

@time_machine.travel(DEFAULT_DATE)
def test_files_deleted_between_pokes_throw_error(self):
self.sensor.allow_delete = False
self.sensor.is_keys_unchanged({"a", "b"})
with pytest.raises(AirflowException):
self.sensor.is_keys_unchanged({"a"})

@pytest.mark.parametrize(
"current_objects, expected_returns, inactivity_periods",
[
pytest.param(
({"a"}, {"a", "b"}, {"a", "b", "c"}),
(False, False, False),
(0, 0, 0),
id="resetting inactivity period after key change",
),
pytest.param(
({"a", "b"}, {"a"}, {"a", "c"}),
(False, False, False),
(0, 0, 0),
id="item was deleted with option `allow_delete=True`",
),
pytest.param(
({"a"}, {"a"}, {"a"}), (False, False, True), (0, 10, 20), id="inactivity period was exceeded"
),
pytest.param(
(set(), set(), set()), (False, False, False), (0, 10, 20), id="not pass if empty key is given"
),
],
)
def test_key_changes(self, current_objects, expected_returns, inactivity_periods, time_machine):
time_machine.move_to(DEFAULT_DATE)
for current, expected, period in zip(current_objects, expected_returns, inactivity_periods):
assert self.sensor.is_keys_unchanged(current) == expected
assert self.sensor.inactivity_seconds == period
time_machine.coordinates.shift(10)

@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook")
def test_poke_succeeds_on_upload_complete(self, mock_hook, time_machine):
time_machine.move_to(DEFAULT_DATE)
mock_hook.return_value.list_keys.return_value = {"a"}
assert not self.sensor.poke(dict())
time_machine.coordinates.shift(10)
assert not self.sensor.poke(dict())
time_machine.coordinates.shift(10)
assert self.sensor.poke(dict())

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
def test_fail_is_keys_unchanged(self, soft_fail, expected_exception):
op = S3KeysUnchangedSensor(task_id="sensor", bucket_name="test-bucket", prefix="test-prefix/path")
op.soft_fail = soft_fail
op.previous_objects = {"1", "2", "3"}
current_objects = {"1", "2"}
op.allow_delete = False
message = "Illegal behavior: objects were deleted in"
with pytest.raises(expected_exception, match=message):
op.is_keys_unchanged(current_objects=current_objects)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
def test_fail_execute_complete(self, soft_fail, expected_exception):
op = S3KeysUnchangedSensor(task_id="sensor", bucket_name="test-bucket", prefix="test-prefix/path")
op.soft_fail = soft_fail
message = "test message"
with pytest.raises(expected_exception, match=message):
op.execute_complete(context={}, event={"status": "error", "message": message})

0 comments on commit c8b4a5e

Please sign in to comment.