Skip to content

Commit

Permalink
[AIRFLOW-3607] Only query DB once per DAG run for TriggerRuleDep (#4751)
Browse files Browse the repository at this point in the history
This decreases scheduler delay between tasks by about 20% for larger DAGs,
sometimes more for larger or more complex DAGs.

The delay between tasks can be a major issue, especially when we have dags with 
many subdags, figures out that the scheduling process spends plenty of time in
dependency checking, we took the trigger rule dependency which calls the db for
each task instance, we made it call the db just once for each dag_run
  • Loading branch information
amichai07 authored and ashb committed Jan 16, 2020
1 parent e54fba5 commit 50efda5
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 101 deletions.
26 changes: 4 additions & 22 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from airflow.models import DAG, DagRun, SlaMiss, errors
from airflow.models.taskinstance import SimpleTaskInstance
from airflow.stats import Stats
from airflow.ti_deps.dep_context import SCHEDULEABLE_STATES, SCHEDULED_DEPS, DepContext
from airflow.ti_deps.dep_context import SCHEDULED_DEPS, DepContext
from airflow.ti_deps.deps.pool_slots_available_dep import STATES_TO_COUNT_AS_RUNNING
from airflow.utils import asciiart, helpers, timezone
from airflow.utils.dag_processing import (
Expand Down Expand Up @@ -648,28 +648,10 @@ def _process_task_instances(self, dag, task_instances_list, session=None):
run.dag = dag
# todo: preferably the integrity check happens at dag collection time
run.verify_integrity(session=session)
run.update_state(session=session)
ready_tis = run.update_state(session=session)
if run.state == State.RUNNING:
make_transient(run)
active_dag_runs.append(run)

for run in active_dag_runs:
self.log.debug("Examining active DAG run: %s", run)
tis = run.get_task_instances(state=SCHEDULEABLE_STATES)

# this loop is quite slow as it uses are_dependencies_met for
# every task (in ti.is_runnable). This is also called in
# update_state above which has already checked these tasks
for ti in tis:
task = dag.get_task(ti.task_id)

# fixme: ti.task is transient but needs to be set
ti.task = task

if ti.are_dependencies_met(
dep_context=DepContext(flag_upstream_failed=True),
session=session
):
self.log.debug("Examining active DAG run: %s", run)
for ti in ready_tis:
self.log.debug('Queuing task: %s', ti)
task_instances_list.append(ti.key)

Expand Down
81 changes: 46 additions & 35 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from airflow.exceptions import AirflowException
from airflow.models.base import ID_LEN, Base
from airflow.stats import Stats
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dep_context import SCHEDULEABLE_STATES, DepContext
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import provide_session
Expand Down Expand Up @@ -201,7 +201,6 @@ def get_task_instances(self, state=None, session=None):

if self.dag and self.dag.partial:
tis = tis.filter(TaskInstance.task_id.in_(self.dag.task_ids))

return tis.all()

@provide_session
Expand Down Expand Up @@ -268,49 +267,33 @@ def update_state(self, session=None):
Determines the overall state of the DagRun based on the state
of its TaskInstances.
:return: State
:return: ready_tis: the tis that can be scheduled in the current loop
:rtype ready_tis: list[airflow.models.TaskInstance]
"""

dag = self.get_dag()

tis = self.get_task_instances(session=session)
self.log.debug("Updating state for %s considering %s task(s)", self, len(tis))

ready_tis = []
tis = [ti for ti in self.get_task_instances(session=session,
state=State.task_states + (State.SHUTDOWN,))]
self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
for ti in list(tis):
# skip in db?
if ti.state == State.REMOVED:
tis.remove(ti)
else:
ti.task = dag.get_task(ti.task_id)
ti.task = dag.get_task(ti.task_id)

# pre-calculate
# db is faster
start_dttm = timezone.utcnow()
unfinished_tasks = self.get_task_instances(
state=State.unfinished(),
session=session
)
unfinished_tasks = [t for t in tis if t.state in State.unfinished()]
finished_tasks = [t for t in tis if t.state in State.finished() + [State.UPSTREAM_FAILED]]
none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks)
none_task_concurrency = all(t.task.task_concurrency is None
for t in unfinished_tasks)
# small speed up
if unfinished_tasks and none_depends_on_past and none_task_concurrency:
# todo: this can actually get pretty slow: one task costs between 0.01-015s
no_dependencies_met = True
for ut in unfinished_tasks:
# We need to flag upstream and check for changes because upstream
# failures/re-schedules can result in deadlock false positives
old_state = ut.state
deps_met = ut.are_dependencies_met(
dep_context=DepContext(
flag_upstream_failed=True,
ignore_in_retry_period=True,
ignore_in_reschedule_period=True),
session=session)
if deps_met or old_state != ut.current_state(session=session):
no_dependencies_met = False
break
scheduleable_tasks = [ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES]

self.log.debug("number of scheduleable tasks for %s: %s task(s)", self, len(scheduleable_tasks))
ready_tis, changed_tis = self._get_ready_tis(scheduleable_tasks, finished_tasks, session)
self.log.debug("ready tis length for %s: %s task(s)", self, len(ready_tis))
are_runnable_tasks = ready_tis or self._are_premature_tis(
unfinished_tasks, finished_tasks, session) or changed_tis
duration = (timezone.utcnow() - start_dttm)
Stats.timing("dagrun.dependency-check.{}".format(self.dag_id), duration)

Expand All @@ -335,7 +318,7 @@ def update_state(self, session=None):

# if *all tasks* are deadlocked, the run failed
elif (unfinished_tasks and none_depends_on_past and
none_task_concurrency and no_dependencies_met):
none_task_concurrency and not are_runnable_tasks):
self.log.info('Deadlock; marking run %s failed', self)
self.set_state(State.FAILED)
dag.handle_callback(self, success=False, reason='all_tasks_deadlocked',
Expand All @@ -351,7 +334,35 @@ def update_state(self, session=None):
session.merge(self)
session.commit()

return self.state
return ready_tis

def _get_ready_tis(self, scheduleable_tasks, finished_tasks, session):
ready_tis = []
changed_tis = False
for st in scheduleable_tasks:
st_old_state = st.state
if st.are_dependencies_met(
dep_context=DepContext(
flag_upstream_failed=True,
finished_tasks=finished_tasks),
session=session):
ready_tis.append(st)
elif st_old_state != st.current_state(session=session):
changed_tis = True
return ready_tis, changed_tis

def _are_premature_tis(self, unfinished_tasks, finished_tasks, session):
# there might be runnable tasks that are up for retry and from some reason(retry delay, etc) are
# not ready yet so we set the flags to count them in
for ut in unfinished_tasks:
if ut.are_dependencies_met(
dep_context=DepContext(
flag_upstream_failed=True,
ignore_in_retry_period=True,
ignore_in_reschedule_period=True,
finished_tasks=finished_tasks),
session=session):
return True

def _emit_duration_stats_for_finished_state(self):
if self.state == State.RUNNING:
Expand Down
6 changes: 5 additions & 1 deletion airflow/ti_deps/dep_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class DepContext:
:type ignore_task_deps: bool
:param ignore_ti_state: Ignore the task instance's previous failure/success
:type ignore_ti_state: bool
:param finished_tasks: A list of all the finished tasks of this run
:type finished_tasks: list[airflow.models.TaskInstance]
"""
def __init__(
self,
Expand All @@ -77,7 +79,8 @@ def __init__(
ignore_in_retry_period=False,
ignore_in_reschedule_period=False,
ignore_task_deps=False,
ignore_ti_state=False):
ignore_ti_state=False,
finished_tasks=None):
self.deps = deps or set()
self.flag_upstream_failed = flag_upstream_failed
self.ignore_all_deps = ignore_all_deps
Expand All @@ -86,6 +89,7 @@ def __init__(
self.ignore_in_reschedule_period = ignore_in_reschedule_period
self.ignore_task_deps = ignore_task_deps
self.ignore_ti_state = ignore_ti_state
self.finished_tasks = finished_tasks


# In order to be able to get queued a task must have one of these states
Expand Down
58 changes: 28 additions & 30 deletions airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# specific language governing permissions and limitations
# under the License.

from sqlalchemy import case, func
from collections import Counter

import airflow
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
Expand All @@ -34,11 +34,32 @@ class TriggerRuleDep(BaseTIDep):
IGNOREABLE = True
IS_TASK_DEP = True

@staticmethod
@provide_session
def _get_states_count_upstream_ti(ti, finished_tasks, session):
"""
This function returns the states of the upstream tis for a specific ti in order to determine
whether this ti can run in this iteration
:param ti: the ti that we want to calculate deps for
:type ti: airflow.models.TaskInstance
:param finished_tasks: all the finished tasks of the dag_run
:type finished_tasks: list[airflow.models.TaskInstance]
"""
if finished_tasks is None:
# this is for the strange feature of running tasks without dag_run
finished_tasks = ti.task.dag.get_task_instances(
start_date=ti.execution_date,
end_date=ti.execution_date,
state=State.finished() + [State.UPSTREAM_FAILED],
session=session)
counter = Counter(task.state for task in finished_tasks if task.task_id in ti.task.upstream_task_ids)
return counter.get(State.SUCCESS, 0), counter.get(State.SKIPPED, 0), counter.get(State.FAILED, 0), \
counter.get(State.UPSTREAM_FAILED, 0), sum(counter.values())

@provide_session
def _get_dep_statuses(self, ti, session, dep_context):
TI = airflow.models.TaskInstance
TR = airflow.utils.trigger_rule.TriggerRule

# Checking that all upstream dependencies have succeeded
if not ti.task.upstream_list:
yield self._passing_status(
Expand All @@ -48,34 +69,11 @@ def _get_dep_statuses(self, ti, session, dep_context):
if ti.task.trigger_rule == TR.DUMMY:
yield self._passing_status(reason="The task had a dummy trigger rule set.")
return
# see if the task name is in the task upstream for our task
successes, skipped, failed, upstream_failed, done = self._get_states_count_upstream_ti(
ti=ti,
finished_tasks=dep_context.finished_tasks)

# TODO(unknown): this query becomes quite expensive with dags that have many
# tasks. It should be refactored to let the task report to the dag run and get the
# aggregates from there.
qry = (
session
.query(
func.coalesce(func.sum(
case([(TI.state == State.SUCCESS, 1)], else_=0)), 0),
func.coalesce(func.sum(
case([(TI.state == State.SKIPPED, 1)], else_=0)), 0),
func.coalesce(func.sum(
case([(TI.state == State.FAILED, 1)], else_=0)), 0),
func.coalesce(func.sum(
case([(TI.state == State.UPSTREAM_FAILED, 1)], else_=0)), 0),
func.count(TI.task_id),
)
.filter(
TI.dag_id == ti.dag_id,
TI.task_id.in_(ti.task.upstream_task_ids),
TI.execution_date == ti.execution_date,
TI.state.in_([
State.SUCCESS, State.FAILED,
State.UPSTREAM_FAILED, State.SKIPPED]),
)
)

successes, skipped, failed, upstream_failed, done = qry.first()
yield from self._evaluate_trigger_rule(
ti=ti,
successes=successes,
Expand Down
1 change: 1 addition & 0 deletions tests/jobs/test_backfill_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,7 @@ def test_backfill_execute_subdag_with_removed_task(self):

session = settings.Session()
session.merge(removed_task_ti)
session.commit()

with timeout(seconds=30):
job.run()
Expand Down
4 changes: 2 additions & 2 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2034,8 +2034,8 @@ def test_dagrun_root_fail_unfinished(self):
ti = dr.get_task_instance('test_dagrun_unfinished', session=session)
ti.state = State.NONE
session.commit()
dr_state = dr.update_state()
self.assertEqual(dr_state, State.RUNNING)
dr.update_state()
self.assertEqual(dr.state, State.RUNNING)

def test_dagrun_root_after_dagrun_unfinished(self):
"""
Expand Down
20 changes: 10 additions & 10 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def test_dagrun_success_when_all_skipped(self):
dag_run = self.create_dag_run(dag=dag,
state=State.RUNNING,
task_states=initial_task_states)
updated_dag_state = dag_run.update_state()
self.assertEqual(State.SUCCESS, updated_dag_state)
dag_run.update_state()
self.assertEqual(State.SUCCESS, dag_run.state)

def test_dagrun_success_conditions(self):
session = settings.Session()
Expand Down Expand Up @@ -198,15 +198,15 @@ def test_dagrun_success_conditions(self):
ti_op4 = dr.get_task_instance(task_id=op4.task_id)

# root is successful, but unfinished tasks
state = dr.update_state()
self.assertEqual(State.RUNNING, state)
dr.update_state()
self.assertEqual(State.RUNNING, dr.state)

# one has failed, but root is successful
ti_op2.set_state(state=State.FAILED, session=session)
ti_op3.set_state(state=State.SUCCESS, session=session)
ti_op4.set_state(state=State.SUCCESS, session=session)
state = dr.update_state()
self.assertEqual(State.SUCCESS, state)
dr.update_state()
self.assertEqual(State.SUCCESS, dr.state)

def test_dagrun_deadlock(self):
session = settings.Session()
Expand Down Expand Up @@ -321,8 +321,8 @@ def on_success_callable(context):
dag_run = self.create_dag_run(dag=dag,
state=State.RUNNING,
task_states=initial_task_states)
updated_dag_state = dag_run.update_state()
self.assertEqual(State.SUCCESS, updated_dag_state)
dag_run.update_state()
self.assertEqual(State.SUCCESS, dag_run.state)

def test_dagrun_failure_callback(self):
def on_failure_callable(context):
Expand Down Expand Up @@ -352,8 +352,8 @@ def on_failure_callable(context):
dag_run = self.create_dag_run(dag=dag,
state=State.RUNNING,
task_states=initial_task_states)
updated_dag_state = dag_run.update_state()
self.assertEqual(State.FAILED, updated_dag_state)
dag_run.update_state()
self.assertEqual(State.FAILED, dag_run.state)

def test_dagrun_set_state_end_date(self):
session = settings.Session()
Expand Down
Loading

0 comments on commit 50efda5

Please sign in to comment.