Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-4104] Add type annotations to common classes. #4926

Merged
merged 1 commit into from
Mar 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 5 additions & 4 deletions airflow/hooks/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import os
import random
from typing import Iterable

from airflow.models.connection import Connection
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -67,7 +68,7 @@ def _get_connection_from_env(cls, conn_id):
return conn

@classmethod
def get_connections(cls, conn_id):
def get_connections(cls, conn_id): # type: (str) -> Iterable[Connection]
ashb marked this conversation as resolved.
Show resolved Hide resolved
conn = cls._get_connection_from_env(conn_id)
if conn:
conns = [conn]
Expand All @@ -76,15 +77,15 @@ def get_connections(cls, conn_id):
return conns

@classmethod
def get_connection(cls, conn_id):
conn = random.choice(cls.get_connections(conn_id))
def get_connection(cls, conn_id): # type: (str) -> Connection
conn = random.choice(list(cls.get_connections(conn_id)))
if conn.host:
log = LoggingMixin().log
log.info("Using connection to: %s", conn.debug_info())
return conn

@classmethod
def get_hook(cls, conn_id):
def get_hook(cls, conn_id): # type: (str) -> BaseHook
connection = cls.get_connection(conn_id)
return connection.get_hook()

Expand Down
176 changes: 94 additions & 82 deletions airflow/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,13 @@
from builtins import ImportError as BuiltinImportError, bytes, object, str
from collections import defaultdict, namedtuple, OrderedDict
import copy
from typing import Iterable
from datetime import timedelta
from typing import Optional, Union, Type, Callable, Iterable, Set, Dict, Any

from future.standard_library import install_aliases

from airflow.models.base import Base, ID_LEN

try:
# Fix Python > 3.7 deprecation
from collections.abc import Hashable
except ImportError:
# Preserve Python < 3.3 compatibility
from collections import Hashable
from datetime import timedelta

import dill
import functools
import getpass
Expand Down Expand Up @@ -74,6 +67,7 @@
croniter, CroniterBadCronError, CroniterBadDateError, CroniterNotAlphaError
)
import six
from dateutil.relativedelta import relativedelta

from airflow import settings, utils
from airflow.executors import get_default_executor, LocalExecutor
Expand All @@ -83,7 +77,7 @@
AirflowRescheduleException
)
from airflow.dag.base_dag import BaseDag, BaseDagBag
from airflow.lineage import apply_lineage, prepare_lineage
from airflow.lineage import apply_lineage, prepare_lineage, DataSet
from airflow.models.dagpickle import DagPickle
from airflow.models.kubernetes import KubeWorkerIdentifier, KubeResourceVersion # noqa: F401
from airflow.models.log import Log
Expand Down Expand Up @@ -116,6 +110,8 @@

XCOM_RETURN_KEY = 'return_value'

ScheduleInterval = Union[str, timedelta, relativedelta]


class InvalidFernetToken(Exception):
# If Fernet isn't loaded we need a valid exception class to catch. If it is
Expand Down Expand Up @@ -2064,43 +2060,44 @@ class derived from this one results in the creation of a task object,

@apply_defaults
def __init__(
self,
task_id,
owner=configuration.conf.get('operators', 'DEFAULT_OWNER'),
email=None,
email_on_retry=True,
email_on_failure=True,
retries=0,
retry_delay=timedelta(seconds=300),
retry_exponential_backoff=False,
max_retry_delay=None,
start_date=None,
end_date=None,
schedule_interval=None, # not hooked as of now
depends_on_past=False,
wait_for_downstream=False,
dag=None,
params=None,
default_args=None,
priority_weight=1,
weight_rule=WeightRule.DOWNSTREAM,
queue=configuration.conf.get('celery', 'default_queue'),
pool=None,
sla=None,
execution_timeout=None,
on_failure_callback=None,
on_success_callback=None,
on_retry_callback=None,
trigger_rule=TriggerRule.ALL_SUCCESS,
resources=None,
run_as_user=None,
task_concurrency=None,
executor_config=None,
do_xcom_push=True,
inlets=None,
outlets=None,
*args,
**kwargs):
self,
task_id, # type: str
owner=configuration.conf.get('operators', 'DEFAULT_OWNER'), # type: str
email=None, # type: Optional[str]
email_on_retry=True, # type: bool
email_on_failure=True, # type: bool
retries=0, # type: int
retry_delay=timedelta(seconds=300), # type: timedelta
retry_exponential_backoff=False, # type: bool
max_retry_delay=None, # type: Optional[datetime]
start_date=None, # type: Optional[datetime]
end_date=None, # type: Optional[datetime]
schedule_interval=None, # not hooked as of now
ashb marked this conversation as resolved.
Show resolved Hide resolved
depends_on_past=False, # type: bool
wait_for_downstream=False, # type: bool
dag=None, # type: Optional[DAG]
params=None, # type: Optional[Dict]
default_args=None, # type: Optional[Dict]
priority_weight=1, # type: int
weight_rule=WeightRule.DOWNSTREAM, # type: str
queue=configuration.conf.get('celery', 'default_queue'), # type: str
pool=None, # type: Optional[str]
sla=None, # type: Optional[timedelta]
execution_timeout=None, # type: Optional[timedelta]
on_failure_callback=None, # type: Optional[Callable]
on_success_callback=None, # type: Optional[Callable]
on_retry_callback=None, # type: Optional[Callable]
trigger_rule=TriggerRule.ALL_SUCCESS, # type: str
resources=None, # type: Optional[Dict]
run_as_user=None, # type: Optional[str]
task_concurrency=None, # type: Optional[int]
executor_config=None, # type: Optional[Dict]
do_xcom_push=True, # type: bool
inlets=None, # type: Optional[Dict]
outlets=None, # type: Optional[Dict]
*args,
**kwargs
):

if args or kwargs:
# TODO remove *args and **kwargs in Airflow 2.0
Expand Down Expand Up @@ -2183,8 +2180,8 @@ def __init__(
self.do_xcom_push = do_xcom_push

# Private attributes
self._upstream_task_ids = set()
self._downstream_task_ids = set()
self._upstream_task_ids = set() # type: Set[str]
self._downstream_task_ids = set() # type: Set[str]

if not dag and _CONTEXT_MANAGER_DAG:
dag = _CONTEXT_MANAGER_DAG
Expand All @@ -2194,8 +2191,8 @@ def __init__(
self._log = logging.getLogger("airflow.task.operators")

# lineage
self.inlets = []
self.outlets = []
self.inlets = [] # type: Iterable[DataSet]
self.outlets = [] # type: Iterable[DataSet]
self.lineage_data = None

self._inlets = {
Expand All @@ -2206,7 +2203,7 @@ def __init__(

self._outlets = {
"datasets": [],
}
} # type: Dict

if inlets:
self._inlets.update(inlets)
Expand Down Expand Up @@ -2977,29 +2974,32 @@ class DAG(BaseDag, LoggingMixin):
"""

def __init__(
self, dag_id,
description='',
schedule_interval=timedelta(days=1),
start_date=None, end_date=None,
full_filepath=None,
template_searchpath=None,
template_undefined=jinja2.Undefined,
user_defined_macros=None,
user_defined_filters=None,
default_args=None,
concurrency=configuration.conf.getint('core', 'dag_concurrency'),
max_active_runs=configuration.conf.getint(
'core', 'max_active_runs_per_dag'),
dagrun_timeout=None,
sla_miss_callback=None,
default_view=None,
orientation=configuration.conf.get('webserver', 'dag_orientation'),
catchup=configuration.conf.getboolean('scheduler', 'catchup_by_default'),
on_success_callback=None, on_failure_callback=None,
doc_md=None,
params=None,
access_control=None):

self,
dag_id, # type: str
description='', # type: str
schedule_interval=timedelta(days=1), # type: Optional[ScheduleInterval]
start_date=None, # type: Optional[datetime]
end_date=None, # type: Optional[datetime]
full_filepath=None, # type: Optional[str]
template_searchpath=None, # type: Optional[Union[str, Iterable[str]]]
template_undefined=jinja2.Undefined, # type: Type[jinja2.Undefined]
user_defined_macros=None, # type: Optional[Dict]
user_defined_filters=None, # type: Optional[Dict]
default_args=None, # type: Optional[Dict]
concurrency=configuration.conf.getint('core', 'dag_concurrency'), # type: int
max_active_runs=configuration.conf.getint(
'core', 'max_active_runs_per_dag'), # type: int
dagrun_timeout=None, # type: Optional[timedelta]
sla_miss_callback=None, # type: Optional[Callable]
default_view=None, # type: Optional[str]
orientation=configuration.conf.get('webserver', 'dag_orientation'), # type: str
catchup=configuration.conf.getboolean('scheduler', 'catchup_by_default'), # type: bool
on_success_callback=None, # type: Optional[Callable]
on_failure_callback=None, # type: Optional[Callable]
doc_md=None, # type: Optional[str]
params=None, # type: Optional[Dict]
access_control=None # type: Optional[Dict]
):
self.user_defined_macros = user_defined_macros
self.user_defined_filters = user_defined_filters
self.default_args = default_args or {}
Expand All @@ -3021,7 +3021,7 @@ def __init__(
self._description = description
# set file location to caller source path
self.fileloc = sys._getframe().f_back.f_code.co_filename
self.task_dict = dict()
self.task_dict = dict() # type: Dict[str, TaskInstance]

# set timezone
if start_date and start_date.tzinfo:
Expand Down Expand Up @@ -3050,8 +3050,8 @@ def __init__(
)

self.schedule_interval = schedule_interval
if isinstance(schedule_interval, Hashable) and schedule_interval in cron_presets:
self._schedule_interval = cron_presets.get(schedule_interval)
if isinstance(schedule_interval, six.string_types) and schedule_interval in cron_presets:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a possible code change. Please double check with the PR that introduced the "Hashable" here for the reasoning

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the Hashable check was added because relativedelta objects aren't hashable in python-dateutil before version 2.7. It turns out that the keys of cron_presets are always strings, so we can check that schedule_interval is a string instead of checking that it's hashable. We could also require python-dateutil>=2.7 and drop both checks.

self._schedule_interval = cron_presets.get(schedule_interval) # type: Optional[ScheduleInterval]
elif schedule_interval == '@once':
self._schedule_interval = None
else:
Expand All @@ -3076,7 +3076,7 @@ def __init__(
self.on_failure_callback = on_failure_callback
self.doc_md = doc_md

self._old_context_manager_dags = []
self._old_context_manager_dags = [] # type: Iterable[DAG]
self._access_control = access_control

self._comps = {
Expand Down Expand Up @@ -4283,7 +4283,13 @@ def setdefault(cls, key, default, deserialize_json=False):

@classmethod
@provide_session
def get(cls, key, default_var=__NO_DEFAULT_SENTINEL, deserialize_json=False, session=None):
def get(
cls,
key, # type: str
default_var=__NO_DEFAULT_SENTINEL, # type: Any
deserialize_json=False, # type: bool
session=None
ashb marked this conversation as resolved.
Show resolved Hide resolved
):
obj = session.query(cls).filter(cls.key == key).first()
if obj is None:
if default_var is not cls.__NO_DEFAULT_SENTINEL:
Expand All @@ -4298,15 +4304,21 @@ def get(cls, key, default_var=__NO_DEFAULT_SENTINEL, deserialize_json=False, ses

@classmethod
@provide_session
def set(cls, key, value, serialize_json=False, session=None):
def set(
cls,
key, # type: str
value, # type: Any
serialize_json=False, # type: bool
session=None
):

if serialize_json:
stored_value = json.dumps(value)
else:
stored_value = str(value)

Variable.delete(key)
session.add(Variable(key=key, val=stored_value))
session.add(Variable(key=key, val=stored_value)) # type: ignore
session.flush()

@classmethod
Expand Down
1 change: 1 addition & 0 deletions airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def get_hook(self):
elif self.conn_type == 'grpc':
from airflow.contrib.hooks.grpc_hook import GrpcHook
return GrpcHook(grpc_conn_id=self.conn_id)
raise AirflowException("Unknown hook type {}".format(self.conn_type))

def __repr__(self):
return self.conn_id
Expand Down
35 changes: 24 additions & 11 deletions airflow/operators/check_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from builtins import zip
from builtins import str
from typing import Iterable
from typing import Optional, Any, Iterable, Dict

from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
Expand Down Expand Up @@ -69,9 +69,12 @@ class CheckOperator(BaseOperator):

@apply_defaults
def __init__(
self, sql,
conn_id=None,
*args, **kwargs):
self,
sql, # type: str
conn_id=None, # type: Optional[str]
*args,
**kwargs
):
super(CheckOperator, self).__init__(*args, **kwargs)
self.conn_id = conn_id
self.sql = sql
Expand Down Expand Up @@ -127,9 +130,14 @@ class ValueCheckOperator(BaseOperator):

@apply_defaults
def __init__(
self, sql, pass_value, tolerance=None,
conn_id=None,
*args, **kwargs):
self,
sql, # type: str
pass_value, # type: Any
tolerance=None, # type: Any
conn_id=None, # type: Optional[str]
*args,
**kwargs
):
super(ValueCheckOperator, self).__init__(*args, **kwargs)
self.sql = sql
self.conn_id = conn_id
Expand Down Expand Up @@ -203,10 +211,15 @@ class IntervalCheckOperator(BaseOperator):

@apply_defaults
def __init__(
self, table, metrics_thresholds,
date_filter_column='ds', days_back=-7,
conn_id=None,
*args, **kwargs):
self,
table, # type: str
metrics_thresholds, # type: Dict
date_filter_column='ds', # type: str
days_back=-7, # type: int
conn_id=None, # type: Optional[str]
*args,
**kwargs
):
super(IntervalCheckOperator, self).__init__(*args, **kwargs)
self.table = table
self.metrics_thresholds = metrics_thresholds
Expand Down