diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py index f9090a6ece031..fe4696efbb5a0 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py @@ -31,7 +31,11 @@ PartitionedDagRunDetailResponse, PartitionedDagRunResponse, ) -from airflow.api_fastapi.core_api.security import requires_access_asset +from airflow.api_fastapi.core_api.security import ( + ReadableDagsFilterDep, + requires_access_asset, + requires_access_dag, +) from airflow.models import DagModel from airflow.models.asset import ( AssetModel, @@ -63,6 +67,7 @@ def _build_response(row, required_count: int) -> PartitionedDagRunResponse: ) def get_partitioned_dag_runs( session: SessionDep, + readable_dags_filter: ReadableDagsFilterDep, dag_id: QueryPartitionedDagRunDagIdFilter, has_created_dag_run_id: QueryPartitionedDagRunHasCreatedDagRunIdFilter, ) -> PartitionedDagRunCollectionResponse: @@ -123,6 +128,9 @@ def get_partitioned_dag_runs( received_subq.label("total_received"), ).outerjoin(DagRun, AssetPartitionDagRun.created_dag_run_id == DagRun.id) query = apply_filters_to_select(statement=query, filters=[dag_id, has_created_dag_run_id]) + readable_dag_ids = readable_dags_filter.value + if readable_dag_ids is not None: + query = query.where(AssetPartitionDagRun.target_dag_id.in_(readable_dag_ids)) query = query.order_by(AssetPartitionDagRun.created_at.desc()) if not (rows := session.execute(query).all()): @@ -162,7 +170,7 @@ def get_partitioned_dag_runs( @partitioned_dag_runs_router.get( "/pending_partitioned_dag_run/{dag_id}/{partition_key}", - dependencies=[Depends(requires_access_asset(method="GET"))], + dependencies=[Depends(requires_access_asset(method="GET")), Depends(requires_access_dag(method="GET"))], ) def get_pending_partitioned_dag_run( dag_id: str, diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py index 658d23abb214f..a2a86da327b86 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +from unittest import mock + import pendulum import pytest from sqlalchemy import select @@ -57,7 +59,7 @@ def test_should_response_200_non_partitioned_dag_returns_empty(self, test_client dag_maker.create_dagrun() dag_maker.sync_dagbag_to_db() - with assert_queries_count(2): + with assert_queries_count(3): resp = test_client.get("/partitioned_dag_runs?dag_id=normal&has_created_dag_run_id=false") assert resp.status_code == 200 assert resp.json() == {"partitioned_dag_runs": [], "total": 0, "asset_expressions": None} @@ -144,7 +146,7 @@ def test_should_response_200( ) session.commit() - with assert_queries_count(2): + with assert_queries_count(3): resp = test_client.get( f"/partitioned_dag_runs?dag_id=list_dag" f"&has_created_dag_run_id={str(has_created_dag_run_id).lower()}" @@ -218,6 +220,24 @@ def _make_schedule(prefix, count): assert pdr_resp["total_required"] == num_target_assets assert pdr_resp["total_received"] == received_count + @mock.patch( + "airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_authorized_dag_ids", + return_value={"other_dag"}, + ) + def test_partitioned_dag_runs_filters_unreadable_dags(self, _, test_client, dag_maker, session): + schedule = PartitionedAssetTimetable(assets=Asset(uri="s3://bucket/a", name="a")) + with dag_maker(dag_id="restricted_dag", schedule=schedule, serialized=True): + EmptyOperator(task_id="t") + dag_maker.sync_dagbag_to_db() + session.add(AssetPartitionDagRun(target_dag_id="restricted_dag", partition_key="2024-06-01")) + session.commit() + + resp = test_client.get("/partitioned_dag_runs?has_created_dag_run_id=false") + assert resp.status_code == 200 + body = resp.json() + dag_ids = {r["dag_id"] for r in body["partitioned_dag_runs"]} + assert "restricted_dag" not in dag_ids + class TestGetPendingPartitionedDagRun: def test_should_response_401(self, unauthenticated_test_client):