Skip to content

Commit

Permalink
Implement | and & operators so that they can be used instead of Datas…
Browse files Browse the repository at this point in the history
…etAll and DatasetAny.
  • Loading branch information
sunank200 committed Jan 30, 2024
1 parent 8de16a8 commit ee9de54
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 6 deletions.
10 changes: 10 additions & 0 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,13 @@ def _check_uri(self, attr, uri: str):

def __fspath__(self):
return self.uri

def __or__(self, other):
from airflow.models.dataset import DatasetAny

return DatasetAny(self, other)

def __and__(self, other):
from airflow.models.dataset import DatasetAll

return DatasetAll(self, other)
31 changes: 31 additions & 0 deletions airflow/example_dags/example_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,34 @@
task_id="consuming_2",
bash_command="sleep 5",
)

with DAG(
dag_id="consume_1_and_2_with_dataset_expressions",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
schedule=(dag1_dataset & dag2_dataset),
) as dag5:
BashOperator(
outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")],
task_id="consume_1_and_2_with_dataset_expressions",
bash_command="sleep 5",
)
with DAG(
dag_id="consume_1_or_2_with_dataset_expressions",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
schedule=(dag1_dataset | dag2_dataset),
) as dag6:
BashOperator(
outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")],
task_id="consume_1_or_2_with_dataset_expressions",
bash_command="sleep 5",
)
with DAG(
dag_id="consume_1_or_-2_and_3_with_dataset_expressions",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
schedule=(dag1_dataset | (dag2_dataset& dag3_dataset)),
) as dag7:
BashOperator(
outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")],
task_id="consume_1_or_-2_and_3_with_dataset_expressions",
bash_command="sleep 5",
)
8 changes: 6 additions & 2 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,12 @@
from airflow.models.dagrun import RUN_ID_REGEX, DagRun
from airflow.models.dataset import (
DatasetAll,
DatasetAny,
DatasetBooleanCondition,
DatasetDagRunQueue,
DatasetModel,
DatasetAny,
DatasetsExpression,
extract_datasets,
)
from airflow.models.param import DagParam, ParamsDict
from airflow.models.taskinstance import (
Expand Down Expand Up @@ -172,7 +174,7 @@
# but Mypy cannot handle that right now. Track progress of PEP 661 for progress.
# See also: https://discuss.python.org/t/9126/7
ScheduleIntervalArg = Union[ArgNotSet, ScheduleInterval]
ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable, Collection["Dataset"]]
ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable, Collection["Dataset"], DatasetsExpression]

SLAMissCallback = Callable[["DAG", str, str, List["SlaMiss"], List[TaskInstance]], None]

Expand Down Expand Up @@ -585,6 +587,8 @@ def __init__(
self.timetable: Timetable
self.schedule_interval: ScheduleInterval
self.dataset_triggers: DatasetBooleanCondition | None = None
if isinstance(schedule, DatasetsExpression):
self.schedule = extract_datasets(dataset_expression=schedule)
if isinstance(schedule, (DatasetAll, DatasetAny)):
self.dataset_triggers = schedule
if isinstance(schedule, Collection) and not isinstance(schedule, str):
Expand Down
110 changes: 106 additions & 4 deletions airflow/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

from typing import Callable, Iterable
from urllib.parse import urlsplit

import sqlalchemy_jsonfield
Expand All @@ -34,7 +35,6 @@
)
from sqlalchemy.orm import relationship

from airflow import Dataset
from airflow.datasets import Dataset
from airflow.models.base import Base, StringID
from airflow.settings import json
Expand Down Expand Up @@ -341,24 +341,38 @@ def __repr__(self) -> str:

class DatasetBooleanCondition:
"""
:meta private:
Base class for boolean conditions on datasets. This class is intended for internal use only.
:param objects: A variable number of Dataset, DatasetAny, or DatasetAll instances.
"""

agg_func = None
type = None
agg_func: Callable[[Iterable[object]], bool] | None = None
type: str | None = None

def __init__(self, *objects):
self.objects = objects

def evaluate(self, statuses: dict[str, bool]):
"""
Evaluates the boolean condition based on the statuses of datasets.
:param statuses: A dictionary mapping dataset URIs to their boolean statuses.
"""
return self.agg_func(self.eval_one(x, statuses) for x in self.objects)

def eval_one(self, obj: Dataset | DatasetAny | DatasetAll, statuses):
"""
Evaluates the status of a single object (Dataset, DatasetAny, or DatasetAll).
:param obj: The Dataset, DatasetAny, or DatasetAll instance to evaluate.
:param statuses: A dictionary mapping dataset URIs to their boolean statuses.
"""
if isinstance(obj, Dataset):
return statuses.get(obj.uri, False)
return obj.evaluate(statuses=statuses)

def all_datasets(self) -> dict[str, Dataset]:
"""Retrieves all unique datasets contained within the boolean condition."""
uris = {}
for x in self.objects:
if isinstance(x, Dataset):
Expand All @@ -373,10 +387,98 @@ def all_datasets(self) -> dict[str, Dataset]:


class DatasetAny(DatasetBooleanCondition):
"""
Represents a logical OR condition of datasets.
Inherits from DatasetBooleanCondition.
"""

type = "OR"
agg_func = any

def __init__(self, *objects: Dataset | DatasetAny | DatasetAll):
"""Initialize with one or more Dataset, DatasetAny, or DatasetAll instances."""
super().__init__(*objects)

def __or__(self, other):
if isinstance(other, (Dataset, DatasetAny, DatasetAll)):
return DatasetAny(*self.objects, other)
return NotImplemented

def __and__(self, other):
if isinstance(other, (Dataset, DatasetAny, DatasetAll)):
return DatasetAll(self, other)
return NotImplemented

def __repr__(self) -> str:
return f"DatasetAny({', '.join(map(str, self.objects))})"


class DatasetAll(DatasetBooleanCondition):
"""Represents a logical AND condition of datasets. Inherits from DatasetBooleanCondition."""

type = "AND"
agg_func = all

def __init__(self, *objects: Dataset | DatasetAny | DatasetAll):
"""Initialize with one or more Dataset, DatasetAny, or DatasetAll instances."""
super().__init__(*objects)

def __or__(self, other):
if isinstance(other, (Dataset, DatasetAny, DatasetAll)):
return DatasetAny(self, other)
return NotImplemented

def __and__(self, other):
if isinstance(other, (Dataset, DatasetAny, DatasetAll)):
return DatasetAll(*self.objects, other)
return NotImplemented

def __repr__(self) -> str:
return f"DatasetAnd({', '.join(map(str, self.objects))})"


class DatasetsExpression:
"""
Represents a node in an expression tree for dataset conditions.
:param value: The value of the node, which can be a 'Dataset', '&', or '|'.
:param left: The left child node.
:param right: The right child node.
"""

def __init__(self, value, left=None, right=None):
self.value = value # value can be 'Dataset', '&', or '|'
self.left = left
self.right = right

def __or__(self, other: Dataset | DatasetsExpression) -> DatasetsExpression:
return DatasetsExpression("|", self, other)

def __and__(self, other: Dataset | DatasetsExpression) -> DatasetsExpression:
return DatasetsExpression("&", self, other)

def __repr__(self):
if isinstance(self.value, Dataset):
return f"Dataset(uri='{self.value.uri}')"
elif self.value == "&":
return repr(DatasetAll(self.left, self.right))
elif self.value == "|":
return repr(DatasetAny(self.left, self.right))


def extract_datasets(dataset_expression: DatasetsExpression | Dataset):
"""
Extracts the dataset(s) from an DatasetsExpression.
:param dataset_expression: The DatasetsExpression to extract from.
"""
if isinstance(dataset_expression, DatasetsExpression):
if dataset_expression.value == "&":
return DatasetAll(dataset_expression.left, dataset_expression.right)
elif dataset_expression.value == "|":
return DatasetAny(dataset_expression.left, dataset_expression.right)
else:
raise ValueError("Invalid Expression node value")
else:
return dataset_expression
87 changes: 87 additions & 0 deletions tests/models/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import pytest

from airflow.models.dataset import Dataset, DatasetAll, DatasetAny, extract_datasets


def datasets_equal(d1, d2):
if type(d1) != type(d2):
return False

if isinstance(d1, Dataset):
return d1.uri == d2.uri

elif isinstance(d1, (DatasetAny, DatasetAll)):
if len(d1.objects) != len(d2.objects):
return False

# Compare each pair of objects
for obj1, obj2 in zip(d1.objects, d2.objects):
# If obj1 or obj2 is a Dataset, DatasetAny, or DatasetAll instance,
# recursively call datasets_equal
if not datasets_equal(obj1, obj2):
return False
return True

return False


dataset1 = Dataset(uri="s3://bucket1/data1")
dataset2 = Dataset(uri="s3://bucket2/data2")
dataset3 = Dataset(uri="s3://bucket3/data3")
dataset4 = Dataset(uri="s3://bucket4/data4")
dataset5 = Dataset(uri="s3://bucket5/data5")

test_cases = [
(lambda: dataset1, dataset1),
(lambda: dataset1 & dataset2, DatasetAll(dataset1, dataset2)),
(lambda: dataset1 | dataset2, DatasetAny(dataset1, dataset2)),
(lambda: dataset1 | (dataset2 & dataset3), DatasetAny(dataset1, DatasetAll(dataset2, dataset3))),
(lambda: dataset1 | dataset2 & dataset3, DatasetAny(dataset1, DatasetAll(dataset2, dataset3))),
(
lambda: ((dataset1 & dataset2) | dataset3) & (dataset4 | dataset5),
DatasetAll(DatasetAny(DatasetAll(dataset1, dataset2), dataset3), DatasetAny(dataset4, dataset5)),
),
(lambda: dataset1 & dataset2 | dataset3, DatasetAny(DatasetAll(dataset1, dataset2), dataset3)),
(
lambda: (dataset1 | dataset2) & (dataset3 | dataset4),
DatasetAll(DatasetAny(dataset1, dataset2), DatasetAny(dataset3, dataset4)),
),
(
lambda: (dataset1 & dataset2) | (dataset3 & (dataset4 | dataset5)),
DatasetAny(DatasetAll(dataset1, dataset2), DatasetAll(dataset3, DatasetAny(dataset4, dataset5))),
),
(
lambda: (dataset1 & dataset2) & (dataset3 & dataset4),
DatasetAll(dataset1, dataset2, DatasetAll(dataset3, dataset4)),
),
(lambda: dataset1 | dataset2 | dataset3, DatasetAny(dataset1, dataset2, dataset3)),
(
lambda: ((dataset1 & dataset2) | dataset3) & (dataset4 | dataset5),
DatasetAll(DatasetAny(DatasetAll(dataset1, dataset2), dataset3), DatasetAny(dataset4, dataset5)),
),
]


@pytest.mark.parametrize("expression, expected", test_cases)
def test_extract_datasets(expression, expected):
expr = expression()
result = extract_datasets(expr)
assert datasets_equal(result, expected)

0 comments on commit ee9de54

Please sign in to comment.