From 68e342f21dbcb024352d48b410d5472a03476da1 Mon Sep 17 00:00:00 2001 From: Stefan Seelmann Date: Fri, 21 Sep 2018 07:00:29 +0200 Subject: [PATCH] [AIRFLOW-2747] Explicit re-schedule of sensors (#3596) * [AIRFLOW-2747] Explicit re-schedule of sensors Add `mode` property to sensors. If set to `reschedule` an AirflowRescheduleException is raised instead of sleeping which sets the task back to state `NONE`. Reschedules are recorded in new `task_schedule` table and visualized in the Gantt view. New TI dependency checks if a sensor task is ready to be re-scheduled. * Reformat sqlalchemy imports * Make `_handle_reschedule` private * Remove print * Add comment * Add comment * Don't record reschule request in test mode --- airflow/exceptions.py | 11 + .../0a2a5b66e19d_add_task_reschedule_table.py | 83 +++++ airflow/models.py | 105 +++++- airflow/sensors/base_sensor_operator.py | 56 +++- airflow/ti_deps/dep_context.py | 4 + airflow/ti_deps/deps/ready_to_reschedule.py | 69 ++++ airflow/www/static/gantt-chart-d3v2.js | 2 +- airflow/www/views.py | 55 +++- .../www_rbac/static/js/gantt-chart-d3v2.js | 2 +- airflow/www_rbac/views.py | 58 +++- tests/sensors/test_base_sensor.py | 303 +++++++++++++++++- .../deps/test_ready_to_reschedule_dep.py | 94 ++++++ 12 files changed, 810 insertions(+), 32 deletions(-) create mode 100644 airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py create mode 100644 airflow/ti_deps/deps/ready_to_reschedule.py create mode 100644 tests/ti_deps/deps/test_ready_to_reschedule_dep.py diff --git a/airflow/exceptions.py b/airflow/exceptions.py index 89f3d0e048da2..d4098c4a32435 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -47,6 +47,17 @@ class AirflowSensorTimeout(AirflowException): pass +class AirflowRescheduleException(AirflowException): + """ + Raise when the task should be re-scheduled at a later time. + + :param reschedule_date: The date when the task should be rescheduled + :type reschedule: datetime + """ + def __init__(self, reschedule_date): + self.reschedule_date = reschedule_date + + class AirflowTaskTimeout(AirflowException): pass diff --git a/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py b/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py new file mode 100644 index 0000000000000..6eef6a9437544 --- /dev/null +++ b/airflow/migrations/versions/0a2a5b66e19d_add_task_reschedule_table.py @@ -0,0 +1,83 @@ +# flake8: noqa +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""add task_reschedule table + +Revision ID: 0a2a5b66e19d +Revises: 9635ae0956e7 +Create Date: 2018-06-17 22:50:00.053620 + +""" + +# revision identifiers, used by Alembic. +revision = '0a2a5b66e19d' +down_revision = '9635ae0956e7' +branch_labels = None +depends_on = None + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + + +TABLE_NAME = 'task_reschedule' +INDEX_NAME = 'idx_' + TABLE_NAME + '_dag_task_date' + +def mysql_timestamp(): + return mysql.TIMESTAMP(fsp=6) + +def sa_timestamp(): + return sa.TIMESTAMP(timezone=True) + +def upgrade(): + # See 0e2a74e0fc9f_add_time_zone_awareness + conn = op.get_bind() + if conn.dialect.name == 'mysql': + timestamp = mysql_timestamp + else: + timestamp = sa_timestamp + + op.create_table( + TABLE_NAME, + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('task_id', sa.String(length=250), nullable=False), + sa.Column('dag_id', sa.String(length=250), nullable=False), + # use explicit server_default=None otherwise mysql implies defaults for first timestamp column + sa.Column('execution_date', timestamp(), nullable=False, server_default=None), + sa.Column('try_number', sa.Integer(), nullable=False), + sa.Column('start_date', timestamp(), nullable=False), + sa.Column('end_date', timestamp(), nullable=False), + sa.Column('duration', sa.Integer(), nullable=False), + sa.Column('reschedule_date', timestamp(), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.ForeignKeyConstraint(['task_id', 'dag_id', 'execution_date'], + ['task_instance.task_id', 'task_instance.dag_id','task_instance.execution_date'], + name='task_reschedule_dag_task_date_fkey') + ) + op.create_index( + INDEX_NAME, + TABLE_NAME, + ['dag_id', 'task_id', 'execution_date'], + unique=False + ) + + +def downgrade(): + op.drop_index(INDEX_NAME, table_name=TABLE_NAME) + op.drop_table(TABLE_NAME) diff --git a/airflow/models.py b/airflow/models.py index 43916f7cf5460..1e7848d04e305 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -56,9 +56,10 @@ from urllib.parse import urlparse, quote, parse_qsl, unquote from sqlalchemy import ( - Column, Integer, String, DateTime, Text, Boolean, ForeignKey, PickleType, - Index, Float, LargeBinary, UniqueConstraint) -from sqlalchemy import func, or_, and_, true as sqltrue + Boolean, Column, DateTime, Float, ForeignKey, ForeignKeyConstraint, Index, + Integer, LargeBinary, PickleType, String, Text, UniqueConstraint, + and_, asc, func, or_, true as sqltrue +) from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.orm import reconstructor, relationship, synonym @@ -71,7 +72,8 @@ from airflow.executors import GetDefaultExecutor, LocalExecutor from airflow import configuration from airflow.exceptions import ( - AirflowDagCycleException, AirflowException, AirflowSkipException, AirflowTaskTimeout + AirflowDagCycleException, AirflowException, AirflowSkipException, AirflowTaskTimeout, + AirflowRescheduleException ) from airflow.dag.base_dag import BaseDag, BaseDagBag from airflow.lineage import apply_lineage, prepare_lineage @@ -1682,6 +1684,10 @@ def signal_handler(signum, frame): except AirflowSkipException: self.refresh_from_db(lock_for_update=True) self.state = State.SKIPPED + except AirflowRescheduleException as reschedule_exception: + self.refresh_from_db() + self._handle_reschedule(reschedule_exception, test_mode, context) + return except AirflowException as e: self.refresh_from_db() # for case when task is marked as success externally @@ -1751,6 +1757,32 @@ def dry_run(self): self.render_templates() task_copy.dry_run() + @provide_session + def _handle_reschedule(self, reschedule_exception, test_mode=False, context=None, + session=None): + # Don't record reschedule request in test mode + if test_mode: + return + + self.end_date = timezone.utcnow() + self.set_duration() + + # Log reschedule request + session.add(TaskReschedule(self.task, self.execution_date, self._try_number, + self.start_date, self.end_date, + reschedule_exception.reschedule_date)) + + # set state + self.state = State.NONE + + # Decrement try_number so subsequent runs will use the same try number and write + # to same log file. + self._try_number -= 1 + + session.merge(self) + session.commit() + self.log.info('Rescheduling task, marking task as NONE') + @provide_session def handle_failure(self, error, test_mode=False, context=None, session=None): self.log.exception(error) @@ -2111,6 +2143,66 @@ def __init__(self, task, execution_date, start_date, end_date): self.duration = None +class TaskReschedule(Base): + """ + TaskReschedule tracks rescheduled task instances. + """ + + __tablename__ = "task_reschedule" + + id = Column(Integer, primary_key=True) + task_id = Column(String(ID_LEN), nullable=False) + dag_id = Column(String(ID_LEN), nullable=False) + execution_date = Column(UtcDateTime, nullable=False) + try_number = Column(Integer, nullable=False) + start_date = Column(UtcDateTime, nullable=False) + end_date = Column(UtcDateTime, nullable=False) + duration = Column(Integer, nullable=False) + reschedule_date = Column(UtcDateTime, nullable=False) + + __table_args__ = ( + Index('idx_task_reschedule_dag_task_date', dag_id, task_id, execution_date, + unique=False), + ForeignKeyConstraint([task_id, dag_id, execution_date], + [TaskInstance.task_id, TaskInstance.dag_id, + TaskInstance.execution_date], + name='task_reschedule_dag_task_date_fkey') + ) + + def __init__(self, task, execution_date, try_number, start_date, end_date, + reschedule_date): + self.dag_id = task.dag_id + self.task_id = task.task_id + self.execution_date = execution_date + self.try_number = try_number + self.start_date = start_date + self.end_date = end_date + self.reschedule_date = reschedule_date + self.duration = (self.end_date - self.start_date).total_seconds() + + @staticmethod + @provide_session + def find_for_task_instance(task_instance, session): + """ + Returns all task reschedules for the task instance and try number, + in ascending order. + + :param task_instance: the task instance to find task reschedules for + :type task_instance: TaskInstance + """ + TR = TaskReschedule + return ( + session + .query(TR) + .filter(TR.dag_id == task_instance.dag_id, + TR.task_id == task_instance.task_id, + TR.execution_date == task_instance.execution_date, + TR.try_number == task_instance.try_number) + .order_by(asc(TR.id)) + .all() + ) + + class Log(Base): """ Used to actively log events to the database @@ -5118,12 +5210,13 @@ def update_state(self, session=None): no_dependencies_met = True for ut in unfinished_tasks: # We need to flag upstream and check for changes because upstream - # failures can result in deadlock false positives + # failures/re-schedules can result in deadlock false positives old_state = ut.state deps_met = ut.are_dependencies_met( dep_context=DepContext( flag_upstream_failed=True, - ignore_in_retry_period=True), + ignore_in_retry_period=True, + ignore_in_reschedule_period=True), session=session) if deps_met or old_state != ut.current_state(session=session): no_dependencies_met = False diff --git a/airflow/sensors/base_sensor_operator.py b/airflow/sensors/base_sensor_operator.py index 74b0e0fe1ca5b..1dc59dd230b0c 100644 --- a/airflow/sensors/base_sensor_operator.py +++ b/airflow/sensors/base_sensor_operator.py @@ -19,20 +19,22 @@ from time import sleep +from datetime import timedelta from airflow.exceptions import AirflowException, AirflowSensorTimeout, \ - AirflowSkipException -from airflow.models import BaseOperator, SkipMixin + AirflowSkipException, AirflowRescheduleException +from airflow.models import BaseOperator, SkipMixin, TaskReschedule from airflow.utils import timezone from airflow.utils.decorators import apply_defaults +from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep class BaseSensorOperator(BaseOperator, SkipMixin): """ - Sensor operators are derived from this class an inherit these attributes. + Sensor operators are derived from this class and inherit these attributes. Sensor operators keep executing at a time interval and succeed when - a criteria is met and fail if and when they time out. + a criteria is met and fail if and when they time out. :param soft_fail: Set to true to mark the task as SKIPPED on failure :type soft_fail: bool @@ -41,20 +43,42 @@ class BaseSensorOperator(BaseOperator, SkipMixin): :type poke_interval: int :param timeout: Time, in seconds before the task times out and fails. :type timeout: int + :param mode: How the sensor operates. + Options are: ``{ poke | reschedule }``, default is ``poke``. + When set to ``poke`` the sensor is taking up a worker slot for its + whole execution time and sleeps between pokes. Use this mode if the + expected runtime of the sensor is short or if a short poke interval + is requried. + When set to ``reschedule`` the sensor task frees the worker slot when + the criteria is not yet met and it's rescheduled at a later time. Use + this mode if the expected time until the criteria is met is. The poke + inteval should be more than one minute to prevent too much load on + the scheduler. + :type mode: str """ ui_color = '#e6f1f2' + valid_modes = ['poke', 'reschedule'] @apply_defaults def __init__(self, poke_interval=60, timeout=60 * 60 * 24 * 7, soft_fail=False, + mode='poke', *args, **kwargs): super(BaseSensorOperator, self).__init__(*args, **kwargs) self.poke_interval = poke_interval self.soft_fail = soft_fail self.timeout = timeout + if mode not in self.valid_modes: + raise AirflowException( + "The mode must be one of {valid_modes}," + "'{d}.{t}'; received '{m}'." + .format(valid_modes=self.valid_modes, + d=self.dag.dag_id if self.dag else "", + t=self.task_id, m=mode)) + self.mode = mode def poke(self, context): """ @@ -65,6 +89,11 @@ def poke(self, context): def execute(self, context): started_at = timezone.utcnow() + if self.reschedule: + # If reschedule, use first start date of current try + task_reschedules = TaskReschedule.find_for_task_instance(context['ti']) + if task_reschedules: + started_at = task_reschedules[0].start_date while not self.poke(context): if (timezone.utcnow() - started_at).total_seconds() > self.timeout: # If sensor is in soft fail mode but will be retried then @@ -75,7 +104,12 @@ def execute(self, context): raise AirflowSkipException('Snap. Time is OUT.') else: raise AirflowSensorTimeout('Snap. Time is OUT.') - sleep(self.poke_interval) + if self.reschedule: + reschedule_date = timezone.utcnow() + timedelta( + seconds=self.poke_interval) + raise AirflowRescheduleException(reschedule_date) + else: + sleep(self.poke_interval) self.log.info("Success criteria met. Exiting.") def _do_skip_downstream_tasks(self, context): @@ -83,3 +117,15 @@ def _do_skip_downstream_tasks(self, context): self.log.debug("Downstream task_ids %s", downstream_tasks) if downstream_tasks: self.skip(context['dag_run'], context['ti'].execution_date, downstream_tasks) + + @property + def reschedule(self): + return self.mode == 'reschedule' + + @property + def deps(self): + """ + Adds one additional dependency for all sensor operators that + checks if a sensor task instance can be rescheduled. + """ + return BaseOperator.deps.fget(self) | {ReadyToRescheduleDep()} diff --git a/airflow/ti_deps/dep_context.py b/airflow/ti_deps/dep_context.py index 6d39998988f83..a347983f3b9f1 100644 --- a/airflow/ti_deps/dep_context.py +++ b/airflow/ti_deps/dep_context.py @@ -58,6 +58,8 @@ class DepContext(object): :type ignore_depends_on_past: boolean :param ignore_in_retry_period: Ignore the retry period for task instances :type ignore_in_retry_period: boolean + :param ignore_in_reschedule_period: Ignore the reschedule period for task instances + :type ignore_in_reschedule_period: boolean :param ignore_task_deps: Ignore task-specific dependencies such as depends_on_past and trigger rule :type ignore_task_deps: boolean @@ -71,6 +73,7 @@ def __init__( ignore_all_deps=False, ignore_depends_on_past=False, ignore_in_retry_period=False, + ignore_in_reschedule_period=False, ignore_task_deps=False, ignore_ti_state=False): self.deps = deps or set() @@ -78,6 +81,7 @@ def __init__( self.ignore_all_deps = ignore_all_deps self.ignore_depends_on_past = ignore_depends_on_past self.ignore_in_retry_period = ignore_in_retry_period + self.ignore_in_reschedule_period = ignore_in_reschedule_period self.ignore_task_deps = ignore_task_deps self.ignore_ti_state = ignore_ti_state diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py b/airflow/ti_deps/deps/ready_to_reschedule.py new file mode 100644 index 0000000000000..e0f5f8fdfe410 --- /dev/null +++ b/airflow/ti_deps/deps/ready_to_reschedule.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.ti_deps.deps.base_ti_dep import BaseTIDep +from airflow.utils import timezone +from airflow.utils.db import provide_session +from airflow.utils.state import State + + +class ReadyToRescheduleDep(BaseTIDep): + NAME = "Ready To Reschedule" + IGNOREABLE = True + IS_TASK_DEP = True + + @provide_session + def _get_dep_statuses(self, ti, session, dep_context): + """ + Determines whether a task is ready to be rescheduled. Only tasks in + NONE state with at least one row in task_reschedule table are + handled by this dependency class, otherwise this dependency is + considered as passed. This dependency fails if the latest reschedule + request's reschedule date is still in future. + """ + if dep_context.ignore_in_reschedule_period: + yield self._passing_status( + reason="The context specified that being in a reschedule period was " + "permitted.") + return + + if ti.state != State.NONE: + yield self._passing_status( + reason="The task instance is not in NONE state.") + return + + # Lazy import to avoid circular dependency + from airflow.models import TaskReschedule + task_reschedules = TaskReschedule.find_for_task_instance(task_instance=ti) + if not task_reschedules: + yield self._passing_status( + reason="There is no reschedule request for this task instance.") + return + + now = timezone.utcnow() + next_reschedule_date = task_reschedules[-1].reschedule_date + if now >= next_reschedule_date: + yield self._passing_status( + reason="Task instance id ready for reschedule.") + return + + yield self._failing_status( + reason="Task is not ready for reschedule yet but will be rescheduled " + "automatically. Current date is {0} and task will be rescheduled " + "at {1}.".format(now.isoformat(), next_reschedule_date.isoformat())) diff --git a/airflow/www/static/gantt-chart-d3v2.js b/airflow/www/static/gantt-chart-d3v2.js index d21311a1c541d..245a0147e9f72 100644 --- a/airflow/www/static/gantt-chart-d3v2.js +++ b/airflow/www/static/gantt-chart-d3v2.js @@ -129,7 +129,7 @@ d3.gantt = function() { call_modal(d.taskName, d.executionDate); }) .attr("class", function(d){ - if(taskStatus[d.status] == null){ return "bar";} + if(taskStatus[d.status] == null){ return "null";} return taskStatus[d.status]; }) .attr("y", 0) diff --git a/airflow/www/views.py b/airflow/www/views.py index 6bf60d63878fc..41465c78c00b6 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -22,6 +22,7 @@ import ast import datetime as dt +import itertools import logging import os import pkg_resources @@ -1851,19 +1852,57 @@ def gantt(self, session=None): ti for ti in dag.get_task_instances(session, dttm, dttm) if ti.start_date] tis = sorted(tis, key=lambda ti: ti.start_date) + TF = models.TaskFail + ti_fails = list(itertools.chain(*[( + session + .query(TF) + .filter(TF.dag_id == ti.dag_id, + TF.task_id == ti.task_id, + TF.execution_date == ti.execution_date) + .all() + ) for ti in tis])) + TR = models.TaskReschedule + ti_reschedules = list(itertools.chain(*[( + session + .query(TR) + .filter(TR.dag_id == ti.dag_id, + TR.task_id == ti.task_id, + TR.execution_date == ti.execution_date) + .all() + ) for ti in tis])) + # determine bars to show in the gantt chart + # all reschedules of one attempt are combinded into one bar + gantt_bar_items = [] + for task_id, items in itertools.groupby( + sorted(tis + ti_fails + ti_reschedules, key=lambda ti: ti.task_id), + key=lambda ti: ti.task_id): + start_date = None + for i in sorted(items, key=lambda ti: ti.start_date): + start_date = start_date or i.start_date + end_date = i.end_date or timezone.utcnow() + if type(i) == models.TaskInstance: + gantt_bar_items.append((task_id, start_date, end_date, i.state)) + start_date = None + elif type(i) == TF and (len(gantt_bar_items) == 0 or + end_date != gantt_bar_items[-1][2]): + gantt_bar_items.append((task_id, start_date, end_date, State.FAILED)) + start_date = None tasks = [] - for ti in tis: - end_date = ti.end_date if ti.end_date else timezone.utcnow() + for gantt_bar_item in gantt_bar_items: + task_id = gantt_bar_item[0] + start_date = gantt_bar_item[1] + end_date = gantt_bar_item[2] + state = gantt_bar_item[3] tasks.append({ - 'startDate': wwwutils.epoch(ti.start_date), + 'startDate': wwwutils.epoch(start_date), 'endDate': wwwutils.epoch(end_date), - 'isoStart': ti.start_date.isoformat()[:-4], + 'isoStart': start_date.isoformat()[:-4], 'isoEnd': end_date.isoformat()[:-4], - 'taskName': ti.task_id, - 'duration': "{}".format(end_date - ti.start_date)[:-4], - 'status': ti.state, - 'executionDate': ti.execution_date.isoformat(), + 'taskName': task_id, + 'duration': "{}".format(end_date - start_date)[:-4], + 'status': state, + 'executionDate': dttm.isoformat(), }) states = {ti.state: ti.state for ti in tis} data = { diff --git a/airflow/www_rbac/static/js/gantt-chart-d3v2.js b/airflow/www_rbac/static/js/gantt-chart-d3v2.js index d21311a1c541d..245a0147e9f72 100644 --- a/airflow/www_rbac/static/js/gantt-chart-d3v2.js +++ b/airflow/www_rbac/static/js/gantt-chart-d3v2.js @@ -129,7 +129,7 @@ d3.gantt = function() { call_modal(d.taskName, d.executionDate); }) .attr("class", function(d){ - if(taskStatus[d.status] == null){ return "bar";} + if(taskStatus[d.status] == null){ return "null";} return taskStatus[d.status]; }) .attr("y", 0) diff --git a/airflow/www_rbac/views.py b/airflow/www_rbac/views.py index 39d9fb58a3633..163377d8e9176 100644 --- a/airflow/www_rbac/views.py +++ b/airflow/www_rbac/views.py @@ -23,7 +23,8 @@ import logging import os import socket -from datetime import datetime, timedelta +from datetime import timedelta +import itertools import copy import math import json @@ -1530,19 +1531,58 @@ def gantt(self, session=None): ti for ti in dag.get_task_instances(session, dttm, dttm) if ti.start_date] tis = sorted(tis, key=lambda ti: ti.start_date) + TF = models.TaskFail + ti_fails = list(itertools.chain(*[( + session + .query(TF) + .filter(TF.dag_id == ti.dag_id, + TF.task_id == ti.task_id, + TF.execution_date == ti.execution_date) + .all() + ) for ti in tis])) + TR = models.TaskReschedule + ti_reschedules = list(itertools.chain(*[( + session + .query(TR) + .filter(TR.dag_id == ti.dag_id, + TR.task_id == ti.task_id, + TR.execution_date == ti.execution_date) + .all() + ) for ti in tis])) + + # determine bars to show in the gantt chart + # all reschedules of one attempt are combinded into one bar + gantt_bar_items = [] + for task_id, items in itertools.groupby( + sorted(tis + ti_fails + ti_reschedules, key=lambda ti: ti.task_id), + key=lambda ti: ti.task_id): + start_date = None + for i in sorted(items, key=lambda ti: ti.start_date): + start_date = start_date or i.start_date + end_date = i.end_date or timezone.utcnow() + if type(i) == models.TaskInstance: + gantt_bar_items.append((task_id, start_date, end_date, i.state)) + start_date = None + elif type(i) == TF and (len(gantt_bar_items) == 0 or + end_date != gantt_bar_items[-1][2]): + gantt_bar_items.append((task_id, start_date, end_date, State.FAILED)) + start_date = None tasks = [] - for ti in tis: - end_date = ti.end_date if ti.end_date else timezone.utcnow() + for gantt_bar_item in gantt_bar_items: + task_id = gantt_bar_item[0] + start_date = gantt_bar_item[1] + end_date = gantt_bar_item[2] + state = gantt_bar_item[3] tasks.append({ - 'startDate': wwwutils.epoch(ti.start_date), + 'startDate': wwwutils.epoch(start_date), 'endDate': wwwutils.epoch(end_date), - 'isoStart': ti.start_date.isoformat()[:-4], + 'isoStart': start_date.isoformat()[:-4], 'isoEnd': end_date.isoformat()[:-4], - 'taskName': ti.task_id, - 'duration': "{}".format(end_date - ti.start_date)[:-4], - 'status': ti.state, - 'executionDate': ti.execution_date.isoformat(), + 'taskName': task_id, + 'duration': "{}".format(end_date - start_date)[:-4], + 'status': state, + 'executionDate': dttm.isoformat(), }) states = {ti.state: ti.state for ti in tis} data = { diff --git a/tests/sensors/test_base_sensor.py b/tests/sensors/test_base_sensor.py index adb7a5d1e31f7..353f4447b1008 100644 --- a/tests/sensors/test_base_sensor.py +++ b/tests/sensors/test_base_sensor.py @@ -18,17 +18,21 @@ # under the License. import unittest +from mock import Mock from airflow import DAG, configuration, settings -from airflow.exceptions import AirflowSensorTimeout -from airflow.models import DagRun, TaskInstance +from airflow.exceptions import (AirflowSensorTimeout, AirflowException, + AirflowRescheduleException) +from airflow.models import DagRun, TaskInstance, TaskReschedule from airflow.operators.dummy_operator import DummyOperator from airflow.sensors.base_sensor_operator import BaseSensorOperator +from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep from airflow.utils import timezone from airflow.utils.state import State from airflow.utils.timezone import datetime from datetime import timedelta from time import sleep +from freezegun import freeze_time configuration.load_test_config() @@ -57,6 +61,7 @@ def setUp(self): self.dag = DAG(TEST_DAG_ID, default_args=args) session = settings.Session() + session.query(TaskReschedule).delete() session.query(DagRun).delete() session.query(TaskInstance).delete() session.commit() @@ -158,3 +163,297 @@ def test_soft_fail_with_retries(self): self.assertEquals(len(tis), 2) for ti in tis: self.assertEquals(ti.state, State.SKIPPED) + + def test_ok_with_reschedule(self): + sensor = self._make_sensor( + return_value=None, + poke_interval=10, + timeout=25, + mode='reschedule') + sensor.poke = Mock(side_effect=[False, False, True]) + dr = self._make_dag_run() + + # first poke returns False and task is re-scheduled + date1 = timezone.utcnow() + with freeze_time(date1): + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + if ti.task_id == SENSOR_OP: + # verify task is re-scheduled, i.e. state set to NONE + self.assertEquals(ti.state, State.NONE) + # verify one row in task_reschedule table + task_reschedules = TaskReschedule.find_for_task_instance(ti) + self.assertEquals(len(task_reschedules), 1) + self.assertEquals(task_reschedules[0].start_date, date1) + self.assertEquals(task_reschedules[0].reschedule_date, + date1 + timedelta(seconds=sensor.poke_interval)) + if ti.task_id == DUMMY_OP: + self.assertEquals(ti.state, State.NONE) + + # second poke returns False and task is re-scheduled + date2 = date1 + timedelta(seconds=sensor.poke_interval) + with freeze_time(date2): + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + if ti.task_id == SENSOR_OP: + # verify task is re-scheduled, i.e. state set to NONE + self.assertEquals(ti.state, State.NONE) + # verify two rows in task_reschedule table + task_reschedules = TaskReschedule.find_for_task_instance(ti) + self.assertEquals(len(task_reschedules), 2) + self.assertEquals(task_reschedules[1].start_date, date2) + self.assertEquals(task_reschedules[1].reschedule_date, + date2 + timedelta(seconds=sensor.poke_interval)) + if ti.task_id == DUMMY_OP: + self.assertEquals(ti.state, State.NONE) + + # third poke returns True and task succeeds + date3 = date2 + timedelta(seconds=sensor.poke_interval) + with freeze_time(date3): + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + if ti.task_id == SENSOR_OP: + self.assertEquals(ti.state, State.SUCCESS) + if ti.task_id == DUMMY_OP: + self.assertEquals(ti.state, State.NONE) + + def test_fail_with_reschedule(self): + sensor = self._make_sensor( + return_value=False, + poke_interval=10, + timeout=5, + mode='reschedule') + dr = self._make_dag_run() + + # first poke returns False and task is re-scheduled + date1 = timezone.utcnow() + with freeze_time(date1): + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + if ti.task_id == SENSOR_OP: + self.assertEquals(ti.state, State.NONE) + if ti.task_id == DUMMY_OP: + self.assertEquals(ti.state, State.NONE) + + # second poke returns False, timeout occurs + date2 = date1 + timedelta(seconds=sensor.poke_interval) + with freeze_time(date2): + with self.assertRaises(AirflowSensorTimeout): + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + if ti.task_id == SENSOR_OP: + self.assertEquals(ti.state, State.FAILED) + if ti.task_id == DUMMY_OP: + self.assertEquals(ti.state, State.NONE) + + def test_soft_fail_with_reschedule(self): + sensor = self._make_sensor( + return_value=False, + poke_interval=10, + timeout=5, + soft_fail=True, + mode='reschedule') + dr = self._make_dag_run() + + # first poke returns False and task is re-scheduled + date1 = timezone.utcnow() + with freeze_time(date1): + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + if ti.task_id == SENSOR_OP: + self.assertEquals(ti.state, State.NONE) + if ti.task_id == DUMMY_OP: + self.assertEquals(ti.state, State.NONE) + + # second poke returns False, timeout occurs + date2 = date1 + timedelta(seconds=sensor.poke_interval) + with freeze_time(date2): + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + self.assertEquals(ti.state, State.SKIPPED) + + def test_ok_with_reschedule_and_retry(self): + sensor = self._make_sensor( + return_value=None, + poke_interval=10, + timeout=5, + retries=1, + retry_delay=timedelta(seconds=10), + mode='reschedule') + sensor.poke = Mock(side_effect=[False, False, False, True]) + dr = self._make_dag_run() + + # first poke returns False and task is re-scheduled + date1 = timezone.utcnow() + with freeze_time(date1): + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + if ti.task_id == SENSOR_OP: + self.assertEquals(ti.state, State.NONE) + # verify one row in task_reschedule table + task_reschedules = TaskReschedule.find_for_task_instance(ti) + self.assertEquals(len(task_reschedules), 1) + self.assertEquals(task_reschedules[0].start_date, date1) + self.assertEquals(task_reschedules[0].reschedule_date, + date1 + timedelta(seconds=sensor.poke_interval)) + self.assertEqual(task_reschedules[0].try_number, 1) + if ti.task_id == DUMMY_OP: + self.assertEquals(ti.state, State.NONE) + + # second poke fails and task instance is marked up to retry + date2 = date1 + timedelta(seconds=sensor.poke_interval) + with freeze_time(date2): + with self.assertRaises(AirflowSensorTimeout): + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + if ti.task_id == SENSOR_OP: + self.assertEquals(ti.state, State.UP_FOR_RETRY) + if ti.task_id == DUMMY_OP: + self.assertEquals(ti.state, State.NONE) + + # third poke returns False and task is rescheduled again + date3 = date2 + timedelta(seconds=sensor.poke_interval) + sensor.retry_delay + with freeze_time(date3): + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + if ti.task_id == SENSOR_OP: + self.assertEquals(ti.state, State.NONE) + # verify one row in task_reschedule table + task_reschedules = TaskReschedule.find_for_task_instance(ti) + self.assertEquals(len(task_reschedules), 1) + self.assertEquals(task_reschedules[0].start_date, date3) + self.assertEquals(task_reschedules[0].reschedule_date, + date3 + timedelta(seconds=sensor.poke_interval)) + self.assertEqual(task_reschedules[0].try_number, 2) + if ti.task_id == DUMMY_OP: + self.assertEquals(ti.state, State.NONE) + + # fourth poke return True and task succeeds + date4 = date3 + timedelta(seconds=sensor.poke_interval) + with freeze_time(date4): + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + if ti.task_id == SENSOR_OP: + self.assertEquals(ti.state, State.SUCCESS) + if ti.task_id == DUMMY_OP: + self.assertEquals(ti.state, State.NONE) + + def test_should_include_ready_to_reschedule_dep(self): + sensor = self._make_sensor(True) + deps = sensor.deps + self.assertTrue(ReadyToRescheduleDep() in deps) + + def test_invalid_mode(self): + with self.assertRaises(AirflowException): + self._make_sensor( + return_value=True, + mode='foo') + + def test_ok_with_custom_reschedule_exception(self): + sensor = self._make_sensor( + return_value=None, + mode='reschedule') + date1 = timezone.utcnow() + date2 = date1 + timedelta(seconds=60) + date3 = date1 + timedelta(seconds=120) + sensor.poke = Mock(side_effect=[ + AirflowRescheduleException(date2), + AirflowRescheduleException(date3), + True, + ]) + dr = self._make_dag_run() + + # first poke returns False and task is re-scheduled + with freeze_time(date1): + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + if ti.task_id == SENSOR_OP: + # verify task is re-scheduled, i.e. state set to NONE + self.assertEquals(ti.state, State.NONE) + # verify one row in task_reschedule table + task_reschedules = TaskReschedule.find_for_task_instance(ti) + self.assertEquals(len(task_reschedules), 1) + self.assertEquals(task_reschedules[0].start_date, date1) + self.assertEquals(task_reschedules[0].reschedule_date, date2) + if ti.task_id == DUMMY_OP: + self.assertEquals(ti.state, State.NONE) + + # second poke returns False and task is re-scheduled + with freeze_time(date2): + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + if ti.task_id == SENSOR_OP: + # verify task is re-scheduled, i.e. state set to NONE + self.assertEquals(ti.state, State.NONE) + # verify two rows in task_reschedule table + task_reschedules = TaskReschedule.find_for_task_instance(ti) + self.assertEquals(len(task_reschedules), 2) + self.assertEquals(task_reschedules[1].start_date, date2) + self.assertEquals(task_reschedules[1].reschedule_date, date3) + if ti.task_id == DUMMY_OP: + self.assertEquals(ti.state, State.NONE) + + # third poke returns True and task succeeds + with freeze_time(date3): + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + if ti.task_id == SENSOR_OP: + self.assertEquals(ti.state, State.SUCCESS) + if ti.task_id == DUMMY_OP: + self.assertEquals(ti.state, State.NONE) + + def test_reschedule_with_test_mode(self): + sensor = self._make_sensor( + return_value=None, + poke_interval=10, + timeout=25, + mode='reschedule') + sensor.poke = Mock(side_effect=[False]) + dr = self._make_dag_run() + + # poke returns False and AirflowRescheduleException is raised + date1 = timezone.utcnow() + with freeze_time(date1): + for dt in self.dag.date_range(DEFAULT_DATE, end_date=DEFAULT_DATE): + TaskInstance(sensor, dt).run( + ignore_ti_state=True, + test_mode=True) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + if ti.task_id == SENSOR_OP: + # in test mode state is not modified + self.assertEquals(ti.state, State.NONE) + # in test mode no reschedule request is recorded + task_reschedules = TaskReschedule.find_for_task_instance(ti) + self.assertEquals(len(task_reschedules), 0) + if ti.task_id == DUMMY_OP: + self.assertEquals(ti.state, State.NONE) diff --git a/tests/ti_deps/deps/test_ready_to_reschedule_dep.py b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py new file mode 100644 index 0000000000000..898850f8b7bfb --- /dev/null +++ b/tests/ti_deps/deps/test_ready_to_reschedule_dep.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +from datetime import timedelta +from mock import Mock, patch + +from airflow.models import TaskInstance, DAG, TaskReschedule +from airflow.ti_deps.dep_context import DepContext +from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep +from airflow.utils.state import State +from airflow.utils.timezone import utcnow + + +class NotInReschedulePeriodDepTest(unittest.TestCase): + + def _get_task_instance(self, state): + dag = DAG('test_dag') + task = Mock(dag=dag) + ti = TaskInstance(task=task, state=state, execution_date=None) + return ti + + def _get_task_reschedule(self, reschedule_date): + task = Mock(dag_id='test_dag', task_id='test_task') + tr = TaskReschedule(task=task, execution_date=None, try_number=None, + start_date=reschedule_date, end_date=reschedule_date, + reschedule_date=reschedule_date) + return tr + + def test_should_pass_if_ignore_in_reschedule_period_is_set(self): + ti = self._get_task_instance(State.NONE) + dep_context = DepContext(ignore_in_reschedule_period=True) + self.assertTrue(ReadyToRescheduleDep().is_met(ti=ti, dep_context=dep_context)) + + def test_should_pass_if_not_in_none_state(self): + ti = self._get_task_instance(State.UP_FOR_RETRY) + self.assertTrue(ReadyToRescheduleDep().is_met(ti=ti)) + + @patch('airflow.models.TaskReschedule.find_for_task_instance', return_value=[]) + def test_should_pass_if_no_reschedule_record_exists(self, find_for_task_instance): + ti = self._get_task_instance(State.NONE) + self.assertTrue(ReadyToRescheduleDep().is_met(ti=ti)) + + @patch('airflow.models.TaskReschedule.find_for_task_instance') + def test_should_pass_after_reschedule_date_one(self, find_for_task_instance): + find_for_task_instance.return_value = [ + self._get_task_reschedule(utcnow() - timedelta(minutes=1)), + ] + ti = self._get_task_instance(State.NONE) + self.assertTrue(ReadyToRescheduleDep().is_met(ti=ti)) + + @patch('airflow.models.TaskReschedule.find_for_task_instance') + def test_should_pass_after_reschedule_date_multiple(self, find_for_task_instance): + find_for_task_instance.return_value = [ + self._get_task_reschedule(utcnow() - timedelta(minutes=21)), + self._get_task_reschedule(utcnow() - timedelta(minutes=11)), + self._get_task_reschedule(utcnow() - timedelta(minutes=1)), + ] + ti = self._get_task_instance(State.NONE) + self.assertTrue(ReadyToRescheduleDep().is_met(ti=ti)) + + @patch('airflow.models.TaskReschedule.find_for_task_instance') + def test_should_fail_before_reschedule_date_one(self, find_for_task_instance): + find_for_task_instance.return_value = [ + self._get_task_reschedule(utcnow() + timedelta(minutes=1)), + ] + ti = self._get_task_instance(State.NONE) + self.assertFalse(ReadyToRescheduleDep().is_met(ti=ti)) + + @patch('airflow.models.TaskReschedule.find_for_task_instance') + def test_should_fail_before_reschedule_date_multiple(self, find_for_task_instance): + find_for_task_instance.return_value = [ + self._get_task_reschedule(utcnow() - timedelta(minutes=19)), + self._get_task_reschedule(utcnow() - timedelta(minutes=9)), + self._get_task_reschedule(utcnow() + timedelta(minutes=1)), + ] + ti = self._get_task_instance(State.NONE) + self.assertFalse(ReadyToRescheduleDep().is_met(ti=ti))