Skip to content

Commit

Permalink
Filter Datasets by associated dag_ids (GET /datasets) (apache#37512)
Browse files Browse the repository at this point in the history
Co-authored-by: Brent Bovenzi <brent.bovenzi@gmail.com>
  • Loading branch information
2 people authored and abhishekbhakat committed Mar 5, 2024
1 parent 0fd8096 commit 1ca508e
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 1 deletion.
10 changes: 10 additions & 0 deletions airflow/api_connexion/endpoints/dataset_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
from airflow.api_connexion.exceptions import NotFound
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_datetime, format_parameters
from airflow.api_connexion.schemas.dataset_schema import (
DagScheduleDatasetReference,
DatasetCollection,
DatasetEventCollection,
QueuedEvent,
QueuedEventCollection,
TaskOutletDatasetReference,
dataset_collection_schema,
dataset_event_collection_schema,
dataset_schema,
Expand Down Expand Up @@ -73,6 +75,7 @@ def get_datasets(
limit: int,
offset: int = 0,
uri_pattern: str | None = None,
dag_ids: str | None = None,
order_by: str = "id",
session: Session = NEW_SESSION,
) -> APIResponse:
Expand All @@ -81,6 +84,13 @@ def get_datasets(

total_entries = session.scalars(select(func.count(DatasetModel.id))).one()
query = select(DatasetModel)

if dag_ids:
dags_list = dag_ids.split(",")
query = query.filter(
(DatasetModel.consuming_dags.any(DagScheduleDatasetReference.dag_id.in_(dags_list)))
| (DatasetModel.producing_tasks.any(TaskOutletDatasetReference.dag_id.in_(dags_list)))
)
if uri_pattern:
query = query.where(DatasetModel.uri.ilike(f"%{uri_pattern}%"))
query = apply_sorting(query, order_by, {}, allowed_attrs)
Expand Down
9 changes: 9 additions & 0 deletions airflow/api_connexion/openapi/v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2068,6 +2068,15 @@ paths:
required: false
description: |
If set, only return datasets with uris matching this pattern.
- name: dag_ids
in: query
schema:
type: string
required: false
description: |
One or more DAG IDs separated by commas to filter datasets by associated DAGs either consuming or producing.
*New in version 2.9.0*
responses:
"200":
description: Success.
Expand Down
6 changes: 6 additions & 0 deletions airflow/www/static/js/types/api-generated.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4504,6 +4504,12 @@ export interface operations {
order_by?: components["parameters"]["OrderBy"];
/** If set, only return datasets with uris matching this pattern. */
uri_pattern?: string;
/**
* One or more DAG IDs separated by commas to filter datasets by associated DAGs either consuming or producing.
*
* *New in version 2.9.0*
*/
dag_ids?: string;
};
};
responses: {
Expand Down
60 changes: 59 additions & 1 deletion tests/api_connexion/endpoints/test_dataset_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,15 @@
import time_machine

from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
from airflow.models import DagModel
from airflow.models.dagrun import DagRun
from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel
from airflow.models.dataset import (
DagScheduleDatasetReference,
DatasetDagRunQueue,
DatasetEvent,
DatasetModel,
TaskOutletDatasetReference,
)
from airflow.security import permissions
from airflow.utils import timezone
from airflow.utils.session import provide_session
Expand Down Expand Up @@ -252,6 +259,57 @@ def test_filter_datasets_by_uri_pattern_works(self, url, expected_datasets, sess
dataset_urls = {dataset["uri"] for dataset in response.json["datasets"]}
assert expected_datasets == dataset_urls

@pytest.mark.parametrize("dag_ids, expected_num", [("dag1,dag2", 2), ("dag3", 1), ("dag2,dag3", 2)])
@provide_session
def test_filter_datasets_by_dag_ids_works(self, dag_ids, expected_num, session):
session.query(DagModel).delete()
session.commit()
dag1 = DagModel(dag_id="dag1")
dag2 = DagModel(dag_id="dag2")
dag3 = DagModel(dag_id="dag3")
dataset1 = DatasetModel("s3://folder/key")
dataset2 = DatasetModel("gcp://bucket/key")
dataset3 = DatasetModel("somescheme://dataset/key")
dag_ref1 = DagScheduleDatasetReference(dag_id="dag1", dataset=dataset1)
dag_ref2 = DagScheduleDatasetReference(dag_id="dag2", dataset=dataset2)
task_ref1 = TaskOutletDatasetReference(dag_id="dag3", task_id="task1", dataset=dataset3)
session.add_all([dataset1, dataset2, dataset3, dag1, dag2, dag3, dag_ref1, dag_ref2, task_ref1])
session.commit()
response = self.client.get(
f"/api/v1/datasets?dag_ids={dag_ids}", environ_overrides={"REMOTE_USER": "test"}
)
assert response.status_code == 200
response_data = response.json
assert len(response_data["datasets"]) == expected_num

@pytest.mark.parametrize(
"dag_ids, uri_pattern,expected_num",
[("dag1,dag2", "folder", 1), ("dag3", "nothing", 0), ("dag2,dag3", "key", 2)],
)
def test_filter_datasets_by_dag_ids_and_uri_pattern_works(
self, dag_ids, uri_pattern, expected_num, session
):
session.query(DagModel).delete()
session.commit()
dag1 = DagModel(dag_id="dag1")
dag2 = DagModel(dag_id="dag2")
dag3 = DagModel(dag_id="dag3")
dataset1 = DatasetModel("s3://folder/key")
dataset2 = DatasetModel("gcp://bucket/key")
dataset3 = DatasetModel("somescheme://dataset/key")
dag_ref1 = DagScheduleDatasetReference(dag_id="dag1", dataset=dataset1)
dag_ref2 = DagScheduleDatasetReference(dag_id="dag2", dataset=dataset2)
task_ref1 = TaskOutletDatasetReference(dag_id="dag3", task_id="task1", dataset=dataset3)
session.add_all([dataset1, dataset2, dataset3, dag1, dag2, dag3, dag_ref1, dag_ref2, task_ref1])
session.commit()
response = self.client.get(
f"/api/v1/datasets?dag_ids={dag_ids}&uri_pattern={uri_pattern}",
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 200
response_data = response.json
assert len(response_data["datasets"]) == expected_num


class TestGetDatasetsEndpointPagination(TestDatasetEndpoint):
@pytest.mark.parametrize(
Expand Down

0 comments on commit 1ca508e

Please sign in to comment.