Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conditional logic for dataset triggering #37016

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
102 changes: 66 additions & 36 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from __future__ import annotations

import asyncio
import collections
import copy
import functools
import itertools
Expand All @@ -31,7 +30,7 @@
import traceback
import warnings
import weakref
from collections import deque
from collections import abc, defaultdict, deque
from contextlib import ExitStack
from datetime import datetime, timedelta
from inspect import signature
Expand Down Expand Up @@ -99,6 +98,13 @@
from airflow.models.dagcode import DagCode
from airflow.models.dagpickle import DagPickle
from airflow.models.dagrun import RUN_ID_REGEX, DagRun
from airflow.models.dataset import (
DatasetAll,
DatasetAny,
DatasetBooleanCondition,
DatasetDagRunQueue,
DatasetModel,
)
from airflow.models.param import DagParam, ParamsDict
from airflow.models.taskinstance import (
Context,
Expand Down Expand Up @@ -462,7 +468,7 @@ def __init__(
on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
doc_md: str | None = None,
params: collections.abc.MutableMapping | None = None,
params: abc.MutableMapping | None = None,
access_control: dict | None = None,
is_paused_upon_creation: bool | None = None,
jinja_environment_kwargs: dict | None = None,
Expand Down Expand Up @@ -580,25 +586,28 @@ def __init__(

self.timetable: Timetable
self.schedule_interval: ScheduleInterval
self.dataset_triggers: Collection[Dataset] = []

self.dataset_triggers: DatasetBooleanCondition | None = None
if isinstance(schedule, (DatasetAll, DatasetAny)):
self.dataset_triggers = schedule
if isinstance(schedule, Collection) and not isinstance(schedule, str):
from airflow.datasets import Dataset

if not all(isinstance(x, Dataset) for x in schedule):
raise ValueError("All elements in 'schedule' should be datasets")
self.dataset_triggers = list(schedule)
self.dataset_triggers = DatasetAll(*schedule)
elif isinstance(schedule, Timetable):
timetable = schedule
elif schedule is not NOTSET:
schedule_interval = schedule

if self.dataset_triggers:
if isinstance(schedule, DatasetOrTimeSchedule):
self.timetable = schedule
self.dataset_triggers = self.timetable.datasets
self.schedule_interval = self.timetable.summary
elif self.dataset_triggers:
self.timetable = DatasetTriggeredTimetable()
self.schedule_interval = self.timetable.summary
elif timetable:
if isinstance(timetable, DatasetOrTimeSchedule):
self.dataset_triggers = timetable.datasets
self.timetable = timetable
self.schedule_interval = self.timetable.summary
else:
Expand Down Expand Up @@ -3156,8 +3165,8 @@ def bulk_write_to_db(
TaskOutletDatasetReference,
)

dag_references = collections.defaultdict(set)
outlet_references = collections.defaultdict(set)
dag_references = defaultdict(set)
outlet_references = defaultdict(set)
# We can't use a set here as we want to preserve order
outlet_datasets: dict[DatasetModel, None] = {}
input_datasets: dict[DatasetModel, None] = {}
Expand All @@ -3168,12 +3177,13 @@ def bulk_write_to_db(
# later we'll persist them to the database.
for dag in dags:
curr_orm_dag = existing_dags.get(dag.dag_id)
if not dag.dataset_triggers:
if dag.dataset_triggers is None:
if curr_orm_dag and curr_orm_dag.schedule_dataset_references:
curr_orm_dag.schedule_dataset_references = []
for dataset in dag.dataset_triggers:
dag_references[dag.dag_id].add(dataset.uri)
input_datasets[DatasetModel.from_public(dataset)] = None
else:
for dataset in dag.dataset_triggers.all_datasets().values():
dag_references[dag.dag_id].add(dataset.uri)
input_datasets[DatasetModel.from_public(dataset)] = None
curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references
for task in dag.tasks:
dataset_outlets = [x for x in task.outlets or [] if isinstance(x, Dataset)]
Expand Down Expand Up @@ -3229,7 +3239,7 @@ def bulk_write_to_db(
for obj in dag_refs_stored - dag_refs_needed:
session.delete(obj)

existing_task_outlet_refs_dict = collections.defaultdict(set)
existing_task_outlet_refs_dict = defaultdict(set)
for dag_id, orm_dag in existing_dags.items():
for todr in orm_dag.task_outlet_dataset_references:
existing_task_outlet_refs_dict[(dag_id, todr.task_id)].add(todr)
Expand Down Expand Up @@ -3512,7 +3522,7 @@ def __repr__(self):

@classmethod
def get_all(cls, session) -> dict[str, dict[str, str]]:
dag_links: dict = collections.defaultdict(dict)
dag_links: dict = defaultdict(dict)
for obj in session.scalars(select(cls)):
dag_links[obj.dag_id].update({obj.owner: obj.link})
return dag_links
Expand Down Expand Up @@ -3781,23 +3791,43 @@ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[
you should ensure that any scheduling decisions are made in a single transaction -- as soon as the
transaction is committed it will be unlocked.
"""
from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue as DDRQ

# these dag ids are triggered by datasets, and they are ready to go.
dataset_triggered_dag_info = {
x.dag_id: (x.first_queued_time, x.last_queued_time)
for x in session.execute(
select(
DagScheduleDatasetReference.dag_id,
func.max(DDRQ.created_at).label("last_queued_time"),
func.min(DDRQ.created_at).label("first_queued_time"),
)
.join(DagScheduleDatasetReference.queue_records, isouter=True)
.group_by(DagScheduleDatasetReference.dag_id)
.having(func.count() == func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)))
)
}
dataset_triggered_dag_ids = set(dataset_triggered_dag_info)
from airflow.models.serialized_dag import SerializedDagModel

def dag_ready(dag_id: str, cond: DatasetBooleanCondition, statuses: dict) -> bool | None:
# if dag was serialized before 2.9 and we *just* upgraded,
# we may be dealing with old version. In that case,
# just wait for the dag to be reserialized.
try:
return cond.evaluate(statuses)
except AttributeError:
log.warning("dag '%s' has old serialization; skipping dag run creation.", dag_id)
return None

# this loads all the DDRQ records.... may need to limit num dags
all_records = session.scalars(select(DatasetDagRunQueue)).all()
by_dag = defaultdict(list)
for r in all_records:
by_dag[r.target_dag_id].append(r)
del all_records
dag_statuses = {}
for dag_id, records in by_dag.items():
dag_statuses[dag_id] = {x.dataset.uri: True for x in records}
ser_dags = session.scalars(
select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys()))
).all()
for ser_dag in ser_dags:
dag_id = ser_dag.dag_id
statuses = dag_statuses[dag_id]
if not dag_ready(dag_id, cond=ser_dag.dag.dataset_triggers, statuses=statuses):
del by_dag[dag_id]
del dag_statuses[dag_id]
del dag_statuses
dataset_triggered_dag_info = {}
for dag_id, records in by_dag.items():
times = sorted(x.created_at for x in records)
dataset_triggered_dag_info[dag_id] = (times[0], times[-1])
del by_dag
dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys())
if dataset_triggered_dag_ids:
exclusion_list = set(
session.scalars(
Expand Down Expand Up @@ -3908,7 +3938,7 @@ def dag(
on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
doc_md: str | None = None,
params: collections.abc.MutableMapping | None = None,
params: abc.MutableMapping | None = None,
access_control: dict | None = None,
is_paused_upon_creation: bool | None = None,
jinja_environment_kwargs: dict | None = None,
Expand Down Expand Up @@ -4030,7 +4060,7 @@ class DagContext:

"""

_context_managed_dags: collections.deque[DAG] = deque()
_context_managed_dags: deque[DAG] = deque()
autoregistered_dags: set[tuple[DAG, ModuleType]] = set()
current_autoregister_module_name: str | None = None

Expand Down
49 changes: 48 additions & 1 deletion airflow/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

from typing import Callable, Iterable
from urllib.parse import urlsplit

import sqlalchemy_jsonfield
Expand Down Expand Up @@ -208,7 +209,7 @@ class DatasetDagRunQueue(Base):
dataset_id = Column(Integer, primary_key=True, nullable=False)
target_dag_id = Column(StringID(), primary_key=True, nullable=False)
created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)

dataset = relationship("DatasetModel", viewonly=True)
__tablename__ = "dataset_dag_run_queue"
__table_args__ = (
PrimaryKeyConstraint(dataset_id, target_dag_id, name="datasetdagrunqueue_pkey"),
Expand Down Expand Up @@ -336,3 +337,49 @@ def __repr__(self) -> str:
]:
args.append(f"{attr}={getattr(self, attr)!r}")
return f"{self.__class__.__name__}({', '.join(args)})"


class DatasetBooleanCondition:
dstandish marked this conversation as resolved.
Show resolved Hide resolved
"""
Base class for boolean logic for dataset triggers.

:meta private:
"""

agg_func: Callable[[Iterable], bool]

def __init__(self, *objects) -> None:
self.objects = objects

def evaluate(self, statuses: dict[str, bool]) -> bool:
return self.agg_func(self.eval_one(x, statuses) for x in self.objects)

def eval_one(self, obj: Dataset | DatasetAny | DatasetAll, statuses) -> bool:
if isinstance(obj, Dataset):
return statuses.get(obj.uri, False)
return obj.evaluate(statuses=statuses)

def all_datasets(self) -> dict[str, Dataset]:
dstandish marked this conversation as resolved.
Show resolved Hide resolved
uris = {}
for x in self.objects:
if isinstance(x, Dataset):
if x.uri not in uris:
uris[x.uri] = x
else:
# keep the first instance
for k, v in x.all_datasets().items():
if k not in uris:
uris[k] = v
return uris


class DatasetAny(DatasetBooleanCondition):
"""Use to combine datasets schedule references in an "and" relationship."""

agg_func = any


class DatasetAll(DatasetBooleanCondition):
"""Use to combine datasets schedule references in an "or" relationship."""

agg_func = all
2 changes: 2 additions & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class DagAttributeTypes(str, Enum):
PARAM = "param"
XCOM_REF = "xcomref"
DATASET = "dataset"
DATASET_ANY = "dataset_any"
DATASET_ALL = "dataset_all"
SIMPLE_TASK_INSTANCE = "simple_task_instance"
BASE_JOB = "Job"
TASK_INSTANCE = "task_instance"
Expand Down
36 changes: 33 additions & 3 deletions airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,36 @@
],
"additionalProperties": false
},
"typed_dataset_cond": {
"type": "object",
"properties": {
"__type": {
"anyOf": [{
"type": "string",
"constant": "dataset_or"
},
{
"type": "string",
"constant": "dataset_and"
}
]
},
"__var": {
"type": "array",
"items": {
"anyOf": [
{"$ref": "#/definitions/typed_dataset"},
{ "$ref": "#/definitions/typed_dataset_cond"}
]
}
}
},
"required": [
"__type",
"__var"
],
"additionalProperties": false
},
"dict": {
"description": "A python dictionary containing values of any type",
"type": "object"
Expand Down Expand Up @@ -119,9 +149,9 @@
]
},
"dataset_triggers": {
"type": "array",
"items": { "$ref": "#/definitions/typed_dataset" }
},
"$ref": "#/definitions/typed_dataset_cond"

},
"owner_links": { "type": "object" },
"timetable": {
"type": "object",
Expand Down
29 changes: 27 additions & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from airflow.models.connection import Connection
from airflow.models.dag import DAG, DagModel, create_timetable
from airflow.models.dagrun import DagRun
from airflow.models.dataset import DatasetAll, DatasetAny
from airflow.models.expandinput import EXPAND_INPUT_EMPTY, create_expand_input, get_map_type_key
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import Param, ParamsDict
Expand Down Expand Up @@ -404,6 +405,8 @@ def serialize_to_json(
serialized_object[key] = cls.serialize(value)
elif key == "timetable" and value is not None:
serialized_object[key] = encode_timetable(value)
elif key == "dataset_triggers":
serialized_object[key] = cls.serialize(value)
else:
value = cls.serialize(value)
if isinstance(value, dict) and Encoding.TYPE in value:
Expand Down Expand Up @@ -497,6 +500,22 @@ def serialize(
return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
elif isinstance(var, Dataset):
return cls._encode({"uri": var.uri, "extra": var.extra}, type_=DAT.DATASET)
elif isinstance(var, DatasetAll):
return cls._encode(
[
cls.serialize(x, strict=strict, use_pydantic_models=use_pydantic_models)
for x in var.objects
],
type_=DAT.DATASET_ALL,
)
elif isinstance(var, DatasetAny):
return cls._encode(
[
cls.serialize(x, strict=strict, use_pydantic_models=use_pydantic_models)
for x in var.objects
],
type_=DAT.DATASET_ANY,
)
elif isinstance(var, SimpleTaskInstance):
return cls._encode(
cls.serialize(var.__dict__, strict=strict, use_pydantic_models=use_pydantic_models),
Expand Down Expand Up @@ -587,6 +606,10 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG.
elif type_ == DAT.DATASET:
return Dataset(**var)
elif type_ == DAT.DATASET_ANY:
return DatasetAny(*(cls.deserialize(x) for x in var))
elif type_ == DAT.DATASET_ALL:
return DatasetAll(*(cls.deserialize(x) for x in var))
elif type_ == DAT.SIMPLE_TASK_INSTANCE:
return SimpleTaskInstance(**cls.deserialize(var))
elif type_ == DAT.CONNECTION:
Expand Down Expand Up @@ -763,12 +786,14 @@ def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]:
"""Detect dependencies set directly on the DAG object."""
if not dag:
return
for x in dag.dataset_triggers:
if not dag.dataset_triggers:
return
for uri in dag.dataset_triggers.all_datasets().keys():
yield DagDependency(
source="dataset",
target=dag.dag_id,
dependency_type="dataset",
dependency_id=x.uri,
dependency_id=uri,
)


Expand Down