Skip to content

Commit

Permalink
Simplify dataset serialization (apache#38694)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored and utkarsharma2 committed Apr 22, 2024
1 parent c7c6436 commit 43a6729
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 51 deletions.
31 changes: 31 additions & 0 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ def __and__(self, other: BaseDatasetEventInput) -> DatasetAll:
return NotImplemented
return DatasetAll(self, other)

def as_expression(self) -> Any:
"""Serialize the dataset into its scheduling expression.
The return value is stored in DagModel for display purposes. It must be
JSON-compatible.
:meta private:
"""
raise NotImplementedError

def evaluate(self, statuses: dict[str, bool]) -> bool:
raise NotImplementedError

Expand Down Expand Up @@ -135,6 +145,13 @@ def __eq__(self, other: Any) -> bool:
def __hash__(self) -> int:
return hash(self.uri)

def as_expression(self) -> Any:
"""Serialize the dataset into its scheduling expression.
:meta private:
"""
return self.uri

def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
yield self.uri, self

Expand Down Expand Up @@ -179,6 +196,13 @@ def __or__(self, other: BaseDatasetEventInput) -> DatasetAny:
def __repr__(self) -> str:
return f"DatasetAny({', '.join(map(str, self.objects))})"

def as_expression(self) -> dict[str, Any]:
"""Serialize the dataset into its scheduling expression.
:meta private:
"""
return {"any": [o.as_expression() for o in self.objects]}


class DatasetAll(_DatasetBooleanCondition):
"""Use to combine datasets schedule references in an "or" relationship."""
Expand All @@ -193,3 +217,10 @@ def __and__(self, other: BaseDatasetEventInput) -> DatasetAll:

def __repr__(self) -> str:
return f"DatasetAll({', '.join(map(str, self.objects))})"

def as_expression(self) -> Any:
"""Serialize the dataset into its scheduling expression.
:meta private:
"""
return {"all": [o.as_expression() for o in self.objects]}
19 changes: 4 additions & 15 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3094,16 +3094,6 @@ def bulk_sync_to_db(
)
return cls.bulk_write_to_db(dags=dags, session=session)

def simplify_dataset_expression(self, dataset_expression) -> dict | None:
"""Simplifies a nested dataset expression into a 'any' or 'all' format with URIs."""
if dataset_expression is None:
return None
if dataset_expression.get("__type") == "dataset":
return dataset_expression["__var"]["uri"]

new_key = "any" if dataset_expression["__type"] == "dataset_any" else "all"
return {new_key: [self.simplify_dataset_expression(item) for item in dataset_expression["__var"]]}

@classmethod
@provide_session
def bulk_write_to_db(
Expand All @@ -3123,8 +3113,6 @@ def bulk_write_to_db(
if not dags:
return

from airflow.serialization.serialized_objects import BaseSerialization # Avoid circular import.

log.info("Sync %s DAGs", len(dags))
dag_by_ids = {dag.dag_id: dag for dag in dags}

Expand Down Expand Up @@ -3191,9 +3179,10 @@ def bulk_write_to_db(
)
orm_dag.schedule_interval = dag.schedule_interval
orm_dag.timetable_description = dag.timetable.description
orm_dag.dataset_expression = dag.simplify_dataset_expression(
BaseSerialization.serialize(dag.dataset_triggers)
)
if (dataset_triggers := dag.dataset_triggers) is None:
orm_dag.dataset_expression = None
else:
orm_dag.dataset_expression = dataset_triggers.as_expression()

orm_dag.processor_subdir = processor_subdir

Expand Down
61 changes: 25 additions & 36 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from datetime import timedelta
from io import StringIO
from pathlib import Path
from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import patch

Expand All @@ -39,12 +40,12 @@
import time_machine
from dateutil.relativedelta import relativedelta
from pendulum.tz.timezone import Timezone
from sqlalchemy import inspect
from sqlalchemy import inspect, select
from sqlalchemy.exc import SAWarning

from airflow import settings
from airflow.configuration import conf
from airflow.datasets import Dataset
from airflow.datasets import Dataset, DatasetAll, DatasetAny
from airflow.decorators import setup, task as task_decorator, teardown
from airflow.exceptions import (
AirflowException,
Expand Down Expand Up @@ -103,6 +104,9 @@
from tests.test_utils.mock_plugins import mock_plugin_manager
from tests.test_utils.timetables import cron_timetable, delta_timetable

if TYPE_CHECKING:
from sqlalchemy.orm import Session

pytestmark = pytest.mark.db_test

TEST_DATE = datetime_tz(2015, 1, 2, 0, 0)
Expand Down Expand Up @@ -2981,42 +2985,27 @@ def test_dags_needing_dagruns_dataset_triggered_dag_info_queued_times(self, sess
assert first_queued_time == DEFAULT_DATE
assert last_queued_time == DEFAULT_DATE + timedelta(hours=1)

def test_dataset_expression(self, session):
dataset_expr = {
"__type": "dataset_any",
"__var": [
{"__type": "dataset", "__var": {"extra": {"hi": "bye"}, "uri": "s3://dag1/output_1.txt"}},
{
"__type": "dataset_all",
"__var": [
{
"__type": "dataset",
"__var": {"extra": {"hi": "bye"}, "uri": "s3://dag2/output_1.txt"},
},
{
"__type": "dataset",
"__var": {"extra": {"hi": "bye"}, "uri": "s3://dag3/output_3.txt"},
},
],
},
],
}
dag_id = "test_dag_dataset_expression"
orm_dag = DagModel(
dag_id=dag_id,
dataset_expression=dataset_expr,
is_active=True,
is_paused=False,
next_dagrun=timezone.utcnow(),
next_dagrun_create_after=timezone.utcnow() + timedelta(days=1),
def test_dataset_expression(self, session: Session) -> None:
dag = DAG(
dag_id="test_dag_dataset_expression",
schedule=DatasetAny(
Dataset("s3://dag1/output_1.txt", {"hi": "bye"}),
DatasetAll(
Dataset("s3://dag2/output_1.txt", {"hi": "bye"}),
Dataset("s3://dag3/output_3.txt", {"hi": "bye"}),
),
),
start_date=datetime.datetime.min,
)
session.add(orm_dag)
session.commit()
retrieved_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one()
assert retrieved_dag.dataset_expression == dataset_expr
DAG.bulk_write_to_db([dag], session=session)

session.rollback()
session.close()
expression = session.scalars(select(DagModel.dataset_expression).filter_by(dag_id=dag.dag_id)).one()
assert expression == {
"any": [
"s3://dag1/output_1.txt",
{"all": ["s3://dag2/output_1.txt", "s3://dag3/output_3.txt"]},
]
}


class TestQueries:
Expand Down

0 comments on commit 43a6729

Please sign in to comment.