Skip to content

Commit

Permalink
working!
Browse files Browse the repository at this point in the history
  • Loading branch information
dstandish committed Jan 25, 2024
1 parent 1f63d6b commit cd01fde
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 96 deletions.
78 changes: 27 additions & 51 deletions airflow/example_dags/example_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
Turn on all the dags.
DAG dataset_produces_1 should run because it's on a schedule.
DAG produce_1 should run because it's on a schedule.
After dataset_produces_1 runs, dataset_consumes_1 should be triggered immediately
because its only dataset dependency is managed by dataset_produces_1.
After produce_1 runs, dataset_consumes_1 should be triggered immediately
because its only dataset dependency is managed by produce_1.
No other dags should be triggered. Note that even though dataset_consumes_1_and_2 depends on
the dataset in dataset_produces_1, it will not be triggered until dataset_produces_2 runs
(and dataset_produces_2 is left with no schedule so that we can trigger it manually).
the dataset in produce_1, it will not be triggered until produce_2 runs
(and produce_2 is left with no schedule so that we can trigger it manually).
Next, trigger dataset_produces_2. After dataset_produces_2 finishes,
Next, trigger produce_2. After produce_2 finishes,
dataset_consumes_1_and_2 should run.
Dags dataset_consumes_1_never_scheduled and dataset_consumes_unknown_never_scheduled should not run because
Expand All @@ -42,90 +42,66 @@

from airflow.datasets import Dataset
from airflow.models.dag import DAG
from airflow.models.dataset import DatasetOr, DatasetAnd
from airflow.models.dataset import DatasetAnd, DatasetOr
from airflow.operators.bash import BashOperator

# [START dataset_def]
dag1_dataset = Dataset("s3://dag1/output_1.txt", extra={"hi": "bye"})
# [END dataset_def]
dag2_dataset = Dataset("s3://dag2/output_1.txt", extra={"hi": "bye"})
dag3_dataset = Dataset("s3://dag3/output_1.txt", extra={"hi": "bye"})

with DAG(
dag_id="dataset_produces_1",
catchup=False,
dag_id="produce_1",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
schedule="@daily",
tags=["produces", "dataset-scheduled"],
schedule=None,
) as dag1:
# [START task_outlet]
BashOperator(outlets=[dag1_dataset], task_id="producing_task_1", bash_command="sleep 5")
# [END task_outlet]

with DAG(
dag_id="dataset_produces_2",
catchup=False,
dag_id="produce_2",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
schedule=None,
tags=["produces", "dataset-scheduled"],
) as dag2:
BashOperator(outlets=[dag2_dataset], task_id="producing_task_2", bash_command="sleep 5")

# [START dag_dep]

with DAG(
dag_id="dataset_consumes_1",
catchup=False,
dag_id="produce_3",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
schedule=[dag1_dataset],
tags=["consumes", "dataset-scheduled"],
) as dag3:
# [END dag_dep]
BashOperator(
outlets=[Dataset("s3://consuming_1_task/dataset_other.txt")],
task_id="consuming_1",
bash_command="sleep 5",
)
schedule=None,
) as dag2:
BashOperator(outlets=[dag3_dataset], task_id="producing_task_3", bash_command="sleep 5")

with DAG(
dag_id="dataset_consumes_1_and_2",
catchup=False,
dag_id="consume_1_and_2",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
schedule=DatasetOr(dag1_dataset, dag2_dataset),
tags=["consumes", "dataset-scheduled"],
schedule=DatasetAnd(dag1_dataset, dag2_dataset),
) as dag4:
BashOperator(
outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")],
task_id="consuming_2",
bash_command="sleep 5",
)

with DAG(
dag_id="dataset_consumes_1_never_scheduled",
catchup=False,
dag_id="consume_1_or_2",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
schedule=[
dag1_dataset,
Dataset("s3://this-dataset-doesnt-get-triggered"),
],
tags=["consumes", "dataset-scheduled"],
) as dag5:
schedule=DatasetOr(dag1_dataset, dag2_dataset),
) as dag4:
BashOperator(
outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")],
task_id="consuming_3",
task_id="consuming_2",
bash_command="sleep 5",
)

with DAG(
dag_id="dataset_consumes_unknown_never_scheduled",
catchup=False,
dag_id="consume_1_or_-2_and_3-",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
schedule=[
Dataset("s3://unrelated/dataset3.txt"),
Dataset("s3://unrelated/dataset_other_unknown.txt"),
],
tags=["dataset-scheduled"],
) as dag6:
schedule=DatasetOr(dag1_dataset, DatasetAnd(dag2_dataset, dag3_dataset)),
) as dag4:
BashOperator(
task_id="unrelated_task",
outlets=[Dataset("s3://unrelated_task/dataset_other_unknown.txt")],
outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")],
task_id="consuming_2",
bash_command="sleep 5",
)
72 changes: 42 additions & 30 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from __future__ import annotations

import asyncio
import collections
import copy
import functools
import itertools
Expand All @@ -31,7 +30,7 @@
import traceback
import warnings
import weakref
from collections import deque
from collections import abc, defaultdict, deque
from contextlib import ExitStack
from datetime import datetime, timedelta
from inspect import signature
Expand Down Expand Up @@ -98,7 +97,13 @@
from airflow.models.dagcode import DagCode
from airflow.models.dagpickle import DagPickle
from airflow.models.dagrun import RUN_ID_REGEX, DagRun
from airflow.models.dataset import DatasetAnd, DatasetOr, DatasetBooleanCondition
from airflow.models.dataset import (
DatasetAnd,
DatasetBooleanCondition,
DatasetDagRunQueue,
DatasetModel,
DatasetOr,
)
from airflow.models.param import DagParam, ParamsDict
from airflow.models.taskinstance import (
Context,
Expand Down Expand Up @@ -461,7 +466,7 @@ def __init__(
on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
doc_md: str | None = None,
params: collections.abc.MutableMapping | None = None,
params: abc.MutableMapping | None = None,
access_control: dict | None = None,
is_paused_upon_creation: bool | None = None,
jinja_environment_kwargs: dict | None = None,
Expand Down Expand Up @@ -3144,8 +3149,8 @@ def bulk_write_to_db(
TaskOutletDatasetReference,
)

dag_references = collections.defaultdict(set)
outlet_references = collections.defaultdict(set)
dag_references = defaultdict(set)
outlet_references = defaultdict(set)
# We can't use a set here as we want to preserve order
outlet_datasets: dict[DatasetModel, None] = {}
input_datasets: dict[DatasetModel, None] = {}
Expand Down Expand Up @@ -3218,7 +3223,7 @@ def bulk_write_to_db(
for obj in dag_refs_stored - dag_refs_needed:
session.delete(obj)

existing_task_outlet_refs_dict = collections.defaultdict(set)
existing_task_outlet_refs_dict = defaultdict(set)
for dag_id, orm_dag in existing_dags.items():
for todr in orm_dag.task_outlet_dataset_references:
existing_task_outlet_refs_dict[(dag_id, todr.task_id)].add(todr)
Expand Down Expand Up @@ -3501,7 +3506,7 @@ def __repr__(self):

@classmethod
def get_all(cls, session) -> dict[str, dict[str, str]]:
dag_links: dict = collections.defaultdict(dict)
dag_links: dict = defaultdict(dict)
for obj in session.scalars(select(cls)):
dag_links[obj.dag_id].update({obj.owner: obj.link})
return dag_links
Expand Down Expand Up @@ -3770,26 +3775,33 @@ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[
you should ensure that any scheduling decisions are made in a single transaction -- as soon as the
transaction is committed it will be unlocked.
"""
from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue as DDRQ

# this loads all the DDRQ records.... may not be the best idea
queue_records = session.execute(select(DDRQ, DAG))

# these dag ids are triggered by datasets, and they are ready to go.
dataset_triggered_dag_info = {
x.dag_id: (x.first_queued_time, x.last_queued_time)
for x in session.execute(
select(
DagScheduleDatasetReference.dag_id,
func.max(DDRQ.created_at).label("last_queued_time"),
func.min(DDRQ.created_at).label("first_queued_time"),
)
.join(DagScheduleDatasetReference.queue_records, isouter=True)
.group_by(DagScheduleDatasetReference.dag_id)
.having(func.count() == func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)))
)
}
dataset_triggered_dag_ids = set(dataset_triggered_dag_info)
from airflow.models.serialized_dag import SerializedDagModel

# this loads all the DDRQ records.... may need to limit num dags
all_records = session.scalars(select(DatasetDagRunQueue)).all()
by_dag = defaultdict(list)
for r in all_records:
by_dag[r.target_dag_id].append(r)
del all_records
dag_statuses = {}
for dag_id, records in by_dag.items():
dag_statuses[dag_id] = {x.dataset.uri: True for x in records}
ser_dags = session.scalars(
select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys()))
).all()
for ser_dag in ser_dags:
dag_id = ser_dag.dag_id
statuses = dag_statuses[dag_id]
if not ser_dag.dag.dataset_triggers.evaluate(statuses):
del by_dag[dag_id]
del dag_statuses[dag_id]
del dag_statuses
dataset_triggered_dag_info = {}
for dag_id, records in by_dag.items():
times = sorted(x.created_at for x in records)
dataset_triggered_dag_info[dag_id] = (times[0], times[-1])
del by_dag
dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys())
if dataset_triggered_dag_ids:
exclusion_list = set(
session.scalars(
Expand Down Expand Up @@ -3900,7 +3912,7 @@ def dag(
on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
doc_md: str | None = None,
params: collections.abc.MutableMapping | None = None,
params: abc.MutableMapping | None = None,
access_control: dict | None = None,
is_paused_upon_creation: bool | None = None,
jinja_environment_kwargs: dict | None = None,
Expand Down Expand Up @@ -4022,7 +4034,7 @@ class DagContext:
"""

_context_managed_dags: collections.deque[DAG] = deque()
_context_managed_dags: deque[DAG] = deque()
autoregistered_dags: set[tuple[DAG, ModuleType]] = set()
current_autoregister_module_name: str | None = None

Expand Down
3 changes: 1 addition & 2 deletions airflow/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class DatasetDagRunQueue(Base):
dataset_id = Column(Integer, primary_key=True, nullable=False)
target_dag_id = Column(StringID(), primary_key=True, nullable=False)
created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)

dataset = relationship("DatasetModel")
__tablename__ = "dataset_dag_run_queue"
__table_args__ = (
PrimaryKeyConstraint(dataset_id, target_dag_id, name="datasetdagrunqueue_pkey"),
Expand Down Expand Up @@ -361,7 +361,6 @@ def eval_one(self, obj: Dataset | DatasetOr | DatasetAnd, statuses):
def all_datasets(self) -> dict[str, Dataset]:
uris = {}
for x in self.objects:
x.uri
if isinstance(x, Dataset):
if x.uri not in uris:
uris[x.uri] = x
Expand Down
3 changes: 3 additions & 0 deletions airflow/providers/common/sql/hooks/sql.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,6 @@ class DbApiHook(BaseForDbApiHook):
def bulk_dump(self, table, tmp_file) -> None: ...
def bulk_load(self, table, tmp_file) -> None: ...
def test_connection(self): ...

def _make_common_data_structure(self, param):
pass
2 changes: 1 addition & 1 deletion airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class DagAttributeTypes(str, Enum):
XCOM_REF = "xcomref"
DATASET = "dataset"
DATASET_OR = "dataset_or"
DATASET_COND = "dataset_and"
DATASET_AND = "dataset_and"
SIMPLE_TASK_INSTANCE = "simple_task_instance"
BASE_JOB = "Job"
TASK_INSTANCE = "task_instance"
Expand Down
9 changes: 7 additions & 2 deletions airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,12 @@
},
"__var": {
"type": "array",
"items": { "$ref": "#/definitions/typed_dataset" }
"items": {
"anyOf": [
{"$ref": "#/definitions/typed_dataset"},
{ "$ref": "#/definitions/typed_dataset_cond"}
]
}
}
},
"required": [
Expand Down Expand Up @@ -144,7 +149,7 @@
]
},
"dataset_triggers": {
"$ref": "#/definitions/typed_dataset_cond"
"$ref": "#/definitions/typed_dataset_cond"

},
"owner_links": { "type": "object" },
Expand Down
45 changes: 35 additions & 10 deletions tests/datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import os
from collections import defaultdict

import pytest
from sqlalchemy.sql import select
Expand Down Expand Up @@ -139,18 +140,42 @@ def test_this(session, dag_maker):
assert isinstance(dtr, dict)
deser_dtr = SerializedDAG.deserialize(dtr)
assert isinstance(deser_dtr, DatasetOr)
assert deser_dtr.objects == dag.dataset_triggers
assert deser_dtr.objects == dag.dataset_triggers.objects
SerializedDagModel.write_dag(dag)
session.commit()
with dag_maker(dag_id="dag2") as dag2:
op = EmptyOperator(task_id="hello2")
records = session.execute(
select(DatasetDagRunQueue, DagModel).join(
DagModel, DagModel.dag_id == DatasetDagRunQueue.target_dag_id
)

# here we start the scheduling logic
records = session.scalars(select(DatasetDagRunQueue)).all()
dag_statuses = defaultdict(dict)
ddrq_times = defaultdict(list)
for ddrq in records:
dag_statuses[ddrq.target_dag_id][ddrq.dataset.uri] = True
ddrq_times[ddrq.target_dag_id].append(ddrq.created_at)
dataset_triggered_dag_info = {dag_id: (min(times), max(times)) for dag_id, times in ddrq_times.items()}
ser_dags = session.execute(
select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys()))
).all()
session.execute(select(SerializedDagModel)).all()
round_trip_triggers = records[0][1].serialized_dag.dag.dataset_triggers
assert isinstance(round_trip_triggers, DatasetOr)
assert round_trip_triggers.objects == dag.dataset_triggers.objects

for (ser_dag,) in ser_dags:
print(ser_dag)
statuses = dag_statuses[ser_dag.dag_id]
ser_dag.dag.dataset_triggers.evaluate(statuses)

def test_this2(session, dag_maker):
d1 = Dataset(uri="hello1")
d1.uri
dm1 = DatasetModel(uri=d1.uri)
d2 = Dataset(uri="hello2")
dm2 = DatasetModel(uri=d2.uri)
session.add(dm1)
session.add(dm2)
session.commit()
session.query(DagModel).all()
d1.uri
with dag_maker(schedule=DatasetOr(d1, DatasetAnd(d2, d1))) as dag:
op = EmptyOperator(task_id="hello")
dag.dataset_triggers
SerializedDAG.serialize_to_json(dag, SerializedDAG._decorated_fields)
SerializedDAG.serialize(dag.dataset_triggers).values()
dtr = SerializedDAG.to_dict(dag)["dag"]["dataset_triggers"]

0 comments on commit cd01fde

Please sign in to comment.