Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve TriggerRuleDep typing and readability #27810

Merged
merged 4 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 38 additions & 22 deletions airflow/ti_deps/deps/base_ti_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@
# under the License.
from __future__ import annotations

from typing import NamedTuple
from typing import TYPE_CHECKING, Any, Iterator, NamedTuple

from airflow.ti_deps.dep_context import DepContext
from airflow.utils.session import provide_session

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.models.taskinstance import TaskInstance


class BaseTIDep:
"""
Expand All @@ -38,27 +43,29 @@ class BaseTIDep:
# to some tasks (e.g. depends_on_past is not specified by all tasks).
IS_TASK_DEP = False

def __init__(self):
pass

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
return isinstance(self, type(other))

def __hash__(self):
def __hash__(self) -> int:
return hash(type(self))

def __repr__(self):
def __repr__(self) -> str:
return f"<TIDep({self.name})>"

@property
def name(self):
"""
The human-readable name for the dependency. Use the classname as the default name
if this method is not overridden in the subclass.
def name(self) -> str:
"""The human-readable name for the dependency.

Use the class name as the default if ``NAME`` is not provided.
"""
return getattr(self, "NAME", self.__class__.__name__)

def _get_dep_statuses(self, ti, session, dep_context):
def _get_dep_statuses(
self,
ti: TaskInstance,
session: Session,
dep_context: DepContext,
) -> Iterator[TIDepStatus]:
"""
Abstract method that returns an iterable of TIDepStatus objects that describe
whether the given task instance has this dependency met.
Expand All @@ -73,7 +80,12 @@ def _get_dep_statuses(self, ti, session, dep_context):
raise NotImplementedError

@provide_session
def get_dep_statuses(self, ti, session, dep_context=None):
def get_dep_statuses(
self,
ti: TaskInstance,
session: Session,
dep_context: DepContext | None = None,
) -> Iterator[TIDepStatus]:
"""
Wrapper around the private _get_dep_statuses method that contains some global
checks for all dependencies.
Expand All @@ -82,21 +94,20 @@ def get_dep_statuses(self, ti, session, dep_context=None):
:param session: database session
:param dep_context: the context for which this dependency should be evaluated for
"""
if dep_context is None:
dep_context = DepContext()
cxt = DepContext() if dep_context is None else dep_context

if self.IGNORABLE and dep_context.ignore_all_deps:
if self.IGNORABLE and cxt.ignore_all_deps:
yield self._passing_status(reason="Context specified all dependencies should be ignored.")
return

if self.IS_TASK_DEP and dep_context.ignore_task_deps:
if self.IS_TASK_DEP and cxt.ignore_task_deps:
yield self._passing_status(reason="Context specified all task dependencies should be ignored.")
return

yield from self._get_dep_statuses(ti, session, dep_context)
yield from self._get_dep_statuses(ti, session, cxt)

@provide_session
def is_met(self, ti, session, dep_context=None):
def is_met(self, ti: TaskInstance, session: Session, dep_context: DepContext | None = None) -> bool:
"""
Returns whether or not this dependency is met for a given task instance. A
dependency is considered met if all of the dependency statuses it reports are
Expand All @@ -110,7 +121,12 @@ def is_met(self, ti, session, dep_context=None):
return all(status.passed for status in self.get_dep_statuses(ti, session, dep_context))

@provide_session
def get_failure_reasons(self, ti, session, dep_context=None):
def get_failure_reasons(
self,
ti: TaskInstance,
session: Session,
dep_context: DepContext | None = None,
) -> Iterator[str]:
"""
Returns an iterable of strings that explain why this dependency wasn't met.

Expand All @@ -123,10 +139,10 @@ def get_failure_reasons(self, ti, session, dep_context=None):
if not dep_status.passed:
yield dep_status.reason

def _failing_status(self, reason=""):
def _failing_status(self, reason: str = "") -> TIDepStatus:
return TIDepStatus(self.name, False, reason)

def _passing_status(self, reason=""):
def _passing_status(self, reason: str = "") -> TIDepStatus:
return TIDepStatus(self.name, True, reason)


Expand Down
Loading