Permalink
9a80ab0 Dec 6, 2018
193 contributors

Users who have contributed to this file

@mistercrunch @bolkedebruin @jlowin @aoen @r39132 @ashb @saguziel @feng-tao @kaxil @criccomini @patrickleotardif @neovintage @jgao54 @Fokko @sv3ndk @skudriashev @KevinYang21 @gwax @edgarRd @btallman @wongwill86 @sekikn @normster @XD-DENG @msumit @syvineckruyk
executable file 5553 lines (4845 sloc) 199 KB
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import copy
from collections import defaultdict, namedtuple
from builtins import ImportError as BuiltinImportError, bytes, object, str
from future.standard_library import install_aliases
try:
# Fix Python > 3.7 deprecation
from collections.abc import Hashable
except ImportError:
# Preserve Python < 3.3 compatibility
from collections import Hashable
from datetime import timedelta
import dill
import functools
import getpass
import imp
import importlib
import itertools
import zipfile
import jinja2
import json
import logging
import numbers
import os
import pendulum
import pickle
import re
import signal
import sys
import textwrap
import traceback
import warnings
import hashlib
import uuid
from datetime import datetime
from urllib.parse import urlparse, quote, parse_qsl, unquote
from sqlalchemy import (
Boolean, Column, DateTime, Float, ForeignKey, ForeignKeyConstraint, Index,
Integer, LargeBinary, PickleType, String, Text, UniqueConstraint, MetaData,
and_, asc, func, or_, true as sqltrue
)
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import reconstructor, relationship, synonym
from croniter import (
croniter, CroniterBadCronError, CroniterBadDateError, CroniterNotAlphaError
)
import six
from airflow import settings, utils
from airflow.executors import GetDefaultExecutor, LocalExecutor
from airflow import configuration
from airflow.exceptions import (
AirflowDagCycleException, AirflowException, AirflowSkipException, AirflowTaskTimeout,
AirflowRescheduleException
)
from airflow.dag.base_dag import BaseDag, BaseDagBag
from airflow.lineage import apply_lineage, prepare_lineage
from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS
from airflow.utils import timezone
from airflow.utils.dag_processing import list_py_file_paths
from airflow.utils.dates import cron_presets, date_range as utils_date_range
from airflow.utils.db import provide_session
from airflow.utils.decorators import apply_defaults
from airflow.utils.email import send_email
from airflow.utils.helpers import (
as_tuple, is_container, validate_key, pprinttable)
from airflow.utils.operator_resources import Resources
from airflow.utils.state import State
from airflow.utils.sqlalchemy import UtcDateTime
from airflow.utils.timeout import timeout
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule
from airflow.utils.net import get_hostname
from airflow.utils.log.logging_mixin import LoggingMixin
install_aliases()
SQL_ALCHEMY_SCHEMA = configuration.get('core', 'SQL_ALCHEMY_SCHEMA')
if not SQL_ALCHEMY_SCHEMA or SQL_ALCHEMY_SCHEMA.isspace():
Base = declarative_base()
else:
Base = declarative_base(metadata=MetaData(schema=SQL_ALCHEMY_SCHEMA))
ID_LEN = 250
XCOM_RETURN_KEY = 'return_value'
Stats = settings.Stats
class InvalidFernetToken(Exception):
# If Fernet isn't loaded we need a valid exception class to catch. If it is
# loaded this will get reset to the actual class once get_fernet() is called
pass
class NullFernet(object):
"""
A "Null" encryptor class that doesn't encrypt or decrypt but that presents
a similar interface to Fernet.
The purpose of this is to make the rest of the code not have to know the
difference, and to only display the message once, not 20 times when
`airflow initdb` is ran.
"""
is_encrypted = False
def decrpyt(self, b):
return b
def encrypt(self, b):
return b
_fernet = None
def get_fernet():
"""
Deferred load of Fernet key.
This function could fail either because Cryptography is not installed
or because the Fernet key is invalid.
:return: Fernet object
:raises: AirflowException if there's a problem trying to load Fernet
"""
global _fernet
log = LoggingMixin().log
if _fernet:
return _fernet
try:
from cryptography.fernet import Fernet, InvalidToken
global InvalidFernetToken
InvalidFernetToken = InvalidToken
except BuiltinImportError:
log.warning(
"cryptography not found - values will not be stored encrypted."
)
_fernet = NullFernet()
return _fernet
try:
fernet_key = configuration.conf.get('core', 'FERNET_KEY')
if not fernet_key:
log.warning(
"empty cryptography key - values will not be stored encrypted."
)
_fernet = NullFernet()
else:
_fernet = Fernet(fernet_key.encode('utf-8'))
_fernet.is_encrypted = True
except (ValueError, TypeError) as ve:
raise AirflowException("Could not create Fernet object: {}".format(ve))
return _fernet
# Used by DAG context_managers
_CONTEXT_MANAGER_DAG = None
def clear_task_instances(tis,
session,
activate_dag_runs=True,
dag=None,
):
"""
Clears a set of task instances, but makes sure the running ones
get killed.
:param tis: a list of task instances
:param session: current session
:param activate_dag_runs: flag to check for active dag run
:param dag: DAG object
"""
job_ids = []
for ti in tis:
if ti.state == State.RUNNING:
if ti.job_id:
ti.state = State.SHUTDOWN
job_ids.append(ti.job_id)
else:
task_id = ti.task_id
if dag and dag.has_task(task_id):
task = dag.get_task(task_id)
task_retries = task.retries
ti.max_tries = ti.try_number + task_retries - 1
else:
# Ignore errors when updating max_tries if dag is None or
# task not found in dag since database records could be
# outdated. We make max_tries the maximum value of its
# original max_tries or the current task try number.
ti.max_tries = max(ti.max_tries, ti.try_number - 1)
ti.state = State.NONE
session.merge(ti)
if job_ids:
from airflow.jobs import BaseJob as BJ
for job in session.query(BJ).filter(BJ.id.in_(job_ids)).all():
job.state = State.SHUTDOWN
if activate_dag_runs and tis:
drs = session.query(DagRun).filter(
DagRun.dag_id.in_({ti.dag_id for ti in tis}),
DagRun.execution_date.in_({ti.execution_date for ti in tis}),
).all()
for dr in drs:
dr.state = State.RUNNING
dr.start_date = timezone.utcnow()
class DagBag(BaseDagBag, LoggingMixin):
"""
A dagbag is a collection of dags, parsed out of a folder tree and has high
level configuration settings, like what database to use as a backend and
what executor to use to fire off tasks. This makes it easier to run
distinct environments for say production and development, tests, or for
different teams or security profiles. What would have been system level
settings are now dagbag level so that one system can run multiple,
independent settings sets.
:param dag_folder: the folder to scan to find DAGs
:type dag_folder: unicode
:param executor: the executor to use when executing task instances
in this DagBag
:param include_examples: whether to include the examples that ship
with airflow or not
:type include_examples: bool
:param has_logged: an instance boolean that gets flipped from False to True after a
file has been skipped. This is to prevent overloading the user with logging
messages about skipped files. Therefore only once per DagBag is a file logged
being skipped.
"""
# static class variables to detetct dag cycle
CYCLE_NEW = 0
CYCLE_IN_PROGRESS = 1
CYCLE_DONE = 2
def __init__(
self,
dag_folder=None,
executor=None,
include_examples=configuration.conf.getboolean('core', 'LOAD_EXAMPLES')):
# do not use default arg in signature, to fix import cycle on plugin load
if executor is None:
executor = GetDefaultExecutor()
dag_folder = dag_folder or settings.DAGS_FOLDER
self.log.info("Filling up the DagBag from %s", dag_folder)
self.dag_folder = dag_folder
self.dags = {}
# the file's last modified timestamp when we last read it
self.file_last_changed = {}
self.executor = executor
self.import_errors = {}
self.has_logged = False
self.collect_dags(dag_folder, include_examples)
def size(self):
"""
:return: the amount of dags contained in this dagbag
"""
return len(self.dags)
def get_dag(self, dag_id):
"""
Gets the DAG out of the dictionary, and refreshes it if expired
"""
# If asking for a known subdag, we want to refresh the parent
root_dag_id = dag_id
if dag_id in self.dags:
dag = self.dags[dag_id]
if dag.is_subdag:
root_dag_id = dag.parent_dag.dag_id
# If the dag corresponding to root_dag_id is absent or expired
orm_dag = DagModel.get_current(root_dag_id)
if orm_dag and (
root_dag_id not in self.dags or
(
orm_dag.last_expired and
dag.last_loaded < orm_dag.last_expired
)
):
# Reprocess source file
found_dags = self.process_file(
filepath=orm_dag.fileloc, only_if_updated=False)
# If the source file no longer exports `dag_id`, delete it from self.dags
if found_dags and dag_id in [found_dag.dag_id for found_dag in found_dags]:
return self.dags[dag_id]
elif dag_id in self.dags:
del self.dags[dag_id]
return self.dags.get(dag_id)
def process_file(self, filepath, only_if_updated=True, safe_mode=True):
"""
Given a path to a python module or zip file, this method imports
the module and look for dag objects within it.
"""
found_dags = []
# if the source file no longer exists in the DB or in the filesystem,
# return an empty list
# todo: raise exception?
if filepath is None or not os.path.isfile(filepath):
return found_dags
try:
# This failed before in what may have been a git sync
# race condition
file_last_changed_on_disk = datetime.fromtimestamp(os.path.getmtime(filepath))
if only_if_updated \
and filepath in self.file_last_changed \
and file_last_changed_on_disk == self.file_last_changed[filepath]:
return found_dags
except Exception as e:
self.log.exception(e)
return found_dags
mods = []
is_zipfile = zipfile.is_zipfile(filepath)
if not is_zipfile:
if safe_mode and os.path.isfile(filepath):
with open(filepath, 'rb') as f:
content = f.read()
if not all([s in content for s in (b'DAG', b'airflow')]):
self.file_last_changed[filepath] = file_last_changed_on_disk
# Don't want to spam user with skip messages
if not self.has_logged:
self.has_logged = True
self.log.info(
"File %s assumed to contain no DAGs. Skipping.",
filepath)
return found_dags
self.log.debug("Importing %s", filepath)
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
mod_name = ('unusual_prefix_' +
hashlib.sha1(filepath.encode('utf-8')).hexdigest() +
'_' + org_mod_name)
if mod_name in sys.modules:
del sys.modules[mod_name]
with timeout(configuration.conf.getint('core', "DAGBAG_IMPORT_TIMEOUT")):
try:
m = imp.load_source(mod_name, filepath)
mods.append(m)
except Exception as e:
self.log.exception("Failed to import: %s", filepath)
self.import_errors[filepath] = str(e)
self.file_last_changed[filepath] = file_last_changed_on_disk
else:
zip_file = zipfile.ZipFile(filepath)
for mod in zip_file.infolist():
head, _ = os.path.split(mod.filename)
mod_name, ext = os.path.splitext(mod.filename)
if not head and (ext == '.py' or ext == '.pyc'):
if mod_name == '__init__':
self.log.warning("Found __init__.%s at root of %s", ext, filepath)
if safe_mode:
with zip_file.open(mod.filename) as zf:
self.log.debug("Reading %s from %s", mod.filename, filepath)
content = zf.read()
if not all([s in content for s in (b'DAG', b'airflow')]):
self.file_last_changed[filepath] = (
file_last_changed_on_disk)
# todo: create ignore list
# Don't want to spam user with skip messages
if not self.has_logged:
self.has_logged = True
self.log.info(
"File %s assumed to contain no DAGs. Skipping.",
filepath)
if mod_name in sys.modules:
del sys.modules[mod_name]
try:
sys.path.insert(0, filepath)
m = importlib.import_module(mod_name)
mods.append(m)
except Exception as e:
self.log.exception("Failed to import: %s", filepath)
self.import_errors[filepath] = str(e)
self.file_last_changed[filepath] = file_last_changed_on_disk
for m in mods:
for dag in list(m.__dict__.values()):
if isinstance(dag, DAG):
if not dag.full_filepath:
dag.full_filepath = filepath
if dag.fileloc != filepath and not is_zipfile:
dag.fileloc = filepath
try:
dag.is_subdag = False
self.bag_dag(dag, parent_dag=dag, root_dag=dag)
if isinstance(dag._schedule_interval, six.string_types):
croniter(dag._schedule_interval)
found_dags.append(dag)
found_dags += dag.subdags
except (CroniterBadCronError,
CroniterBadDateError,
CroniterNotAlphaError) as cron_e:
self.log.exception("Failed to bag_dag: %s", dag.full_filepath)
self.import_errors[dag.full_filepath] = \
"Invalid Cron expression: " + str(cron_e)
self.file_last_changed[dag.full_filepath] = \
file_last_changed_on_disk
except AirflowDagCycleException as cycle_exception:
self.log.exception("Failed to bag_dag: %s", dag.full_filepath)
self.import_errors[dag.full_filepath] = str(cycle_exception)
self.file_last_changed[dag.full_filepath] = \
file_last_changed_on_disk
self.file_last_changed[filepath] = file_last_changed_on_disk
return found_dags
@provide_session
def kill_zombies(self, zombies, session=None):
"""
Fail given zombie tasks, which are tasks that haven't
had a heartbeat for too long, in the current DagBag.
:param zombies: zombie task instances to kill.
:type zombies: SimpleTaskInstance
:param session: DB session.
:type Session.
"""
for zombie in zombies:
if zombie.dag_id in self.dags:
dag = self.dags[zombie.dag_id]
if zombie.task_id in dag.task_ids:
task = dag.get_task(zombie.task_id)
ti = TaskInstance(task, zombie.execution_date)
# Get properties needed for failure handling from SimpleTaskInstance.
ti.start_date = zombie.start_date
ti.end_date = zombie.end_date
ti.try_number = zombie.try_number
ti.state = zombie.state
ti.test_mode = configuration.getboolean('core', 'unit_test_mode')
ti.handle_failure("{} detected as zombie".format(ti),
ti.test_mode, ti.get_template_context())
self.log.info(
'Marked zombie job %s as %s', ti, ti.state)
Stats.incr('zombies_killed')
session.commit()
def bag_dag(self, dag, parent_dag, root_dag):
"""
Adds the DAG into the bag, recurses into sub dags.
Throws AirflowDagCycleException if a cycle is detected in this dag or its subdags
"""
dag.test_cycle() # throws if a task cycle is found
dag.resolve_template_files()
dag.last_loaded = timezone.utcnow()
for task in dag.tasks:
settings.policy(task)
subdags = dag.subdags
try:
for subdag in subdags:
subdag.full_filepath = dag.full_filepath
subdag.parent_dag = dag
subdag.is_subdag = True
self.bag_dag(subdag, parent_dag=dag, root_dag=root_dag)
self.dags[dag.dag_id] = dag
self.log.debug('Loaded DAG {dag}'.format(**locals()))
except AirflowDagCycleException as cycle_exception:
# There was an error in bagging the dag. Remove it from the list of dags
self.log.exception('Exception bagging dag: {dag.dag_id}'.format(**locals()))
# Only necessary at the root level since DAG.subdags automatically
# performs DFS to search through all subdags
if dag == root_dag:
for subdag in subdags:
if subdag.dag_id in self.dags:
del self.dags[subdag.dag_id]
raise cycle_exception
def collect_dags(
self,
dag_folder=None,
only_if_updated=True,
include_examples=configuration.conf.getboolean('core', 'LOAD_EXAMPLES')):
"""
Given a file path or a folder, this method looks for python modules,
imports them and adds them to the dagbag collection.
Note that if a ``.airflowignore`` file is found while processing
the directory, it will behave much like a ``.gitignore``,
ignoring files that match any of the regex patterns specified
in the file.
**Note**: The patterns in .airflowignore are treated as
un-anchored regexes, not shell-like glob patterns.
"""
start_dttm = timezone.utcnow()
dag_folder = dag_folder or self.dag_folder
# Used to store stats around DagBag processing
stats = []
FileLoadStat = namedtuple(
'FileLoadStat', "file duration dag_num task_num dags")
for filepath in list_py_file_paths(dag_folder, include_examples):
try:
ts = timezone.utcnow()
found_dags = self.process_file(
filepath, only_if_updated=only_if_updated)
td = timezone.utcnow() - ts
td = td.total_seconds() + (
float(td.microseconds) / 1000000)
stats.append(FileLoadStat(
filepath.replace(dag_folder, ''),
td,
len(found_dags),
sum([len(dag.tasks) for dag in found_dags]),
str([dag.dag_id for dag in found_dags]),
))
except Exception as e:
self.log.exception(e)
Stats.gauge(
'collect_dags', (timezone.utcnow() - start_dttm).total_seconds(), 1)
Stats.gauge(
'dagbag_size', len(self.dags), 1)
Stats.gauge(
'dagbag_import_errors', len(self.import_errors), 1)
self.dagbag_stats = sorted(
stats, key=lambda x: x.duration, reverse=True)
def dagbag_report(self):
"""Prints a report around DagBag loading stats"""
report = textwrap.dedent("""\n
-------------------------------------------------------------------
DagBag loading stats for {dag_folder}
-------------------------------------------------------------------
Number of DAGs: {dag_num}
Total task number: {task_num}
DagBag parsing time: {duration}
{table}
""")
stats = self.dagbag_stats
return report.format(
dag_folder=self.dag_folder,
duration=sum([o.duration for o in stats]),
dag_num=sum([o.dag_num for o in stats]),
task_num=sum([o.task_num for o in stats]),
table=pprinttable(stats),
)
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
username = Column(String(ID_LEN), unique=True)
email = Column(String(500))
superuser = False
def __repr__(self):
return self.username
def get_id(self):
return str(self.id)
def is_superuser(self):
return self.superuser
class Connection(Base, LoggingMixin):
"""
Placeholder to store information about different database instances
connection information. The idea here is that scripts use references to
database instances (conn_id) instead of hard coding hostname, logins and
passwords when using operators or hooks.
"""
__tablename__ = "connection"
id = Column(Integer(), primary_key=True)
conn_id = Column(String(ID_LEN))
conn_type = Column(String(500))
host = Column(String(500))
schema = Column(String(500))
login = Column(String(500))
_password = Column('password', String(5000))
port = Column(Integer())
is_encrypted = Column(Boolean, unique=False, default=False)
is_extra_encrypted = Column(Boolean, unique=False, default=False)
_extra = Column('extra', String(5000))
_types = [
('docker', 'Docker Registry',),
('fs', 'File (path)'),
('ftp', 'FTP',),
('google_cloud_platform', 'Google Cloud Platform'),
('hdfs', 'HDFS',),
('http', 'HTTP',),
('hive_cli', 'Hive Client Wrapper',),
('hive_metastore', 'Hive Metastore Thrift',),
('hiveserver2', 'Hive Server 2 Thrift',),
('jdbc', 'Jdbc Connection',),
('jenkins', 'Jenkins'),
('mysql', 'MySQL',),
('postgres', 'Postgres',),
('oracle', 'Oracle',),
('vertica', 'Vertica',),
('presto', 'Presto',),
('s3', 'S3',),
('samba', 'Samba',),
('sqlite', 'Sqlite',),
('ssh', 'SSH',),
('cloudant', 'IBM Cloudant',),
('mssql', 'Microsoft SQL Server'),
('mesos_framework-id', 'Mesos Framework ID'),
('jira', 'JIRA',),
('redis', 'Redis',),
('wasb', 'Azure Blob Storage'),
('databricks', 'Databricks',),
('aws', 'Amazon Web Services',),
('emr', 'Elastic MapReduce',),
('snowflake', 'Snowflake',),
('segment', 'Segment',),
('azure_data_lake', 'Azure Data Lake'),
('azure_cosmos', 'Azure CosmosDB'),
('cassandra', 'Cassandra',),
('qubole', 'Qubole'),
('mongo', 'MongoDB'),
('gcpcloudsql', 'Google Cloud SQL'),
]
def __init__(
self, conn_id=None, conn_type=None,
host=None, login=None, password=None,
schema=None, port=None, extra=None,
uri=None):
self.conn_id = conn_id
if uri:
self.parse_from_uri(uri)
else:
self.conn_type = conn_type
self.host = host
self.login = login
self.password = password
self.schema = schema
self.port = port
self.extra = extra
def parse_from_uri(self, uri):
temp_uri = urlparse(uri)
hostname = temp_uri.hostname or ''
conn_type = temp_uri.scheme
if conn_type == 'postgresql':
conn_type = 'postgres'
self.conn_type = conn_type
self.host = unquote(hostname) if hostname else hostname
quoted_schema = temp_uri.path[1:]
self.schema = unquote(quoted_schema) if quoted_schema else quoted_schema
self.login = unquote(temp_uri.username) \
if temp_uri.username else temp_uri.username
self.password = unquote(temp_uri.password) \
if temp_uri.password else temp_uri.password
self.port = temp_uri.port
if temp_uri.query:
self.extra = json.dumps(dict(parse_qsl(temp_uri.query)))
def get_password(self):
if self._password and self.is_encrypted:
fernet = get_fernet()
if not fernet.is_encrypted:
raise AirflowException(
"Can't decrypt encrypted password for login={}, \
FERNET_KEY configuration is missing".format(self.login))
return fernet.decrypt(bytes(self._password, 'utf-8')).decode()
else:
return self._password
def set_password(self, value):
if value:
fernet = get_fernet()
self._password = fernet.encrypt(bytes(value, 'utf-8')).decode()
self.is_encrypted = fernet.is_encrypted
@declared_attr
def password(cls):
return synonym('_password',
descriptor=property(cls.get_password, cls.set_password))
def get_extra(self):
if self._extra and self.is_extra_encrypted:
fernet = get_fernet()
if not fernet.is_encrypted:
raise AirflowException(
"Can't decrypt `extra` params for login={},\
FERNET_KEY configuration is missing".format(self.login))
return fernet.decrypt(bytes(self._extra, 'utf-8')).decode()
else:
return self._extra
def set_extra(self, value):
if value:
fernet = get_fernet()
self._extra = fernet.encrypt(bytes(value, 'utf-8')).decode()
self.is_extra_encrypted = fernet.is_encrypted
else:
self._extra = value
self.is_extra_encrypted = False
@declared_attr
def extra(cls):
return synonym('_extra',
descriptor=property(cls.get_extra, cls.set_extra))
def get_hook(self):
try:
if self.conn_type == 'mysql':
from airflow.hooks.mysql_hook import MySqlHook
return MySqlHook(mysql_conn_id=self.conn_id)
elif self.conn_type == 'google_cloud_platform':
from airflow.contrib.hooks.bigquery_hook import BigQueryHook
return BigQueryHook(bigquery_conn_id=self.conn_id)
elif self.conn_type == 'postgres':
from airflow.hooks.postgres_hook import PostgresHook
return PostgresHook(postgres_conn_id=self.conn_id)
elif self.conn_type == 'hive_cli':
from airflow.hooks.hive_hooks import HiveCliHook
return HiveCliHook(hive_cli_conn_id=self.conn_id)
elif self.conn_type == 'presto':
from airflow.hooks.presto_hook import PrestoHook
return PrestoHook(presto_conn_id=self.conn_id)
elif self.conn_type == 'hiveserver2':
from airflow.hooks.hive_hooks import HiveServer2Hook
return HiveServer2Hook(hiveserver2_conn_id=self.conn_id)
elif self.conn_type == 'sqlite':
from airflow.hooks.sqlite_hook import SqliteHook
return SqliteHook(sqlite_conn_id=self.conn_id)
elif self.conn_type == 'jdbc':
from airflow.hooks.jdbc_hook import JdbcHook
return JdbcHook(jdbc_conn_id=self.conn_id)
elif self.conn_type == 'mssql':
from airflow.hooks.mssql_hook import MsSqlHook
return MsSqlHook(mssql_conn_id=self.conn_id)
elif self.conn_type == 'oracle':
from airflow.hooks.oracle_hook import OracleHook
return OracleHook(oracle_conn_id=self.conn_id)
elif self.conn_type == 'vertica':
from airflow.contrib.hooks.vertica_hook import VerticaHook
return VerticaHook(vertica_conn_id=self.conn_id)
elif self.conn_type == 'cloudant':
from airflow.contrib.hooks.cloudant_hook import CloudantHook
return CloudantHook(cloudant_conn_id=self.conn_id)
elif self.conn_type == 'jira':
from airflow.contrib.hooks.jira_hook import JiraHook
return JiraHook(jira_conn_id=self.conn_id)
elif self.conn_type == 'redis':
from airflow.contrib.hooks.redis_hook import RedisHook
return RedisHook(redis_conn_id=self.conn_id)
elif self.conn_type == 'wasb':
from airflow.contrib.hooks.wasb_hook import WasbHook
return WasbHook(wasb_conn_id=self.conn_id)
elif self.conn_type == 'docker':
from airflow.hooks.docker_hook import DockerHook
return DockerHook(docker_conn_id=self.conn_id)
elif self.conn_type == 'azure_data_lake':
from airflow.contrib.hooks.azure_data_lake_hook import AzureDataLakeHook
return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id)
elif self.conn_type == 'azure_cosmos':
from airflow.contrib.hooks.azure_cosmos_hook import AzureCosmosDBHook
return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id)
elif self.conn_type == 'cassandra':
from airflow.contrib.hooks.cassandra_hook import CassandraHook
return CassandraHook(cassandra_conn_id=self.conn_id)
elif self.conn_type == 'mongo':
from airflow.contrib.hooks.mongo_hook import MongoHook
return MongoHook(conn_id=self.conn_id)
elif self.conn_type == 'gcpcloudsql':
from airflow.contrib.hooks.gcp_sql_hook import CloudSqlDatabaseHook
return CloudSqlDatabaseHook(gcp_cloudsql_conn_id=self.conn_id)
except Exception:
pass
def __repr__(self):
return self.conn_id
def debug_info(self):
return ("id: {}. Host: {}, Port: {}, Schema: {}, "
"Login: {}, Password: {}, extra: {}".
format(self.conn_id,
self.host,
self.port,
self.schema,
self.login,
"XXXXXXXX" if self.password else None,
self.extra_dejson))
@property
def extra_dejson(self):
"""Returns the extra property by deserializing json."""
obj = {}
if self.extra:
try:
obj = json.loads(self.extra)
except Exception as e:
self.log.exception(e)
self.log.error("Failed parsing the json for conn_id %s", self.conn_id)
return obj
class DagPickle(Base):
"""
Dags can originate from different places (user repos, master repo, ...)
and also get executed in different places (different executors). This
object represents a version of a DAG and becomes a source of truth for
a BackfillJob execution. A pickle is a native python serialized object,
and in this case gets stored in the database for the duration of the job.
The executors pick up the DagPickle id and read the dag definition from
the database.
"""
id = Column(Integer, primary_key=True)
pickle = Column(PickleType(pickler=dill))
created_dttm = Column(UtcDateTime, default=timezone.utcnow)
pickle_hash = Column(Text)
__tablename__ = "dag_pickle"
def __init__(self, dag):
self.dag_id = dag.dag_id
if hasattr(dag, 'template_env'):
dag.template_env = None
self.pickle_hash = hash(dag)
self.pickle = dag
class TaskInstance(Base, LoggingMixin):
"""
Task instances store the state of a task instance. This table is the
authority and single source of truth around what tasks have run and the
state they are in.
The SqlAlchemy model doesn't have a SqlAlchemy foreign key to the task or
dag model deliberately to have more control over transactions.
Database transactions on this table should insure double triggers and
any confusion around what task instances are or aren't ready to run
even while multiple schedulers may be firing task instances.
"""
__tablename__ = "task_instance"
task_id = Column(String(ID_LEN), primary_key=True)
dag_id = Column(String(ID_LEN), primary_key=True)
execution_date = Column(UtcDateTime, primary_key=True)
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
duration = Column(Float)
state = Column(String(20))
_try_number = Column('try_number', Integer, default=0)
max_tries = Column(Integer)
hostname = Column(String(1000))
unixname = Column(String(1000))
job_id = Column(Integer)
pool = Column(String(50))
queue = Column(String(50))
priority_weight = Column(Integer)
operator = Column(String(1000))
queued_dttm = Column(UtcDateTime)
pid = Column(Integer)
executor_config = Column(PickleType(pickler=dill))
__table_args__ = (
Index('ti_dag_state', dag_id, state),
Index('ti_dag_date', dag_id, execution_date),
Index('ti_state', state),
Index('ti_state_lkp', dag_id, task_id, execution_date, state),
Index('ti_pool', pool, state, priority_weight),
Index('ti_job_id', job_id),
)
def __init__(self, task, execution_date, state=None):
self.dag_id = task.dag_id
self.task_id = task.task_id
self.task = task
self._log = logging.getLogger("airflow.task")
# make sure we have a localized execution_date stored in UTC
if execution_date and not timezone.is_localized(execution_date):
self.log.warning("execution date %s has no timezone information. Using "
"default from dag or system", execution_date)
if self.task.has_dag():
execution_date = timezone.make_aware(execution_date,
self.task.dag.timezone)
else:
execution_date = timezone.make_aware(execution_date)
execution_date = timezone.convert_to_utc(execution_date)
self.execution_date = execution_date
self.queue = task.queue
self.pool = task.pool
self.priority_weight = task.priority_weight_total
self.try_number = 0
self.max_tries = self.task.retries
self.unixname = getpass.getuser()
self.run_as_user = task.run_as_user
if state:
self.state = state
self.hostname = ''
self.executor_config = task.executor_config
self.init_on_load()
# Is this TaskInstance being currently running within `airflow run --raw`.
# Not persisted to the database so only valid for the current process
self.raw = False
@reconstructor
def init_on_load(self):
""" Initialize the attributes that aren't stored in the DB. """
self.test_mode = False # can be changed when calling 'run'
@property
def try_number(self):
"""
Return the try number that this task number will be when it is actually
run.
If the TI is currently running, this will match the column in the
databse, in all othercases this will be incremenetd
"""
# This is designed so that task logs end up in the right file.
if self.state == State.RUNNING:
return self._try_number
return self._try_number + 1
@try_number.setter
def try_number(self, value):
self._try_number = value
@property
def next_try_number(self):
return self._try_number + 1
def command(
self,
mark_success=False,
ignore_all_deps=False,
ignore_depends_on_past=False,
ignore_task_deps=False,
ignore_ti_state=False,
local=False,
pickle_id=None,
raw=False,
job_id=None,
pool=None,
cfg_path=None):
"""
Returns a command that can be executed anywhere where airflow is
installed. This command is part of the message sent to executors by
the orchestrator.
"""
return " ".join(self.command_as_list(
mark_success=mark_success,
ignore_all_deps=ignore_all_deps,
ignore_depends_on_past=ignore_depends_on_past,
ignore_task_deps=ignore_task_deps,
ignore_ti_state=ignore_ti_state,
local=local,
pickle_id=pickle_id,
raw=raw,
job_id=job_id,
pool=pool,
cfg_path=cfg_path))
def command_as_list(
self,
mark_success=False,
ignore_all_deps=False,
ignore_task_deps=False,
ignore_depends_on_past=False,
ignore_ti_state=False,
local=False,
pickle_id=None,
raw=False,
job_id=None,
pool=None,
cfg_path=None):
"""
Returns a command that can be executed anywhere where airflow is
installed. This command is part of the message sent to executors by
the orchestrator.
"""
dag = self.task.dag
should_pass_filepath = not pickle_id and dag
if should_pass_filepath and dag.full_filepath != dag.filepath:
path = "DAGS_FOLDER/{}".format(dag.filepath)
elif should_pass_filepath and dag.full_filepath:
path = dag.full_filepath
else:
path = None
return TaskInstance.generate_command(
self.dag_id,
self.task_id,
self.execution_date,
mark_success=mark_success,
ignore_all_deps=ignore_all_deps,
ignore_task_deps=ignore_task_deps,
ignore_depends_on_past=ignore_depends_on_past,
ignore_ti_state=ignore_ti_state,
local=local,
pickle_id=pickle_id,
file_path=path,
raw=raw,
job_id=job_id,
pool=pool,
cfg_path=cfg_path)
@staticmethod
def generate_command(dag_id,
task_id,
execution_date,
mark_success=False,
ignore_all_deps=False,
ignore_depends_on_past=False,
ignore_task_deps=False,
ignore_ti_state=False,
local=False,
pickle_id=None,
file_path=None,
raw=False,
job_id=None,
pool=None,
cfg_path=None
):
"""
Generates the shell command required to execute this task instance.
:param dag_id: DAG ID
:type dag_id: unicode
:param task_id: Task ID
:type task_id: unicode
:param execution_date: Execution date for the task
:type execution_date: datetime
:param mark_success: Whether to mark the task as successful
:type mark_success: bool
:param ignore_all_deps: Ignore all ignorable dependencies.
Overrides the other ignore_* parameters.
:type ignore_all_deps: bool
:param ignore_depends_on_past: Ignore depends_on_past parameter of DAGs
(e.g. for Backfills)
:type ignore_depends_on_past: bool
:param ignore_task_deps: Ignore task-specific dependencies such as depends_on_past
and trigger rule
:type ignore_task_deps: bool
:param ignore_ti_state: Ignore the task instance's previous failure/success
:type ignore_ti_state: bool
:param local: Whether to run the task locally
:type local: bool
:param pickle_id: If the DAG was serialized to the DB, the ID
associated with the pickled DAG
:type pickle_id: unicode
:param file_path: path to the file containing the DAG definition
:param raw: raw mode (needs more details)
:param job_id: job ID (needs more details)
:param pool: the Airflow pool that the task should run in
:type pool: unicode
:param cfg_path: the Path to the configuration file
:type cfg_path: basestring
:return: shell command that can be used to run the task instance
"""
iso = execution_date.isoformat()
cmd = ["airflow", "run", str(dag_id), str(task_id), str(iso)]
cmd.extend(["--mark_success"]) if mark_success else None
cmd.extend(["--pickle", str(pickle_id)]) if pickle_id else None
cmd.extend(["--job_id", str(job_id)]) if job_id else None
cmd.extend(["-A"]) if ignore_all_deps else None
cmd.extend(["-i"]) if ignore_task_deps else None
cmd.extend(["-I"]) if ignore_depends_on_past else None
cmd.extend(["--force"]) if ignore_ti_state else None
cmd.extend(["--local"]) if local else None
cmd.extend(["--pool", pool]) if pool else None
cmd.extend(["--raw"]) if raw else None
cmd.extend(["-sd", file_path]) if file_path else None
cmd.extend(["--cfg_path", cfg_path]) if cfg_path else None
return cmd
@property
def log_filepath(self):
iso = self.execution_date.isoformat()
log = os.path.expanduser(configuration.conf.get('core', 'BASE_LOG_FOLDER'))
return (
"{log}/{self.dag_id}/{self.task_id}/{iso}.log".format(**locals()))
@property
def log_url(self):
iso = quote(self.execution_date.isoformat())
BASE_URL = configuration.conf.get('webserver', 'BASE_URL')
if settings.RBAC:
return BASE_URL + (
"/log?"
"execution_date={iso}"
"&task_id={self.task_id}"
"&dag_id={self.dag_id}"
).format(**locals())
else:
return BASE_URL + (
"/admin/airflow/log"
"?dag_id={self.dag_id}"
"&task_id={self.task_id}"
"&execution_date={iso}"
).format(**locals())
@property
def mark_success_url(self):
iso = quote(self.execution_date.isoformat())
BASE_URL = configuration.conf.get('webserver', 'BASE_URL')
if settings.RBAC:
return BASE_URL + (
"/success"
"?task_id={self.task_id}"
"&dag_id={self.dag_id}"
"&execution_date={iso}"
"&upstream=false"
"&downstream=false"
).format(**locals())
else:
return BASE_URL + (
"/admin/airflow/success"
"?task_id={self.task_id}"
"&dag_id={self.dag_id}"
"&execution_date={iso}"
"&upstream=false"
"&downstream=false"
).format(**locals())
@provide_session
def current_state(self, session=None):
"""
Get the very latest state from the database, if a session is passed,
we use and looking up the state becomes part of the session, otherwise
a new session is used.
"""
TI = TaskInstance
ti = session.query(TI).filter(
TI.dag_id == self.dag_id,
TI.task_id == self.task_id,
TI.execution_date == self.execution_date,
).all()
if ti:
state = ti[0].state
else:
state = None
return state
@provide_session
def error(self, session=None):
"""
Forces the task instance's state to FAILED in the database.
"""
self.log.error("Recording the task instance as FAILED")
self.state = State.FAILED
session.merge(self)
session.commit()
@provide_session
def refresh_from_db(self, session=None, lock_for_update=False):
"""
Refreshes the task instance from the database based on the primary key
:param lock_for_update: if True, indicates that the database should
lock the TaskInstance (issuing a FOR UPDATE clause) until the
session is committed.
"""
TI = TaskInstance
qry = session.query(TI).filter(
TI.dag_id == self.dag_id,
TI.task_id == self.task_id,
TI.execution_date == self.execution_date)
if lock_for_update:
ti = qry.with_for_update().first()
else:
ti = qry.first()
if ti:
self.state = ti.state
self.start_date = ti.start_date
self.end_date = ti.end_date
# Get the raw value of try_number column, don't read through the
# accessor here otherwise it will be incremeneted by one already.
self.try_number = ti._try_number
self.max_tries = ti.max_tries
self.hostname = ti.hostname
self.pid = ti.pid
self.executor_config = ti.executor_config
else:
self.state = None
@provide_session
def clear_xcom_data(self, session=None):
"""
Clears all XCom data from the database for the task instance
"""
session.query(XCom).filter(
XCom.dag_id == self.dag_id,
XCom.task_id == self.task_id,
XCom.execution_date == self.execution_date
).delete()
session.commit()
@property
def key(self):
"""
Returns a tuple that identifies the task instance uniquely
"""
return self.dag_id, self.task_id, self.execution_date, self.try_number
@provide_session
def set_state(self, state, session=None):
self.state = state
self.start_date = timezone.utcnow()
self.end_date = timezone.utcnow()
session.merge(self)
session.commit()
@property
def is_premature(self):
"""
Returns whether a task is in UP_FOR_RETRY state and its retry interval
has elapsed.
"""
# is the task still in the retry waiting period?
return self.state == State.UP_FOR_RETRY and not self.ready_for_retry()
@provide_session
def are_dependents_done(self, session=None):
"""
Checks whether the dependents of this task instance have all succeeded.
This is meant to be used by wait_for_downstream.
This is useful when you do not want to start processing the next
schedule of a task until the dependents are done. For instance,
if the task DROPs and recreates a table.
"""
task = self.task
if not task.downstream_task_ids:
return True
ti = session.query(func.count(TaskInstance.task_id)).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id.in_(task.downstream_task_ids),
TaskInstance.execution_date == self.execution_date,
TaskInstance.state == State.SUCCESS,
)
count = ti[0][0]
return count == len(task.downstream_task_ids)
@property
@provide_session
def previous_ti(self, session=None):
""" The task instance for the task that ran before this task instance """
dag = self.task.dag
if dag:
dr = self.get_dagrun(session=session)
# LEGACY: most likely running from unit tests
if not dr:
# Means that this TI is NOT being run from a DR, but from a catchup
previous_scheduled_date = dag.previous_schedule(self.execution_date)
if not previous_scheduled_date:
return None
return TaskInstance(task=self.task,
execution_date=previous_scheduled_date)
dr.dag = dag
if dag.catchup:
last_dagrun = dr.get_previous_scheduled_dagrun(session=session)
else:
last_dagrun = dr.get_previous_dagrun(session=session)
if last_dagrun:
return last_dagrun.get_task_instance(self.task_id, session=session)
return None
@provide_session
def are_dependencies_met(
self,
dep_context=None,
session=None,
verbose=False):
"""
Returns whether or not all the conditions are met for this task instance to be run
given the context for the dependencies (e.g. a task instance being force run from
the UI will ignore some dependencies).
:param dep_context: The execution context that determines the dependencies that
should be evaluated.
:type dep_context: DepContext
:param session: database session
:type session: Session
:param verbose: whether log details on failed dependencies on
info or debug log level
:type verbose: bool
"""
dep_context = dep_context or DepContext()
failed = False
verbose_aware_logger = self.log.info if verbose else self.log.debug
for dep_status in self.get_failed_dep_statuses(
dep_context=dep_context,
session=session):
failed = True
verbose_aware_logger(
"Dependencies not met for %s, dependency '%s' FAILED: %s",
self, dep_status.dep_name, dep_status.reason
)
if failed:
return False
verbose_aware_logger("Dependencies all met for %s", self)
return True
@provide_session
def get_failed_dep_statuses(
self,
dep_context=None,
session=None):
dep_context = dep_context or DepContext()
for dep in dep_context.deps | self.task.deps:
for dep_status in dep.get_dep_statuses(
self,
session,
dep_context):
self.log.debug(
"%s dependency '%s' PASSED: %s, %s",
self, dep_status.dep_name, dep_status.passed, dep_status.reason
)
if not dep_status.passed:
yield dep_status
def __repr__(self):
return (
"<TaskInstance: {ti.dag_id}.{ti.task_id} "
"{ti.execution_date} [{ti.state}]>"
).format(ti=self)
def next_retry_datetime(self):
"""
Get datetime of the next retry if the task instance fails. For exponential
backoff, retry_delay is used as base and will be converted to seconds.
"""
delay = self.task.retry_delay
if self.task.retry_exponential_backoff:
min_backoff = int(delay.total_seconds() * (2 ** (self.try_number - 2)))
# deterministic per task instance
hash = int(hashlib.sha1("{}#{}#{}#{}".format(self.dag_id,
self.task_id,
self.execution_date,
self.try_number)
.encode('utf-8')).hexdigest(), 16)
# between 0.5 * delay * (2^retry_number) and 1.0 * delay * (2^retry_number)
modded_hash = min_backoff + hash % min_backoff
# timedelta has a maximum representable value. The exponentiation
# here means this value can be exceeded after a certain number
# of tries (around 50 if the initial delay is 1s, even fewer if
# the delay is larger). Cap the value here before creating a
# timedelta object so the operation doesn't fail.
delay_backoff_in_seconds = min(
modded_hash,
timedelta.max.total_seconds() - 1
)
delay = timedelta(seconds=delay_backoff_in_seconds)
if self.task.max_retry_delay:
delay = min(self.task.max_retry_delay, delay)
return self.end_date + delay
def ready_for_retry(self):
"""
Checks on whether the task instance is in the right state and timeframe
to be retried.
"""
return (self.state == State.UP_FOR_RETRY and
self.next_retry_datetime() < timezone.utcnow())
@provide_session
def pool_full(self, session):
"""
Returns a boolean as to whether the slot pool has room for this
task to run
"""
if not self.task.pool:
return False
pool = (
session
.query(Pool)
.filter(Pool.pool == self.task.pool)
.first()
)
if not pool:
return False
open_slots = pool.open_slots(session=session)
return open_slots <= 0
@provide_session
def get_dagrun(self, session):
"""
Returns the DagRun for this TaskInstance
:param session:
:return: DagRun
"""
dr = session.query(DagRun).filter(
DagRun.dag_id == self.dag_id,
DagRun.execution_date == self.execution_date
).first()
return dr
@provide_session
def _check_and_change_state_before_execution(
self,
verbose=True,
ignore_all_deps=False,
ignore_depends_on_past=False,
ignore_task_deps=False,
ignore_ti_state=False,
mark_success=False,
test_mode=False,
job_id=None,
pool=None,
session=None):
"""
Checks dependencies and then sets state to RUNNING if they are met. Returns
True if and only if state is set to RUNNING, which implies that task should be
executed, in preparation for _run_raw_task
:param verbose: whether to turn on more verbose logging
:type verbose: bool
:param ignore_all_deps: Ignore all of the non-critical dependencies, just runs
:type ignore_all_deps: bool
:param ignore_depends_on_past: Ignore depends_on_past DAG attribute
:type ignore_depends_on_past: bool
:param ignore_task_deps: Don't check the dependencies of this TI's task
:type ignore_task_deps: bool
:param ignore_ti_state: Disregards previous task instance state
:type ignore_ti_state: bool
:param mark_success: Don't run the task, mark its state as success
:type mark_success: bool
:param test_mode: Doesn't record success or failure in the DB
:type test_mode: bool
:param pool: specifies the pool to use to run the task instance
:type pool: str
:return: whether the state was changed to running or not
:rtype: bool
"""
task = self.task
self.pool = pool or task.pool
self.test_mode = test_mode
self.refresh_from_db(session=session, lock_for_update=True)
self.job_id = job_id
self.hostname = get_hostname()
self.operator = task.__class__.__name__
if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS:
Stats.incr('previously_succeeded', 1, 1)
queue_dep_context = DepContext(
deps=QUEUE_DEPS,
ignore_all_deps=ignore_all_deps,
ignore_ti_state=ignore_ti_state,
ignore_depends_on_past=ignore_depends_on_past,
ignore_task_deps=ignore_task_deps)
if not self.are_dependencies_met(
dep_context=queue_dep_context,
session=session,
verbose=True):
session.commit()
return False
# TODO: Logging needs cleanup, not clear what is being printed
hr = "\n" + ("-" * 80) + "\n" # Line break
# For reporting purposes, we report based on 1-indexed,
# not 0-indexed lists (i.e. Attempt 1 instead of
# Attempt 0 for the first attempt).
msg = "Starting attempt {attempt} of {total}".format(
attempt=self.try_number,
total=self.max_tries + 1)
self.start_date = timezone.utcnow()
dep_context = DepContext(
deps=RUN_DEPS - QUEUE_DEPS,
ignore_all_deps=ignore_all_deps,
ignore_depends_on_past=ignore_depends_on_past,
ignore_task_deps=ignore_task_deps,
ignore_ti_state=ignore_ti_state)
runnable = self.are_dependencies_met(
dep_context=dep_context,
session=session,
verbose=True)
if not runnable and not mark_success:
# FIXME: we might have hit concurrency limits, which means we probably
# have been running prematurely. This should be handled in the
# scheduling mechanism.
self.state = State.NONE
msg = ("FIXME: Rescheduling due to concurrency limits reached at task "
"runtime. Attempt {attempt} of {total}. State set to NONE.").format(
attempt=self.try_number,
total=self.max_tries + 1)
self.log.warning(hr + msg + hr)
self.queued_dttm = timezone.utcnow()
self.log.info("Queuing into pool %s", self.pool)
session.merge(self)
session.commit()
return False
# Another worker might have started running this task instance while
# the current worker process was blocked on refresh_from_db
if self.state == State.RUNNING:
msg = "Task Instance already running {}".format(self)
self.log.warning(msg)
session.commit()
return False
# print status message
self.log.info(hr + msg + hr)
self._try_number += 1
if not test_mode:
session.add(Log(State.RUNNING, self))
self.state = State.RUNNING
self.pid = os.getpid()
self.end_date = None
if not test_mode:
session.merge(self)
session.commit()
# Closing all pooled connections to prevent
# "max number of connections reached"
settings.engine.dispose()
if verbose:
if mark_success:
msg = "Marking success for {} on {}".format(self.task,
self.execution_date)
self.log.info(msg)
else:
msg = "Executing {} on {}".format(self.task, self.execution_date)
self.log.info(msg)
return True
@provide_session
def _run_raw_task(
self,
mark_success=False,
test_mode=False,
job_id=None,
pool=None,
session=None):
"""
Immediately runs the task (without checking or changing db state
before execution) and then sets the appropriate final state after
completion and runs any post-execute callbacks. Meant to be called
only after another function changes the state to running.
:param mark_success: Don't run the task, mark its state as success
:type mark_success: bool
:param test_mode: Doesn't record success or failure in the DB
:type test_mode: bool
:param pool: specifies the pool to use to run the task instance
:type pool: str
"""
task = self.task
self.pool = pool or task.pool
self.test_mode = test_mode
self.refresh_from_db(session=session)
self.job_id = job_id
self.hostname = get_hostname()
self.operator = task.__class__.__name__
context = {}
try:
if not mark_success:
context = self.get_template_context()
task_copy = copy.copy(task)
self.task = task_copy
def signal_handler(signum, frame):
self.log.error("Received SIGTERM. Terminating subprocesses.")
task_copy.on_kill()
raise AirflowException("Task received SIGTERM signal")
signal.signal(signal.SIGTERM, signal_handler)
# Don't clear Xcom until the task is certain to execute
self.clear_xcom_data()
self.render_templates()
task_copy.pre_execute(context=context)
# If a timeout is specified for the task, make it fail
# if it goes beyond
result = None
if task_copy.execution_timeout:
try:
with timeout(int(
task_copy.execution_timeout.total_seconds())):
result = task_copy.execute(context=context)
except AirflowTaskTimeout:
task_copy.on_kill()
raise
else:
result = task_copy.execute(context=context)
# If the task returns a result, push an XCom containing it
if result is not None:
self.xcom_push(key=XCOM_RETURN_KEY, value=result)
task_copy.post_execute(context=context, result=result)
Stats.incr('operator_successes_{}'.format(
self.task.__class__.__name__), 1, 1)
Stats.incr('ti_successes')
self.refresh_from_db(lock_for_update=True)
self.state = State.SUCCESS
except AirflowSkipException:
self.refresh_from_db(lock_for_update=True)
self.state = State.SKIPPED
except AirflowRescheduleException as reschedule_exception:
self.refresh_from_db()
self._handle_reschedule(reschedule_exception, test_mode, context)
return
except AirflowException as e:
self.refresh_from_db()
# for case when task is marked as success/failed externally
# current behavior doesn't hit the success callback
if self.state in {State.SUCCESS, State.FAILED}:
return
else:
self.handle_failure(e, test_mode, context)
raise
except (Exception, KeyboardInterrupt) as e:
self.handle_failure(e, test_mode, context)
raise
# Success callback
try:
if task.on_success_callback:
task.on_success_callback(context)
except Exception as e3:
self.log.error("Failed when executing success callback")
self.log.exception(e3)
# Recording SUCCESS
self.end_date = timezone.utcnow()
self.set_duration()
if not test_mode:
session.add(Log(self.state, self))
session.merge(self)
session.commit()
@provide_session
def run(
self,
verbose=True,
ignore_all_deps=False,
ignore_depends_on_past=False,
ignore_task_deps=False,
ignore_ti_state=False,
mark_success=False,
test_mode=False,
job_id=None,
pool=None,
session=None):
res = self._check_and_change_state_before_execution(
verbose=verbose,
ignore_all_deps=ignore_all_deps,
ignore_depends_on_past=ignore_depends_on_past,
ignore_task_deps=ignore_task_deps,
ignore_ti_state=ignore_ti_state,
mark_success=mark_success,
test_mode=test_mode,
job_id=job_id,
pool=pool,
session=session)
if res:
self._run_raw_task(
mark_success=mark_success,
test_mode=test_mode,
job_id=job_id,
pool=pool,
session=session)
def dry_run(self):
task = self.task
task_copy = copy.copy(task)
self.task = task_copy
self.render_templates()
task_copy.dry_run()
@provide_session
def _handle_reschedule(self, reschedule_exception, test_mode=False, context=None,
session=None):
# Don't record reschedule request in test mode
if test_mode:
return
self.end_date = timezone.utcnow()
self.set_duration()
# Log reschedule request
session.add(TaskReschedule(self.task, self.execution_date, self._try_number,
self.start_date, self.end_date,
reschedule_exception.reschedule_date))
# set state
self.state = State.NONE
# Decrement try_number so subsequent runs will use the same try number and write
# to same log file.
self._try_number -= 1
session.merge(self)
session.commit()
self.log.info('Rescheduling task, marking task as NONE')
@provide_session
def handle_failure(self, error, test_mode=False, context=None, session=None):
self.log.exception(error)
task = self.task
self.end_date = timezone.utcnow()
self.set_duration()
Stats.incr('operator_failures_{}'.format(task.__class__.__name__), 1, 1)
Stats.incr('ti_failures')
if not test_mode:
session.add(Log(State.FAILED, self))
# Log failure duration
session.add(TaskFail(task, self.execution_date, self.start_date, self.end_date))
if context is not None:
context['exception'] = error
# Let's go deeper
try:
# Since this function is called only when the TI state is running,
# try_number contains the current try_number (not the next). We
# only mark task instance as FAILED if the next task instance
# try_number exceeds the max_tries.
if self.is_eligible_to_retry():
self.state = State.UP_FOR_RETRY
self.log.info('Marking task as UP_FOR_RETRY')
if task.email_on_retry and task.email:
self.email_alert(error)
else:
self.state = State.FAILED
if task.retries:
self.log.info('All retries failed; marking task as FAILED')
else:
self.log.info('Marking task as FAILED.')
if task.email_on_failure and task.email:
self.email_alert(error)
except Exception as e2:
self.log.error('Failed to send email to: %s', task.email)
self.log.exception(e2)
# Handling callbacks pessimistically
try:
if self.state == State.UP_FOR_RETRY and task.on_retry_callback:
task.on_retry_callback(context)
if self.state == State.FAILED and task.on_failure_callback:
task.on_failure_callback(context)
except Exception as e3:
self.log.error("Failed at executing callback")
self.log.exception(e3)
if not test_mode:
session.merge(self)
session.commit()
def is_eligible_to_retry(self):
"""Is task instance is eligible for retry"""
return self.task.retries and self.try_number <= self.max_tries
@provide_session
def get_template_context(self, session=None):
task = self.task
from airflow import macros
tables = None
if 'tables' in task.params:
tables = task.params['tables']
ds = self.execution_date.strftime('%Y-%m-%d')
ts = self.execution_date.isoformat()
yesterday_ds = (self.execution_date - timedelta(1)).strftime('%Y-%m-%d')
tomorrow_ds = (self.execution_date + timedelta(1)).strftime('%Y-%m-%d')
prev_execution_date = task.dag.previous_schedule(self.execution_date)
next_execution_date = task.dag.following_schedule(self.execution_date)
next_ds = None
next_ds_nodash = None
if next_execution_date:
next_ds = next_execution_date.strftime('%Y-%m-%d')
next_ds_nodash = next_ds.replace('-', '')
prev_ds = None
prev_ds_nodash = None
if prev_execution_date:
prev_ds = prev_execution_date.strftime('%Y-%m-%d')
prev_ds_nodash = prev_ds.replace('-', '')
ds_nodash = ds.replace('-', '')
ts_nodash = ts.replace('-', '').replace(':', '')
yesterday_ds_nodash = yesterday_ds.replace('-', '')
tomorrow_ds_nodash = tomorrow_ds.replace('-', '')
ti_key_str = "{task.dag_id}__{task.task_id}__{ds_nodash}"
ti_key_str = ti_key_str.format(**locals())
params = {}
run_id = ''
dag_run = None
if hasattr(task, 'dag'):
if task.dag.params:
params.update(task.dag.params)
dag_run = (
session.query(DagRun)
.filter_by(
dag_id=task.dag.dag_id,
execution_date=self.execution_date)
.first()
)
run_id = dag_run.run_id if dag_run else None
session.expunge_all()
session.commit()
if task.params:
params.update(task.params)
if configuration.getboolean('core', 'dag_run_conf_overrides_params'):
self.overwrite_params_with_dag_run_conf(params=params, dag_run=dag_run)
class VariableAccessor:
"""
Wrapper around Variable. This way you can get variables in templates by using
{var.value.your_variable_name}.
"""
def __init__(self):
self.var = None
def __getattr__(self, item):
self.var = Variable.get(item)
return self.var
def __repr__(self):
return str(self.var)
class VariableJsonAccessor:
"""
Wrapper around deserialized Variables. This way you can get variables
in templates by using {var.json.your_variable_name}.
"""
def __init__(self):
self.var = None
def __getattr__(self, item):
self.var = Variable.get(item, deserialize_json=True)
return self.var
def __repr__(self):
return str(self.var)
return {
'dag': task.dag,
'ds': ds,
'next_ds': next_ds,
'next_ds_nodash': next_ds_nodash,
'prev_ds': prev_ds,
'prev_ds_nodash': prev_ds_nodash,
'ds_nodash': ds_nodash,
'ts': ts,
'ts_nodash': ts_nodash,
'yesterday_ds': yesterday_ds,
'yesterday_ds_nodash': yesterday_ds_nodash,
'tomorrow_ds': tomorrow_ds,
'tomorrow_ds_nodash': tomorrow_ds_nodash,
'END_DATE': ds,
'end_date': ds,
'dag_run': dag_run,
'run_id': run_id,
'execution_date': self.execution_date,
'prev_execution_date': prev_execution_date,
'next_execution_date': next_execution_date,
'latest_date': ds,
'macros': macros,
'params': params,
'tables': tables,
'task': task,
'task_instance': self,
'ti': self,
'task_instance_key_str': ti_key_str,
'conf': configuration,
'test_mode': self.test_mode,
'var': {
'value': VariableAccessor(),
'json': VariableJsonAccessor()
},
'inlets': task.inlets,
'outlets': task.outlets,
}
def overwrite_params_with_dag_run_conf(self, params, dag_run):
if dag_run and dag_run.conf:
params.update(dag_run.conf)
def render_templates(self):
task = self.task
jinja_context = self.get_template_context()
if hasattr(self, 'task') and hasattr(self.task, 'dag'):
if self.task.dag.user_defined_macros:
jinja_context.update(
self.task.dag.user_defined_macros)
rt = self.task.render_template # shortcut to method
for attr in task.__class__.template_fields:
content = getattr(task, attr)
if content:
rendered_content = rt(attr, content, jinja_context)
setattr(task, attr, rendered_content)
def email_alert(self, exception):
exception_html = str(exception).replace('\n', '<br>')
jinja_context = self.get_template_context()
jinja_context.update(dict(
exception=exception,
exception_html=exception_html,
try_number=self.try_number,
max_tries=self.max_tries))
jinja_env = self.task.get_template_env()
default_subject = 'Airflow alert: {{ti}}'
# For reporting purposes, we report based on 1-indexed,
# not 0-indexed lists (i.e. Try 1 instead of
# Try 0 for the first attempt).
default_html_content = (
'Try {{try_number}} out of {{max_tries + 1}}<br>'
'Exception:<br>{{exception_html}}<br>'
'Log: <a href="{{ti.log_url}}">Link</a><br>'
'Host: {{ti.hostname}}<br>'
'Log file: {{ti.log_filepath}}<br>'
'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
)
def render(key, content):
if configuration.has_option('email', key):
path = configuration.get('email', key)
with open(path) as f:
content = f.read()
return jinja_env.from_string(content).render(**jinja_context)
subject = render('subject_template', default_subject)
html_content = render('html_content_template', default_html_content)
send_email(self.task.email, subject, html_content)
def set_duration(self):
if self.end_date and self.start_date:
self.duration = (self.end_date - self.start_date).total_seconds()
else:
self.duration = None
def xcom_push(
self,
key,
value,
execution_date=None):
"""
Make an XCom available for tasks to pull.
:param key: A key for the XCom
:type key: str
:param value: A value for the XCom. The value is pickled and stored
in the database.
:type value: any pickleable object
:param execution_date: if provided, the XCom will not be visible until
this date. This can be used, for example, to send a message to a
task on a future date without it being immediately visible.
:type execution_date: datetime
"""
if execution_date and execution_date < self.execution_date:
raise ValueError(
'execution_date can not be in the past (current '
'execution_date is {}; received {})'.format(
self.execution_date, execution_date))
XCom.set(
key=key,
value=value,
task_id=self.task_id,
dag_id=self.dag_id,
execution_date=execution_date or self.execution_date)
def xcom_pull(
self,
task_ids=None,
dag_id=None,
key=XCOM_RETURN_KEY,
include_prior_dates=False):
"""
Pull XComs that optionally meet certain criteria.
The default value for `key` limits the search to XComs
that were returned by other tasks (as opposed to those that were pushed
manually). To remove this filter, pass key=None (or any desired value).
If a single task_id string is provided, the result is the value of the
most recent matching XCom from that task_id. If multiple task_ids are
provided, a tuple of matching values is returned. None is returned
whenever no matches are found.
:param key: A key for the XCom. If provided, only XComs with matching
keys will be returned. The default key is 'return_value', also
available as a constant XCOM_RETURN_KEY. This key is automatically
given to XComs returned by tasks (as opposed to being pushed
manually). To remove the filter, pass key=None.
:type key: str
:param task_ids: Only XComs from tasks with matching ids will be
pulled. Can pass None to remove the filter.
:type task_ids: str or iterable of strings (representing task_ids)
:param dag_id: If provided, only pulls XComs from this DAG.
If None (default), the DAG of the calling task is used.
:type dag_id: str
:param include_prior_dates: If False, only XComs from the current
execution_date are returned. If True, XComs from previous dates
are returned as well.
:type include_prior_dates: bool
"""
if dag_id is None:
dag_id = self.dag_id
pull_fn = functools.partial(
XCom.get_one,
execution_date=self.execution_date,
key=key,
dag_id=dag_id,
include_prior_dates=include_prior_dates)
if is_container(task_ids):
return tuple(pull_fn(task_id=t) for t in task_ids)
else:
return pull_fn(task_id=task_ids)
@provide_session
def get_num_running_task_instances(self, session):
TI = TaskInstance
return session.query(TI).filter(
TI.dag_id == self.dag_id,
TI.task_id == self.task_id,
TI.state == State.RUNNING
).count()
def init_run_context(self, raw=False):
"""
Sets the log context.
"""
self.raw = raw
self._set_context(self)
class TaskFail(Base):
"""
TaskFail tracks the failed run durations of each task instance.
"""
__tablename__ = "task_fail"
id = Column(Integer, primary_key=True)
task_id = Column(String(ID_LEN), nullable=False)
dag_id = Column(String(ID_LEN), nullable=False)
execution_date = Column(UtcDateTime, nullable=False)
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
duration = Column(Integer)
__table_args__ = (
Index('idx_task_fail_dag_task_date', dag_id, task_id, execution_date,
unique=False),
)
def __init__(self, task, execution_date, start_date, end_date):
self.dag_id = task.dag_id
self.task_id = task.task_id
self.execution_date = execution_date
self.start_date = start_date
self.end_date = end_date
if self.end_date and self.start_date:
self.duration = (self.end_date - self.start_date).total_seconds()
else:
self.duration = None
class TaskReschedule(Base):
"""
TaskReschedule tracks rescheduled task instances.
"""
__tablename__ = "task_reschedule"
id = Column(Integer, primary_key=True)
task_id = Column(String(ID_LEN), nullable=False)
dag_id = Column(String(ID_LEN), nullable=False)
execution_date = Column(UtcDateTime, nullable=False)
try_number = Column(Integer, nullable=False)
start_date = Column(UtcDateTime, nullable=False)
end_date = Column(UtcDateTime, nullable=False)
duration = Column(Integer, nullable=False)
reschedule_date = Column(UtcDateTime, nullable=False)
__table_args__ = (
Index('idx_task_reschedule_dag_task_date', dag_id, task_id, execution_date,
unique=False),
ForeignKeyConstraint([task_id, dag_id, execution_date],
[TaskInstance.task_id, TaskInstance.dag_id,
TaskInstance.execution_date],
name='task_reschedule_dag_task_date_fkey')
)
def __init__(self, task, execution_date, try_number, start_date, end_date,
reschedule_date):
self.dag_id = task.dag_id
self.task_id = task.task_id
self.execution_date = execution_date
self.try_number = try_number
self.start_date = start_date
self.end_date = end_date
self.reschedule_date = reschedule_date
self.duration = (self.end_date - self.start_date).total_seconds()
@staticmethod
@provide_session
def find_for_task_instance(task_instance, session):
"""
Returns all task reschedules for the task instance and try number,
in ascending order.
:param task_instance: the task instance to find task reschedules for
:type task_instance: TaskInstance
"""
TR = TaskReschedule
return (
session
.query(TR)
.filter(TR.dag_id == task_instance.dag_id,
TR.task_id == task_instance.task_id,
TR.execution_date == task_instance.execution_date,
TR.try_number == task_instance.try_number)
.order_by(asc(TR.id))
.all()
)
class Log(Base):
"""
Used to actively log events to the database
"""
__tablename__ = "log"
id = Column(Integer, primary_key=True)
dttm = Column(UtcDateTime)
dag_id = Column(String(ID_LEN))
task_id = Column(String(ID_LEN))
event = Column(String(30))
execution_date = Column(UtcDateTime)
owner = Column(String(500))
extra = Column(Text)
__table_args__ = (
Index('idx_log_dag', dag_id),
)
def __init__(self, event, task_instance, owner=None, extra=None, **kwargs):
self.dttm = timezone.utcnow()
self.event = event
self.extra = extra
task_owner = None
if task_instance:
self.dag_id = task_instance.dag_id
self.task_id = task_instance.task_id
self.execution_date = task_instance.execution_date
task_owner = task_instance.task.owner
if 'task_id' in kwargs:
self.task_id = kwargs['task_id']
if 'dag_id' in kwargs:
self.dag_id = kwargs['dag_id']
if 'execution_date' in kwargs:
if kwargs['execution_date']:
self.execution_date = kwargs['execution_date']
self.owner = owner or task_owner
class SkipMixin(LoggingMixin):
@provide_session
def skip(self, dag_run, execution_date, tasks, session=None):
"""
Sets tasks instances to skipped from the same dag run.
:param dag_run: the DagRun for which to set the tasks to skipped
:param execution_date: execution_date
:param tasks: tasks to skip (not task_ids)
:param session: db session to use
"""
if not tasks:
return
task_ids = [d.task_id for d in tasks]
now = timezone.utcnow()
if dag_run:
session.query(TaskInstance).filter(
TaskInstance.dag_id == dag_run.dag_id,
TaskInstance.execution_date == dag_run.execution_date,
TaskInstance.task_id.in_(task_ids)
).update({TaskInstance.state: State.SKIPPED,
TaskInstance.start_date: now,
TaskInstance.end_date: now},
synchronize_session=False)
session.commit()
else:
assert execution_date is not None, "Execution date is None and no dag run"
self.log.warning("No DAG RUN present this should not happen")
# this is defensive against dag runs that are not complete
for task in tasks:
ti = TaskInstance(task, execution_date=execution_date)
ti.state = State.SKIPPED
ti.start_date = now
ti.end_date = now
session.merge(ti)
session.commit()
@functools.total_ordering
class BaseOperator(LoggingMixin):
"""
Abstract base class for all operators. Since operators create objects that
become nodes in the dag, BaseOperator contains many recursive methods for
dag crawling behavior. To derive this class, you are expected to override
the constructor as well as the 'execute' method.
Operators derived from this class should perform or trigger certain tasks
synchronously (wait for completion). Example of operators could be an
operator that runs a Pig job (PigOperator), a sensor operator that
waits for a partition to land in Hive (HiveSensorOperator), or one that
moves data from Hive to MySQL (Hive2MySqlOperator). Instances of these
operators (tasks) target specific operations, running specific scripts,
functions or data transfers.
This class is abstract and shouldn't be instantiated. Instantiating a
class derived from this one results in the creation of a task object,
which ultimately becomes a node in DAG objects. Task dependencies should
be set by using the set_upstream and/or set_downstream methods.
:param task_id: a unique, meaningful id for the task
:type task_id: str
:param owner: the owner of the task, using the unix username is recommended
:type owner: str
:param retries: the number of retries that should be performed before
failing the task
:type retries: int
:param retry_delay: delay between retries
:type retry_delay: timedelta
:param retry_exponential_backoff: allow progressive longer waits between
retries by using exponential backoff algorithm on retry delay (delay
will be converted into seconds)
:type retry_exponential_backoff: bool
:param max_retry_delay: maximum delay interval between retries
:type max_retry_delay: timedelta
:param start_date: The ``start_date`` for the task, determines
the ``execution_date`` for the first task instance. The best practice
is to have the start_date rounded
to your DAG's ``schedule_interval``. Daily jobs have their start_date
some day at 00:00:00, hourly jobs have their start_date at 00:00
of a specific hour. Note that Airflow simply looks at the latest
``execution_date`` and adds the ``schedule_interval`` to determine
the next ``execution_date``. It is also very important
to note that different tasks' dependencies
need to line up in time. If task A depends on task B and their
start_date are offset in a way that their execution_date don't line
up, A's dependencies will never be met. If you are looking to delay
a task, for example running a daily task at 2AM, look into the
``TimeSensor`` and ``TimeDeltaSensor``. We advise against using
dynamic ``start_date`` and recommend using fixed ones. Read the
FAQ entry about start_date for more information.
:type start_date: datetime
:param end_date: if specified, the scheduler won't go beyond this date
:type end_date: datetime
:param depends_on_past: when set to true, task instances will run
sequentially while relying on the previous task's schedule to
succeed. The task instance for the start_date is allowed to run.
:type depends_on_past: bool
:param wait_for_downstream: when set to true, an instance of task
X will wait for tasks immediately downstream of the previous instance
of task X to finish successfully before it runs. This is useful if the
different instances of a task X alter the same asset, and this asset
is used by tasks downstream of task X. Note that depends_on_past
is forced to True wherever wait_for_downstream is used.
:type wait_for_downstream: bool
:param queue: which queue to target when running this job. Not
all executors implement queue management, the CeleryExecutor
does support targeting specific queues.
:type queue: str
:param dag: a reference to the dag the task is attached to (if any)
:type dag: DAG
:param priority_weight: priority weight of this task against other task.
This allows the executor to trigger higher priority tasks before
others when things get backed up. Set priority_weight as a higher
number for more important tasks.
:type priority_weight: int
:param weight_rule: weighting method used for the effective total
priority weight of the task. Options are:
``{ downstream | upstream | absolute }`` default is ``downstream``
When set to ``downstream`` the effective weight of the task is the
aggregate sum of all downstream descendants. As a result, upstream
tasks will have higher weight and will be scheduled more aggressively
when using positive weight values. This is useful when you have
multiple dag run instances and desire to have all upstream tasks to
complete for all runs before each dag can continue processing
downstream tasks. When set to ``upstream`` the effective weight is the
aggregate sum of all upstream ancestors. This is the opposite where
downtream tasks have higher weight and will be scheduled more
aggressively when using positive weight values. This is useful when you
have multiple dag run instances and prefer to have each dag complete
before starting upstream tasks of other dags. When set to
``absolute``, the effective weight is the exact ``priority_weight``
specified without additional weighting. You may want to do this when
you know exactly what priority weight each task should have.
Additionally, when set to ``absolute``, there is bonus effect of
significantly speeding up the task creation process as for very large
DAGS. Options can be set as string or using the constants defined in
the static class ``airflow.utils.WeightRule``
:type weight_rule: str
:param pool: the slot pool this task should run in, slot pools are a
way to limit concurrency for certain tasks
:type pool: str
:param sla: time by which the job is expected to succeed. Note that
this represents the ``timedelta`` after the period is closed. For
example if you set an SLA of 1 hour, the scheduler would send an email
soon after 1:00AM on the ``2016-01-02`` if the ``2016-01-01`` instance
has not succeeded yet.
The scheduler pays special attention for jobs with an SLA and
sends alert
emails for sla misses. SLA misses are also recorded in the database
for future reference. All tasks that share the same SLA time
get bundled in a single email, sent soon after that time. SLA
notification are sent once and only once for each task instance.
:type sla: datetime.timedelta
:param execution_timeout: max time allowed for the execution of
this task instance, if it goes beyond it will raise and fail.
:type execution_timeout: datetime.timedelta
:param on_failure_callback: a function to be called when a task instance
of this task fails. a context dictionary is passed as a single
parameter to this function. Context contains references to related
objects to the task instance and is documented under the macros
section of the API.
:type on_failure_callback: callable
:param on_retry_callback: much like the ``on_failure_callback`` except
that it is executed when retries occur.
:type on_retry_callback: callable
:param on_success_callback: much like the ``on_failure_callback`` except
that it is executed when the task succeeds.
:type on_success_callback: callable
:param trigger_rule: defines the rule by which dependencies are applied
for the task to get triggered. Options are:
``{ all_success | all_failed | all_done | one_success |
one_failed | none_failed | dummy}``
default is ``all_success``. Options can be set as string or
using the constants defined in the static class
``airflow.utils.TriggerRule``
:type trigger_rule: str
:param resources: A map of resource parameter names (the argument names of the
Resources constructor) to their values.
:type resources: dict
:param run_as_user: unix username to impersonate while running the task
:type run_as_user: str
:param task_concurrency: When set, a task will be able to limit the concurrent
runs across execution_dates
:type task_concurrency: int
:param executor_config: Additional task-level configuration parameters that are
interpreted by a specific executor. Parameters are namespaced by the name of
executor.
**Example**: to run this task in a specific docker container through
the KubernetesExecutor ::
MyOperator(...,
executor_config={
"KubernetesExecutor":
{"image": "myCustomDockerImage"}
}
)
:type executor_config: dict
"""
# For derived classes to define which fields will get jinjaified
template_fields = []
# Defines which files extensions to look for in the templated fields
template_ext = []
# Defines the color in the UI
ui_color = '#fff'
ui_fgcolor = '#000'
# base list which includes all the attrs that don't need deep copy.
_base_operator_shallow_copy_attrs = ('user_defined_macros',
'user_defined_filters',
'params',
'_log',)
# each operator should override this class attr for shallow copy attrs.
shallow_copy_attrs = ()
@apply_defaults
def __init__(
self,
task_id,
owner=configuration.conf.get('operators', 'DEFAULT_OWNER'),
email=None,
email_on_retry=True,
email_on_failure=True,
retries=0,
retry_delay=timedelta(seconds=300),
retry_exponential_backoff=False,
max_retry_delay=None,
start_date=None,
end_date=None,
schedule_interval=None, # not hooked as of now
depends_on_past=False,
wait_for_downstream=False,
dag=None,
params=None,
default_args=None,
adhoc=False,
priority_weight=1,
weight_rule=WeightRule.DOWNSTREAM,
queue=configuration.conf.get('celery', 'default_queue'),
pool=None,
sla=None,
execution_timeout=None,
on_failure_callback=None,
on_success_callback=None,
on_retry_callback=None,
trigger_rule=TriggerRule.ALL_SUCCESS,
resources=None,
run_as_user=None,
task_concurrency=None,
executor_config=None,
inlets=None,
outlets=None,
*args,
**kwargs):
if args or kwargs:
# TODO remove *args and **kwargs in Airflow 2.0
warnings.warn(
'Invalid arguments were passed to {c} (task_id: {t}). '
'Support for passing such arguments will be dropped in '
'Airflow 2.0. Invalid arguments were:'
'\n*args: {a}\n**kwargs: {k}'.format(
c=self.__class__.__name__, a=args, k=kwargs, t=task_id),
category=PendingDeprecationWarning,
stacklevel=3
)
validate_key(task_id)
self.task_id = task_id
self.owner = owner
self.email = email
self.email_on_retry = email_on_retry
self.email_on_failure = email_on_failure
self.start_date = start_date
if start_date and not isinstance(start_date, datetime):
self.log.warning("start_date for %s isn't datetime.datetime", self)
elif start_date:
self.start_date = timezone.convert_to_utc(start_date)
self.end_date = end_date
if end_date:
self.end_date = timezone.convert_to_utc(end_date)
if not TriggerRule.is_valid(trigger_rule):
raise AirflowException(
"The trigger_rule must be one of {all_triggers},"
"'{d}.{t}'; received '{tr}'."
.format(all_triggers=TriggerRule.all_triggers,
d=dag.dag_id if dag else "", t=task_id, tr=trigger_rule))
self.trigger_rule = trigger_rule
self.depends_on_past = depends_on_past
self.wait_for_downstream = wait_for_downstream
if wait_for_downstream:
self.depends_on_past = True
if schedule_interval:
self.log.warning(
"schedule_interval is used for %s, though it has "
"been deprecated as a task parameter, you need to "
"specify it as a DAG parameter instead",
self
)
self._schedule_interval = schedule_interval
self.retries = retries
self.queue = queue
self.pool = pool
self.sla = sla
self.execution_timeout = execution_timeout
self.on_failure_callback = on_failure_callback
self.on_success_callback = on_success_callback
self.on_retry_callback = on_retry_callback
if isinstance(retry_delay, timedelta):
self.retry_delay = retry_delay
else:
self.log.debug("Retry_delay isn't timedelta object, assuming secs")
self.retry_delay = timedelta(seconds=retry_delay)
self.retry_exponential_backoff = retry_exponential_backoff
self.max_retry_delay = max_retry_delay
self.params = params or {} # Available in templates!
self.adhoc = adhoc
self.priority_weight = priority_weight
if not WeightRule.is_valid(weight_rule):
raise AirflowException(
"The weight_rule must be one of {all_weight_rules},"
"'{d}.{t}'; received '{tr}'."
.format(all_weight_rules=WeightRule.all_weight_rules,
d=dag.dag_id if dag else "", t=task_id, tr=weight_rule))
self.weight_rule = weight_rule
self.resources = Resources(**(resources or {}))
self.run_as_user = run_as_user
self.task_concurrency = task_concurrency
self.executor_config = executor_config or {}
# Private attributes
self._upstream_task_ids = set()
self._downstream_task_ids = set()
if not dag and _CONTEXT_MANAGER_DAG:
dag = _CONTEXT_MANAGER_DAG
if dag:
self.dag = dag
self._log = logging.getLogger("airflow.task.operators")
# lineage
self.inlets = []
self.outlets = []
self.lineage_data = None
self._inlets = {
"auto": False,
"task_ids": [],
"datasets": [],
}
self._outlets = {
"datasets": [],
}
if inlets:
self._inlets.update(inlets)
if outlets:
self._outlets.update(outlets)
self._comps = {
'task_id',
'dag_id',
'owner',
'email',
'email_on_retry',
'retry_delay',
'retry_exponential_backoff',
'max_retry_delay',
'start_date',
'schedule_interval',
'depends_on_past',
'wait_for_downstream',
'adhoc',
'priority_weight',
'sla',
'execution_timeout',
'on_failure_callback',
'on_success_callback',
'on_retry_callback',
}
def __eq__(self, other):
return (
type(self) == type(other) and
all(self.__dict__.get(c, None) == other.__dict__.get(c, None)
for c in self._comps))
def __ne__(self, other):
return not self == other
def __lt__(self, other):
return self.task_id < other.task_id
def __hash__(self):
hash_components = [type(self)]
for c in self._comps:
val = getattr(self, c, None)
try:
hash(val)
hash_components.append(val)
except TypeError:
hash_components.append(repr(val))
return hash(tuple(hash_components))
# Composing Operators -----------------------------------------------
def __rshift__(self, other):
"""
Implements Self >> Other == self.set_downstream(other)
If "Other" is a DAG, the DAG is assigned to the Operator.
"""
if isinstance(other, DAG):
# if this dag is already assigned, do nothing
# otherwise, do normal dag assignment
if not (self.has_dag() and self.dag is other):
self.dag = other
else:
self.set_downstream(other)
return other
def __lshift__(self, other):
"""
Implements Self << Other == self.set_upstream(other)
If "Other" is a DAG, the DAG is assigned to the Operator.
"""
if isinstance(other, DAG):
# if this dag is already assigned, do nothing
# otherwise, do normal dag assignment
if not (self.has_dag() and self.dag is other):
self.dag = other
else:
self.set_upstream(other)
return other
def __rrshift__(self, other):
"""
Called for [DAG] >> [Operator] because DAGs don't have
__rshift__ operators.
"""
self.__lshift__(other)
return self
def __rlshift__(self, other):
"""
Called for [DAG] << [Operator] because DAGs don't have
__lshift__ operators.
"""
self.__rshift__(other)
return self
# /Composing Operators ---------------------------------------------
@property
def dag(self):
"""
Returns the Operator's DAG if set, otherwise raises an error
"""
if self.has_dag():
return self._dag
else:
raise AirflowException(
'Operator {} has not been assigned to a DAG yet'.format(self))
@dag.setter
def dag(self, dag):
"""
Operators can be assigned to one DAG, one time. Repeat assignments to
that same DAG are ok.
"""
if not isinstance(dag, DAG):
raise TypeError(
'Expected DAG; received {}'.format(dag.__class__.__name__))
elif self.has_dag() and self.dag is not dag:
raise AirflowException(
"The DAG assigned to {} can not be changed.".format(self))
elif self.task_id not in dag.task_dict:
dag.add_task(self)
self._dag = dag
def has_dag(self):
"""
Returns True if the Operator has been assigned to a DAG.
"""
return getattr(self, '_dag', None) is not None
@property
def dag_id(self):
if self.has_dag():
return self.dag.dag_id
else:
return 'adhoc_' + self.owner
@property
def deps(self):
"""
Returns the list of dependencies for the operator. These differ from execution
context dependencies in that they are specific to tasks and can be
extended/overridden by subclasses.
"""
return {
NotInRetryPeriodDep(),
PrevDagrunDep(),
TriggerRuleDep(),
}
@property
def schedule_interval(self):
"""
The schedule interval of the DAG always wins over individual tasks so
that tasks within a DAG always line up. The task still needs a
schedule_interval as it may not be attached to a DAG.
"""
if self.has_dag():
return self.dag._schedule_interval
else:
return self._schedule_interval
@property
def priority_weight_total(self):
if self.weight_rule == WeightRule.ABSOLUTE:
return self.priority_weight
elif self.weight_rule == WeightRule.DOWNSTREAM:
upstream = False
elif self.weight_rule == WeightRule.UPSTREAM:
upstream = True
else:
upstream = False
return self.priority_weight + sum(
map(lambda task_id: self._dag.task_dict[task_id].priority_weight,
self.get_flat_relative_ids(upstream=upstream))
)
@prepare_lineage
def pre_execute(self, context):
"""
This hook is triggered right before self.execute() is called.
"""
pass
def execute(self, context):
"""
This is the main method to derive when creating an operator.
Context is the same dictionary used as when rendering jinja templates.
Refer to get_template_context for more context.
"""
raise NotImplementedError()
@apply_lineage
def post_execute(self, context, result=None):
"""
This hook is triggered right after self.execute() is called.
It is passed the execution context and any results returned by the
operator.
"""
pass
def on_kill(self):
"""
Override this method to cleanup subprocesses when a task instance
gets killed. Any use of the threading, subprocess or multiprocessing
module within an operator needs to be cleaned up or it will leave
ghost processes behind.
"""
pass
def __deepcopy__(self, memo):
"""
Hack sorting double chained task lists by task_id to avoid hitting
max_depth on deepcopy operations.
"""
sys.setrecursionlimit(5000) # TODO fix this in a better way
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
shallow_copy = cls.shallow_copy_attrs + cls._base_operator_shallow_copy_attrs
for k, v in list(self.__dict__.items()):
if k not in shallow_copy:
setattr(result, k, copy.deepcopy(v, memo))
else:
setattr(result, k, copy.copy(v))
return result
def __getstate__(self):
state = dict(self.__dict__)
del state['_log']
return state
def __setstate__(self, state):
self.__dict__ = state
self._log = logging.getLogger("airflow.task.operators")
def render_template_from_field(self, attr, content, context, jinja_env):
"""
Renders a template from a field. If the field is a string, it will
simply render the string and return the result. If it is a collection or
nested set of collections, it will traverse the structure and render
all strings in it.
"""
rt = self.render_template
if isinstance(content, six.string_types):
result = jinja_env.from_string(content).render(**context)
elif isinstance(content, (list, tuple)):
result = [rt(attr, e, context) for e in content]
elif isinstance(content, numbers.Number):
result = content
elif isinstance(content, dict):
result = {
k: rt("{}[{}]".format(attr, k), v, context)
for k, v in list(content.items())}
else:
param_type = type(content)
msg = (
"Type '{param_type}' used for parameter '{attr}' is "
"not supported for templating").format(**locals())
raise AirflowException(msg)
return result
def render_template(self, attr, content, context):
"""
Renders a template either from a file or directly in a field, and returns
the rendered result.
"""
jinja_env = self.get_template_env()
exts = self.__class__.template_ext
if (
isinstance(content, six.string_types) and
any([content.endswith(ext) for ext in exts])):
return jinja_env.get_template(content).render(**context)
else:
return self.render_template_from_field(attr, content, context, jinja_env)
def get_template_env(self):
return self.dag.get_template_env() \
if hasattr(self, 'dag') \
else jinja2.Environment(cache_size=0)
def prepare_template(self):
"""
Hook that is triggered after the templated fields get replaced
by their content. If you need your operator to alter the
content of the file before the template is rendered,
it should override this method to do so.
"""
pass
def resolve_template_files(self):
# Getting the content of files for template_field / template_ext
for attr in self.template_fields:
content = getattr(self, attr)
if content is None:
continue
elif isinstance(content, six.string_types) and \
any([content.endswith(ext) for ext in self.template_ext]):
env = self.get_template_env()
try:
setattr(self, attr, env.loader.get_source(env, content)[0])
except Exception as e:
self.log.exception(e)
elif isinstance(content, list):
env = self.dag.get_template_env()
for i in range(len(content)):
if isinstance(content[i], six.string_types) and \
any([content[i].endswith(ext) for ext in self.template_ext]):
try:
content[i] = env.loader.get_source(env, content[i])[0]
except Exception as e:
self.log.exception(e)
self.prepare_template()
@property
def upstream_list(self):
"""@property: list of tasks directly upstream"""
return [self.dag.get_task(tid) for tid in self._upstream_task_ids]
@property
def upstream_task_ids(self):
return self._upstream_task_ids
@property
def downstream_list(self):
"""@property: list of tasks directly downstream"""
return [self.dag.get_task(tid) for tid in self._downstream_task_ids]
@property
def downstream_task_ids(self):
return self._downstream_task_ids
@provide_session
def clear(self,
start_date=None,
end_date=None,
upstream=False,
downstream=False,
session=None):
"""
Clears the state of task instances associated with the task, following
the parameters specified.
"""
TI = TaskInstance
qry = session.query(TI).filter(TI.dag_id == self.dag_id)
if start_date:
qry = qry.filter(TI.execution_date >= start_date)
if end_date:
qry = qry.filter(TI.execution_date <= end_date)
tasks = [self.task_id]
if upstream:
tasks += [
t.task_id for t in self.get_flat_relatives(upstream=True)]
if downstream:
tasks += [
t.task_id for t in self.get_flat_relatives(upstream=False)]
qry = qry.filter(TI.task_id.in_(tasks))
count = qry.count()
clear_task_instances(qry.all(), session, dag=self.dag)
session.commit()
return count
def get_task_instances(self, session, start_date=None, end_date=None):
"""
Get a set of task instance related to this task for a specific date
range.
"""
TI = TaskInstance
end_date = end_date or timezone.utcnow()
return session.query(TI).filter(
TI.dag_id == self.dag_id,
TI.task_id == self.task_id,
TI.execution_date >= start_date,
TI.execution_date <= end_date,
).order_by(TI.execution_date).all()
def get_flat_relative_ids(self, upstream=False, found_descendants=None):
"""
Get a flat list of relatives' ids, either upstream or downstream.
"""
if not found_descendants:
found_descendants = set()
relative_ids = self.get_direct_relative_ids(upstream)
for relative_id in relative_ids:
if relative_id not in found_descendants:
found_descendants.add(relative_id)
relative_task = self._dag.task_dict[relative_id]
relative_task.get_flat_relative_ids(upstream,
found_descendants)
return found_descendants
def get_flat_relatives(self, upstream=False):
"""
Get a flat list of relatives, either upstream or downstream.
"""
return list(map(lambda task_id: self._dag.task_dict[task_id],
self.get_flat_relative_ids(upstream)))
def run(
self,
start_date=None,
end_date=None,
ignore_first_depends_on_past=False,
ignore_ti_state=False,
mark_success=False):
"""
Run a set of task instances for a date range.
"""
start_date = start_date or self.start_date
end_date = end_date or self.end_date or timezone.utcnow()
for dt in self.dag.date_range(start_date, end_date=end_date):
TaskInstance(self, dt).run(
mark_success=mark_success,
ignore_depends_on_past=(
dt == start_date and ignore_first_depends_on_past),
ignore_ti_state=ignore_ti_state)
def dry_run(self):
self.log.info('Dry run')
for attr in self.template_fields:
content = getattr(self, attr)
if content and isinstance(content, six.string_types):
self.log.info('Rendering template for %s', attr)
self.log.info(content)
def get_direct_relative_ids(self, upstream=False):
"""
Get the direct relative ids to the current task, upstream or
downstream.
"""
if upstream:
return self._upstream_task_ids
else:
return self._downstream_task_ids
def get_direct_relatives(self, upstream=False):
"""
Get the direct relatives to the current task, upstream or
downstream.
"""
if upstream:
return self.upstream_list
else:
return self.downstream_list
def __repr__(self):
return "<Task({self.__class__.__name__}): {self.task_id}>".format(
self=self)
@property
def task_type(self):
return self.__class__.__name__
def add_only_new(self, item_set, item):
if item in item_set:
raise AirflowException(
'Dependency {self}, {item} already registered'
''.format(**locals()))
else:
item_set.add(item)
def _set_relatives(self, task_or_task_list, upstream=False):
try:
task_list = list(task_or_task_list)
except TypeError:
task_list = [task_or_task_list]
for t in task_list:
if not isinstance(t, BaseOperator):
raise AirflowException(
"Relationships can only be set between "
"Operators; received {}".format(t.__class__.__name__))
# relationships can only be set if the tasks share a single DAG. Tasks
# without a DAG are assigned to that DAG.
dags = {t._dag.dag_id: t._dag for t in [self] + task_list if t.has_dag()}
if len(dags) > 1:
raise AirflowException(
'Tried to set relationships between tasks in '
'more than one DAG: {}'.format(dags.values()))
elif len(dags) == 1:
dag = dags.popitem()[1]
else:
raise AirflowException(
"Tried to create relationships between tasks that don't have "
"DAGs yet. Set the DAG for at least one "
"task and try again: {}".format([self] + task_list))
if dag and not self.has_dag():