Skip to content

Commit

Permalink
Add pre-commit check to keep partial() up-to-date
Browse files Browse the repository at this point in the history
Also caught and fixed a few missing arguments!
  • Loading branch information
uranusjr committed Jun 2, 2022
1 parent 07edbea commit c7f62ce
Show file tree
Hide file tree
Showing 40 changed files with 1,389 additions and 1,178 deletions.
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,12 @@ repos:
^.*RELEASE_NOTES\.rst$|
^.*CHANGELOG\.txt$|^.*CHANGELOG\.rst$|
git
- id: check-base-operator-partial-arguments
name: Check BaseOperator and partial() arguments
language: python
entry: ./scripts/ci/pre_commit/pre_commit_base_operator_partial_arguments.py
pass_filenames: false
files: ^airflow/models/(?:base|mapped)operator.py$
- id: check-base-operator-usage
language: pygrep
name: Check BaseOperator[Link] core imports
Expand Down
2 changes: 2 additions & 0 deletions STATIC_CODE_CHECKS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ require Breeze Docker image to be build locally.
+--------------------------------------------------------+------------------------------------------------------------------+---------+
| check-apache-license-rat | Check if licenses are OK for Apache | |
+--------------------------------------------------------+------------------------------------------------------------------+---------+
| check-base-operator-partial-arguments | Check BaseOperator and partial() arguments | |
+--------------------------------------------------------+------------------------------------------------------------------+---------+
| check-base-operator-usage | * Check BaseOperator[Link] core imports | |
| | * Check BaseOperator[Link] other imports | |
+--------------------------------------------------------+------------------------------------------------------------------+---------+
Expand Down
11 changes: 9 additions & 2 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from airflow.models.baseoperator import (
BaseOperator,
coerce_resources,
coerce_retry_delay,
coerce_timedelta,
get_merged_defaults,
parse_retries,
)
Expand Down Expand Up @@ -344,8 +344,15 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg:
if partial_kwargs.get("pool") is None:
partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
partial_kwargs["retries"] = parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES))
partial_kwargs["retry_delay"] = coerce_retry_delay(
partial_kwargs["retry_delay"] = coerce_timedelta(
partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY),
key="retry_delay",
)
max_retry_delay = partial_kwargs.get("max_retry_delay")
partial_kwargs["max_retry_delay"] = (
max_retry_delay
if max_retry_delay is None
else coerce_timedelta(max_retry_delay, key="max_retry_delay")
)
partial_kwargs["resources"] = coerce_resources(partial_kwargs.get("resources"))
partial_kwargs.setdefault("executor_config", {})
Expand Down
52 changes: 32 additions & 20 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ def parse_retries(retries: Any) -> Optional[int]:
return parsed_retries


def coerce_retry_delay(retry_delay: Union[float, timedelta]) -> timedelta:
if isinstance(retry_delay, timedelta):
return retry_delay
logger.debug("retry_delay isn't a timedelta object, assuming secs")
return timedelta(seconds=retry_delay)
def coerce_timedelta(value: Union[float, timedelta], *, key: str) -> timedelta:
if isinstance(value, timedelta):
return value
logger.debug("%s isn't a timedelta object, assuming secs", key)
return timedelta(seconds=value)


def coerce_resources(resources: Optional[Dict[str, Any]]) -> Optional[Resources]:
Expand Down Expand Up @@ -205,6 +205,7 @@ def partial(
pool: Optional[str] = None,
pool_slots: int = DEFAULT_POOL_SLOTS,
execution_timeout: Optional[timedelta] = DEFAULT_TASK_EXECUTION_TIMEOUT,
max_retry_delay: Union[None, timedelta, float] = None,
retry_delay: Union[timedelta, float] = DEFAULT_RETRY_DELAY,
retry_exponential_backoff: bool = False,
priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
Expand All @@ -219,6 +220,11 @@ def partial(
executor_config: Optional[Dict] = None,
inlets: Optional[Any] = None,
outlets: Optional[Any] = None,
doc: Optional[str] = None,
doc_md: Optional[str] = None,
doc_json: Optional[str] = None,
doc_yaml: Optional[str] = None,
doc_rst: Optional[str] = None,
**kwargs,
) -> OperatorPartial:
from airflow.models.dag import DagContext
Expand Down Expand Up @@ -259,6 +265,7 @@ def partial(
partial_kwargs.setdefault("pool", pool)
partial_kwargs.setdefault("pool_slots", pool_slots)
partial_kwargs.setdefault("execution_timeout", execution_timeout)
partial_kwargs.setdefault("max_retry_delay", max_retry_delay)
partial_kwargs.setdefault("retry_delay", retry_delay)
partial_kwargs.setdefault("retry_exponential_backoff", retry_exponential_backoff)
partial_kwargs.setdefault("priority_weight", priority_weight)
Expand All @@ -274,6 +281,11 @@ def partial(
partial_kwargs.setdefault("inlets", inlets)
partial_kwargs.setdefault("outlets", outlets)
partial_kwargs.setdefault("resources", resources)
partial_kwargs.setdefault("doc", doc)
partial_kwargs.setdefault("doc_json", doc_json)
partial_kwargs.setdefault("doc_md", doc_md)
partial_kwargs.setdefault("doc_rst", doc_rst)
partial_kwargs.setdefault("doc_yaml", doc_yaml)

# Post-process arguments. Should be kept in sync with _TaskDecorator.expand().
if "task_concurrency" in kwargs: # Reject deprecated option.
Expand All @@ -285,7 +297,12 @@ def partial(
if partial_kwargs["pool"] is None:
partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"])
partial_kwargs["retry_delay"] = coerce_retry_delay(partial_kwargs["retry_delay"])
partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay")
if partial_kwargs["max_retry_delay"] is not None:
partial_kwargs["max_retry_delay"] = coerce_timedelta(
partial_kwargs["max_retry_delay"],
key="max_retry_delay",
)
partial_kwargs["executor_config"] = partial_kwargs["executor_config"] or {}
partial_kwargs["resources"] = coerce_resources(partial_kwargs["resources"])

Expand Down Expand Up @@ -757,10 +774,7 @@ def __init__(
dag = dag or DagContext.get_current_dag()
task_group = task_group or TaskGroupContext.get_current_task_group(dag)

if task_group:
self.task_id = task_group.child_id(task_id)
else:
self.task_id = task_id
self.task_id = task_group.child_id(task_id) if task_group else task_id
if not self.__from_mapped and task_group:
task_group.add(self)

Expand Down Expand Up @@ -826,20 +840,18 @@ def __init__(

self.trigger_rule = TriggerRule(trigger_rule)
self.depends_on_past: bool = depends_on_past
self.ignore_first_depends_on_past = ignore_first_depends_on_past
self.wait_for_downstream = wait_for_downstream
self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
self.wait_for_downstream: bool = wait_for_downstream
if wait_for_downstream:
self.depends_on_past = True

self.retry_delay = coerce_retry_delay(retry_delay)
self.retry_delay = coerce_timedelta(retry_delay, key="retry_delay")
self.retry_exponential_backoff = retry_exponential_backoff
self.max_retry_delay = max_retry_delay
if max_retry_delay:
if isinstance(max_retry_delay, timedelta):
self.max_retry_delay = max_retry_delay
else:
self.log.debug("max_retry_delay isn't a timedelta object, assuming secs")
self.max_retry_delay = timedelta(seconds=max_retry_delay)
self.max_retry_delay = (
max_retry_delay
if max_retry_delay is None
else coerce_timedelta(max_retry_delay, key="max_retry_delay")
)

# At execution_time this becomes a normal dict
self.params: Union[ParamsDict, dict] = ParamsDict(params)
Expand Down
24 changes: 24 additions & 0 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,10 @@ def pool_slots(self) -> Optional[str]:
def execution_timeout(self) -> Optional[datetime.timedelta]:
return self.partial_kwargs.get("execution_timeout")

@property
def max_retry_delay(self) -> Optional[datetime.timedelta]:
return self.partial_kwargs.get("max_retry_delay")

@property
def retry_delay(self) -> datetime.timedelta:
return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY)
Expand Down Expand Up @@ -477,6 +481,26 @@ def inlets(self) -> Optional[Any]:
def outlets(self) -> Optional[Any]:
return self.partial_kwargs.get("outlets", None)

@property
def doc(self) -> Optional[str]:
return self.partial_kwargs.get("doc")

@property
def doc_md(self) -> Optional[str]:
return self.partial_kwargs.get("doc_md")

@property
def doc_json(self) -> Optional[str]:
return self.partial_kwargs.get("doc_json")

@property
def doc_yaml(self) -> Optional[str]:
return self.partial_kwargs.get("doc_yaml")

@property
def doc_rst(self) -> Optional[str]:
return self.partial_kwargs.get("doc_rst")

def get_dag(self) -> Optional["DAG"]:
"""Implementing Operator."""
return self.dag
Expand Down
4 changes: 2 additions & 2 deletions airflow/ti_deps/deps/prev_dagrun_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PrevDagrunDep(BaseTIDep):
IS_TASK_DEP = True

@provide_session
def _get_dep_statuses(self, ti, session, dep_context):
def _get_dep_statuses(self, ti: TI, session, dep_context):
if dep_context.ignore_depends_on_past:
reason = "The context specified that the state of past DAGs could be ignored."
yield self._passing_status(reason=reason)
Expand All @@ -50,7 +50,7 @@ def _get_dep_statuses(self, ti, session, dep_context):
return

# Don't depend on the previous task instance if we are the first task.
catchup = ti.task.dag.catchup
catchup = ti.task.dag and ti.task.dag.catchup
if catchup:
last_dagrun = dr.get_previous_scheduled_dagrun(session)
else:
Expand Down
1 change: 1 addition & 0 deletions dev/breeze/src/airflow_breeze/pre_commit_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
'check-airflow-config-yaml-consistent',
'check-airflow-providers-have-extras',
'check-apache-license-rat',
'check-base-operator-partial-arguments',
'check-base-operator-usage',
'check-boring-cyborg-configuration',
'check-breeze-top-dependencies-limited',
Expand Down

0 comments on commit c7f62ce

Please sign in to comment.