Skip to content

Commit

Permalink
[AIRFLOW-2747] Explicit re-schedule of sensors (apache#3596)
Browse files Browse the repository at this point in the history
* [AIRFLOW-2747] Explicit re-schedule of sensors

Add `mode` property to sensors. If set to `reschedule` an
AirflowRescheduleException is raised instead of sleeping which sets
the task back to state `NONE`. Reschedules are recorded in new
`task_schedule` table and visualized in the Gantt view. New TI
dependency checks if a sensor task is ready to be re-scheduled.

* Reformat sqlalchemy imports

* Make `_handle_reschedule` private

* Remove print

* Add comment

* Add comment

* Don't record reschule request in test mode
  • Loading branch information
seelmann authored and Fokko Driesprong committed Dec 6, 2018
1 parent 23ff340 commit e1e014e
Show file tree
Hide file tree
Showing 12 changed files with 810 additions and 32 deletions.
11 changes: 11 additions & 0 deletions airflow/exceptions.py
Expand Up @@ -47,6 +47,17 @@ class AirflowSensorTimeout(AirflowException):
pass


class AirflowRescheduleException(AirflowException):
"""
Raise when the task should be re-scheduled at a later time.
:param reschedule_date: The date when the task should be rescheduled
:type reschedule: datetime
"""
def __init__(self, reschedule_date):
self.reschedule_date = reschedule_date


class AirflowTaskTimeout(AirflowException):
pass

Expand Down
@@ -0,0 +1,83 @@
# flake8: noqa
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""add task_reschedule table
Revision ID: 0a2a5b66e19d
Revises: 9635ae0956e7
Create Date: 2018-06-17 22:50:00.053620
"""

# revision identifiers, used by Alembic.
revision = '0a2a5b66e19d'
down_revision = '9635ae0956e7'
branch_labels = None
depends_on = None

from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql


TABLE_NAME = 'task_reschedule'
INDEX_NAME = 'idx_' + TABLE_NAME + '_dag_task_date'

def mysql_timestamp():
return mysql.TIMESTAMP(fsp=6)

def sa_timestamp():
return sa.TIMESTAMP(timezone=True)

def upgrade():
# See 0e2a74e0fc9f_add_time_zone_awareness
conn = op.get_bind()
if conn.dialect.name == 'mysql':
timestamp = mysql_timestamp
else:
timestamp = sa_timestamp

op.create_table(
TABLE_NAME,
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('task_id', sa.String(length=250), nullable=False),
sa.Column('dag_id', sa.String(length=250), nullable=False),
# use explicit server_default=None otherwise mysql implies defaults for first timestamp column
sa.Column('execution_date', timestamp(), nullable=False, server_default=None),
sa.Column('try_number', sa.Integer(), nullable=False),
sa.Column('start_date', timestamp(), nullable=False),
sa.Column('end_date', timestamp(), nullable=False),
sa.Column('duration', sa.Integer(), nullable=False),
sa.Column('reschedule_date', timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.ForeignKeyConstraint(['task_id', 'dag_id', 'execution_date'],
['task_instance.task_id', 'task_instance.dag_id','task_instance.execution_date'],
name='task_reschedule_dag_task_date_fkey')
)
op.create_index(
INDEX_NAME,
TABLE_NAME,
['dag_id', 'task_id', 'execution_date'],
unique=False
)


def downgrade():
op.drop_index(INDEX_NAME, table_name=TABLE_NAME)
op.drop_table(TABLE_NAME)
105 changes: 99 additions & 6 deletions airflow/models.py
Expand Up @@ -56,9 +56,10 @@
from urllib.parse import urlparse, quote, parse_qsl, unquote

from sqlalchemy import (
Column, Integer, String, DateTime, Text, Boolean, ForeignKey, PickleType,
Index, Float, LargeBinary, UniqueConstraint)
from sqlalchemy import func, or_, and_, true as sqltrue
Boolean, Column, DateTime, Float, ForeignKey, ForeignKeyConstraint, Index,
Integer, LargeBinary, PickleType, String, Text, UniqueConstraint,
and_, asc, func, or_, true as sqltrue
)
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import reconstructor, relationship, synonym

Expand All @@ -71,7 +72,8 @@
from airflow.executors import GetDefaultExecutor, LocalExecutor
from airflow import configuration
from airflow.exceptions import (
AirflowDagCycleException, AirflowException, AirflowSkipException, AirflowTaskTimeout
AirflowDagCycleException, AirflowException, AirflowSkipException, AirflowTaskTimeout,
AirflowRescheduleException
)
from airflow.dag.base_dag import BaseDag, BaseDagBag
from airflow.lineage import apply_lineage, prepare_lineage
Expand Down Expand Up @@ -1677,6 +1679,10 @@ def signal_handler(signum, frame):
except AirflowSkipException:
self.refresh_from_db(lock_for_update=True)
self.state = State.SKIPPED
except AirflowRescheduleException as reschedule_exception:
self.refresh_from_db()
self._handle_reschedule(reschedule_exception, test_mode, context)
return
except AirflowException as e:
self.refresh_from_db()
# for case when task is marked as success externally
Expand Down Expand Up @@ -1746,6 +1752,32 @@ def dry_run(self):
self.render_templates()
task_copy.dry_run()

@provide_session
def _handle_reschedule(self, reschedule_exception, test_mode=False, context=None,
session=None):
# Don't record reschedule request in test mode
if test_mode:
return

self.end_date = timezone.utcnow()
self.set_duration()

# Log reschedule request
session.add(TaskReschedule(self.task, self.execution_date, self._try_number,
self.start_date, self.end_date,
reschedule_exception.reschedule_date))

# set state
self.state = State.NONE

# Decrement try_number so subsequent runs will use the same try number and write
# to same log file.
self._try_number -= 1

session.merge(self)
session.commit()
self.log.info('Rescheduling task, marking task as NONE')

@provide_session
def handle_failure(self, error, test_mode=False, context=None, session=None):
self.log.exception(error)
Expand Down Expand Up @@ -2106,6 +2138,66 @@ def __init__(self, task, execution_date, start_date, end_date):
self.duration = None


class TaskReschedule(Base):
"""
TaskReschedule tracks rescheduled task instances.
"""

__tablename__ = "task_reschedule"

id = Column(Integer, primary_key=True)
task_id = Column(String(ID_LEN), nullable=False)
dag_id = Column(String(ID_LEN), nullable=False)
execution_date = Column(UtcDateTime, nullable=False)
try_number = Column(Integer, nullable=False)
start_date = Column(UtcDateTime, nullable=False)
end_date = Column(UtcDateTime, nullable=False)
duration = Column(Integer, nullable=False)
reschedule_date = Column(UtcDateTime, nullable=False)

__table_args__ = (
Index('idx_task_reschedule_dag_task_date', dag_id, task_id, execution_date,
unique=False),
ForeignKeyConstraint([task_id, dag_id, execution_date],
[TaskInstance.task_id, TaskInstance.dag_id,
TaskInstance.execution_date],
name='task_reschedule_dag_task_date_fkey')
)

def __init__(self, task, execution_date, try_number, start_date, end_date,
reschedule_date):
self.dag_id = task.dag_id
self.task_id = task.task_id
self.execution_date = execution_date
self.try_number = try_number
self.start_date = start_date
self.end_date = end_date
self.reschedule_date = reschedule_date
self.duration = (self.end_date - self.start_date).total_seconds()

@staticmethod
@provide_session
def find_for_task_instance(task_instance, session):
"""
Returns all task reschedules for the task instance and try number,
in ascending order.
:param task_instance: the task instance to find task reschedules for
:type task_instance: TaskInstance
"""
TR = TaskReschedule
return (
session
.query(TR)
.filter(TR.dag_id == task_instance.dag_id,
TR.task_id == task_instance.task_id,
TR.execution_date == task_instance.execution_date,
TR.try_number == task_instance.try_number)
.order_by(asc(TR.id))
.all()
)


class Log(Base):
"""
Used to actively log events to the database
Expand Down Expand Up @@ -5108,12 +5200,13 @@ def update_state(self, session=None):
no_dependencies_met = True
for ut in unfinished_tasks:
# We need to flag upstream and check for changes because upstream
# failures can result in deadlock false positives
# failures/re-schedules can result in deadlock false positives
old_state = ut.state
deps_met = ut.are_dependencies_met(
dep_context=DepContext(
flag_upstream_failed=True,
ignore_in_retry_period=True),
ignore_in_retry_period=True,
ignore_in_reschedule_period=True),
session=session)
if deps_met or old_state != ut.current_state(session=session):
no_dependencies_met = False
Expand Down
56 changes: 51 additions & 5 deletions airflow/sensors/base_sensor_operator.py
Expand Up @@ -19,20 +19,22 @@


from time import sleep
from datetime import timedelta

from airflow.exceptions import AirflowException, AirflowSensorTimeout, \
AirflowSkipException
from airflow.models import BaseOperator, SkipMixin
AirflowSkipException, AirflowRescheduleException
from airflow.models import BaseOperator, SkipMixin, TaskReschedule
from airflow.utils import timezone
from airflow.utils.decorators import apply_defaults
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep


class BaseSensorOperator(BaseOperator, SkipMixin):
"""
Sensor operators are derived from this class an inherit these attributes.
Sensor operators are derived from this class and inherit these attributes.
Sensor operators keep executing at a time interval and succeed when
a criteria is met and fail if and when they time out.
a criteria is met and fail if and when they time out.
:param soft_fail: Set to true to mark the task as SKIPPED on failure
:type soft_fail: bool
Expand All @@ -41,20 +43,42 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
:type poke_interval: int
:param timeout: Time, in seconds before the task times out and fails.
:type timeout: int
:param mode: How the sensor operates.
Options are: ``{ poke | reschedule }``, default is ``poke``.
When set to ``poke`` the sensor is taking up a worker slot for its
whole execution time and sleeps between pokes. Use this mode if the
expected runtime of the sensor is short or if a short poke interval
is requried.
When set to ``reschedule`` the sensor task frees the worker slot when
the criteria is not yet met and it's rescheduled at a later time. Use
this mode if the expected time until the criteria is met is. The poke
inteval should be more than one minute to prevent too much load on
the scheduler.
:type mode: str
"""
ui_color = '#e6f1f2'
valid_modes = ['poke', 'reschedule']

@apply_defaults
def __init__(self,
poke_interval=60,
timeout=60 * 60 * 24 * 7,
soft_fail=False,
mode='poke',
*args,
**kwargs):
super(BaseSensorOperator, self).__init__(*args, **kwargs)
self.poke_interval = poke_interval
self.soft_fail = soft_fail
self.timeout = timeout
if mode not in self.valid_modes:
raise AirflowException(
"The mode must be one of {valid_modes},"
"'{d}.{t}'; received '{m}'."
.format(valid_modes=self.valid_modes,
d=self.dag.dag_id if self.dag else "",
t=self.task_id, m=mode))
self.mode = mode

def poke(self, context):
"""
Expand All @@ -65,6 +89,11 @@ def poke(self, context):

def execute(self, context):
started_at = timezone.utcnow()
if self.reschedule:
# If reschedule, use first start date of current try
task_reschedules = TaskReschedule.find_for_task_instance(context['ti'])
if task_reschedules:
started_at = task_reschedules[0].start_date
while not self.poke(context):
if (timezone.utcnow() - started_at).total_seconds() > self.timeout:
# If sensor is in soft fail mode but will be retried then
Expand All @@ -75,11 +104,28 @@ def execute(self, context):
raise AirflowSkipException('Snap. Time is OUT.')
else:
raise AirflowSensorTimeout('Snap. Time is OUT.')
sleep(self.poke_interval)
if self.reschedule:
reschedule_date = timezone.utcnow() + timedelta(
seconds=self.poke_interval)
raise AirflowRescheduleException(reschedule_date)
else:
sleep(self.poke_interval)
self.log.info("Success criteria met. Exiting.")

def _do_skip_downstream_tasks(self, context):
downstream_tasks = context['task'].get_flat_relatives(upstream=False)
self.log.debug("Downstream task_ids %s", downstream_tasks)
if downstream_tasks:
self.skip(context['dag_run'], context['ti'].execution_date, downstream_tasks)

@property
def reschedule(self):
return self.mode == 'reschedule'

@property
def deps(self):
"""
Adds one additional dependency for all sensor operators that
checks if a sensor task instance can be rescheduled.
"""
return BaseOperator.deps.fget(self) | {ReadyToRescheduleDep()}
6 changes: 5 additions & 1 deletion airflow/ti_deps/dep_context.py
Expand Up @@ -57,7 +57,9 @@ class DepContext(object):
Backfills)
:type ignore_depends_on_past: boolean
:param ignore_in_retry_period: Ignore the retry period for task instances
:type ignore_in_retry_period: boolean
:type ignore_in_retry_period: bool
:param ignore_in_reschedule_period: Ignore the reschedule period for task instances
:type ignore_in_reschedule_period: bool
:param ignore_task_deps: Ignore task-specific dependencies such as depends_on_past and
trigger rule
:type ignore_task_deps: boolean
Expand All @@ -71,13 +73,15 @@ def __init__(
ignore_all_deps=False,
ignore_depends_on_past=False,
ignore_in_retry_period=False,
ignore_in_reschedule_period=False,
ignore_task_deps=False,
ignore_ti_state=False):
self.deps = deps or set()
self.flag_upstream_failed = flag_upstream_failed
self.ignore_all_deps = ignore_all_deps
self.ignore_depends_on_past = ignore_depends_on_past
self.ignore_in_retry_period = ignore_in_retry_period
self.ignore_in_reschedule_period = ignore_in_reschedule_period
self.ignore_task_deps = ignore_task_deps
self.ignore_ti_state = ignore_ti_state

Expand Down

0 comments on commit e1e014e

Please sign in to comment.