Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

D401 Support - Sensors, Serialization, and Triggers #34932

Merged
merged 1 commit into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/sensors/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def poke(self, context: Context, session: Session = NEW_SESSION) -> bool:
return count_allowed == len(dttm_filter)

def execute(self, context: Context) -> None:
"""Runs on the worker and defers using the triggers if deferrable is set to True."""
"""Run on the worker and defer using the triggers if deferrable is set to True."""
if not self.deferrable:
super().execute(context)
else:
Expand Down
4 changes: 2 additions & 2 deletions airflow/serialization/pydantic/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_task_instances(
session: Session = NEW_SESSION,
) -> list[TI]:
"""
Returns the task instances for this dag run.
Return the task instances for this dag run.

TODO: make it works for AIP-44
"""
Expand All @@ -107,7 +107,7 @@ def get_task_instance(
map_index: int = -1,
) -> TI | None:
"""
Returns the task instance specified by task_id for this dag run.
Return the task instance specified by task_id for this dag run.

:param task_id: the task id
:param session: Sqlalchemy ORM Session
Expand Down
14 changes: 7 additions & 7 deletions airflow/serialization/pydantic/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def xcom_push(
@provide_session
def get_dagrun(self, session: Session = NEW_SESSION) -> DagRunPydantic:
"""
Returns the DagRun for this TaskInstance.
Return the DagRun for this TaskInstance.

:param session: SQLAlchemy ORM Session

Expand All @@ -166,7 +166,7 @@ def get_dagrun(self, session: Session = NEW_SESSION) -> DagRunPydantic:

def _execute_task(self, context, task_orig):
"""
Executes Task (optionally with a Timeout) and pushes Xcom results.
Execute Task (optionally with a Timeout) and push Xcom results.

:param context: Jinja2 context
:param task_orig: origin task
Expand All @@ -178,7 +178,7 @@ def _execute_task(self, context, task_orig):
@provide_session
def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool = False) -> None:
"""
Refreshes the task instance from the database based on the primary key.
Refresh the task instance from the database based on the primary key.

:param session: SQLAlchemy ORM Session
:param lock_for_update: if True, indicates that the database should
Expand All @@ -197,7 +197,7 @@ def set_duration(self) -> None:

@property
def stats_tags(self) -> dict[str, str]:
"""Returns task instance tags."""
"""Return task instance tags."""
from airflow.models.taskinstance import _stats_tags

return _stats_tags(task_instance=self)
Expand Down Expand Up @@ -280,7 +280,7 @@ def get_previous_dagrun(
session: Session | None = None,
) -> DagRun | None:
"""
The DagRun that ran before this task instance's DagRun.
Return the DagRun that ran before this task instance's DagRun.

:param state: If passed, it only take into account instances of a specific state.
:param session: SQLAlchemy ORM Session.
Expand All @@ -296,7 +296,7 @@ def get_previous_execution_date(
session: Session = NEW_SESSION,
) -> pendulum.DateTime | None:
"""
The execution date from property previous_ti_success.
Return the execution date from property previous_ti_success.

:param state: If passed, it only take into account instances of a specific state.
:param session: SQLAlchemy ORM Session
Expand Down Expand Up @@ -336,7 +336,7 @@ def get_previous_ti(
session: Session = NEW_SESSION,
) -> TaskInstance | None:
"""
The task instance for the task that ran before this task instance.
Return the task instance for the task that ran before this task instance.

:param session: SQLAlchemy ORM Session
:param state: If passed, it only take into account instances of a specific state.
Expand Down
35 changes: 22 additions & 13 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@

@cache
def get_operator_extra_links() -> set[str]:
"""Get the operator extra links.
"""
Get the operator extra links.

This includes both the built-in ones, and those come from the providers.
"""
Expand All @@ -110,7 +111,8 @@ def get_operator_extra_links() -> set[str]:

@cache
def _get_default_mapped_partial() -> dict[str, Any]:
"""Get default partial kwargs in a mapped operator.
"""
Get default partial kwargs in a mapped operator.

This is used to simplify a serialized mapped operator by excluding default
values supplied in the implementation from the serialized dict. Since those
Expand Down Expand Up @@ -141,7 +143,8 @@ def decode_relativedelta(var: dict[str, Any]) -> relativedelta.relativedelta:


def encode_timezone(var: Timezone) -> str | int:
"""Encode a Pendulum Timezone for serialization.
"""
Encode a Pendulum Timezone for serialization.

Airflow only supports timezone objects that implements Pendulum's Timezone
interface. We try to keep as much information as possible to make conversion
Expand Down Expand Up @@ -192,7 +195,8 @@ def __str__(self) -> str:


def _encode_timetable(var: Timetable) -> dict[str, Any]:
"""Encode a timetable instance.
"""
Encode a timetable instance.

This delegates most of the serialization work to the type, so the behavior
can be completely controlled by a custom subclass.
Expand All @@ -205,7 +209,8 @@ def _encode_timetable(var: Timetable) -> dict[str, Any]:


def _decode_timetable(var: dict[str, Any]) -> Timetable:
"""Decode a previously serialized timetable.
"""
Decode a previously serialized timetable.

Most of the deserialization logic is delegated to the actual type, which
we import from string.
Expand All @@ -218,7 +223,8 @@ def _decode_timetable(var: dict[str, Any]) -> Timetable:


class _XComRef(NamedTuple):
"""Used to store info needed to create XComArg.
"""
Store info needed to create XComArg.

We can't turn it in to a XComArg until we've loaded _all_ the tasks, so when
deserializing an operator, we need to create something in its place, and
Expand Down Expand Up @@ -252,7 +258,8 @@ def deref(self, dag: DAG) -> XComArg:


class _ExpandInputRef(NamedTuple):
"""Used to store info needed to create a mapped operator's expand input.
"""
Store info needed to create a mapped operator's expand input.

This references a ``ExpandInput`` type, but replaces ``XComArg`` objects
with ``_XComRef`` (see documentation on the latter type for reasoning).
Expand All @@ -263,15 +270,17 @@ class _ExpandInputRef(NamedTuple):

@classmethod
def validate_expand_input_value(cls, value: _ExpandInputOriginalValue) -> None:
"""Validate we've covered all ``ExpandInput.value`` types.
"""
Validate we've covered all ``ExpandInput.value`` types.

This function does not actually do anything, but is called during
serialization so Mypy will *statically* check we have handled all
possible ExpandInput cases.
"""

def deref(self, dag: DAG) -> ExpandInput:
"""De-reference into a concrete ExpandInput object.
"""
De-reference into a concrete ExpandInput object.

If you add more cases here, be sure to update _ExpandInputOriginalValue
and _ExpandInputSerializedValue to match the logic.
Expand Down Expand Up @@ -311,19 +320,19 @@ class BaseSerialization:

@classmethod
def to_json(cls, var: DAG | BaseOperator | dict | list | set | tuple) -> str:
"""Stringifies DAGs and operators contained by var and returns a JSON string of var."""
"""Stringify DAGs and operators contained by var and returns a JSON string of var."""
return json.dumps(cls.to_dict(var), ensure_ascii=True)

@classmethod
def to_dict(cls, var: DAG | BaseOperator | dict | list | set | tuple) -> dict:
"""Stringifies DAGs and operators contained by var and returns a dict of var."""
"""Stringify DAGs and operators contained by var and returns a dict of var."""
# Don't call on this class directly - only SerializedDAG or
# SerializedBaseOperator should be used as the "entrypoint"
raise NotImplementedError()

@classmethod
def from_json(cls, serialized_obj: str) -> BaseSerialization | dict | list | set | tuple:
"""Deserializes json_str and reconstructs all DAGs and operators it contains."""
"""Deserialize json_str and reconstructs all DAGs and operators it contains."""
return cls.from_dict(json.loads(serialized_obj))

@classmethod
Expand Down Expand Up @@ -356,7 +365,7 @@ def _is_primitive(cls, var: Any) -> bool:

@classmethod
def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool:
"""Types excluded from serialization."""
"""Check if type is excluded from serialization."""
if var is None:
if not cls._is_constructor_param(attrname, instance):
# Any instance attribute, that is not a constructor argument, we exclude None as the default
Expand Down
2 changes: 1 addition & 1 deletion airflow/triggers/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:

async def run(self):
"""
Simple time delay loop until the relevant time is met.
Loop until the relevant time is met.

We do have a two-phase delay to save some cycles, but sleeping is so
cheap anyway that it's pretty loose. We also don't just sleep for
Expand Down