Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
PartitionedDagRunDetailResponse,
PartitionedDagRunResponse,
)
from airflow.api_fastapi.core_api.security import requires_access_asset
from airflow.api_fastapi.core_api.security import (
ReadableDagsFilterDep,
requires_access_asset,
requires_access_dag,
)
from airflow.models import DagModel
from airflow.models.asset import (
AssetModel,
Expand Down Expand Up @@ -63,6 +67,7 @@ def _build_response(row, required_count: int) -> PartitionedDagRunResponse:
)
def get_partitioned_dag_runs(
session: SessionDep,
readable_dags_filter: ReadableDagsFilterDep,
dag_id: QueryPartitionedDagRunDagIdFilter,
has_created_dag_run_id: QueryPartitionedDagRunHasCreatedDagRunIdFilter,
) -> PartitionedDagRunCollectionResponse:
Expand Down Expand Up @@ -123,6 +128,9 @@ def get_partitioned_dag_runs(
received_subq.label("total_received"),
).outerjoin(DagRun, AssetPartitionDagRun.created_dag_run_id == DagRun.id)
query = apply_filters_to_select(statement=query, filters=[dag_id, has_created_dag_run_id])
readable_dag_ids = readable_dags_filter.value
if readable_dag_ids is not None:
query = query.where(AssetPartitionDagRun.target_dag_id.in_(readable_dag_ids))
query = query.order_by(AssetPartitionDagRun.created_at.desc())

if not (rows := session.execute(query).all()):
Expand Down Expand Up @@ -162,7 +170,7 @@ def get_partitioned_dag_runs(

@partitioned_dag_runs_router.get(
"/pending_partitioned_dag_run/{dag_id}/{partition_key}",
dependencies=[Depends(requires_access_asset(method="GET"))],
dependencies=[Depends(requires_access_asset(method="GET")), Depends(requires_access_dag(method="GET"))],
)
def get_pending_partitioned_dag_run(
dag_id: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations

from unittest import mock

import pendulum
import pytest
from sqlalchemy import select
Expand Down Expand Up @@ -57,7 +59,7 @@ def test_should_response_200_non_partitioned_dag_returns_empty(self, test_client
dag_maker.create_dagrun()
dag_maker.sync_dagbag_to_db()

with assert_queries_count(2):
with assert_queries_count(3):
resp = test_client.get("/partitioned_dag_runs?dag_id=normal&has_created_dag_run_id=false")
assert resp.status_code == 200
assert resp.json() == {"partitioned_dag_runs": [], "total": 0, "asset_expressions": None}
Expand Down Expand Up @@ -144,7 +146,7 @@ def test_should_response_200(
)
session.commit()

with assert_queries_count(2):
with assert_queries_count(3):
resp = test_client.get(
f"/partitioned_dag_runs?dag_id=list_dag"
f"&has_created_dag_run_id={str(has_created_dag_run_id).lower()}"
Expand Down Expand Up @@ -218,6 +220,24 @@ def _make_schedule(prefix, count):
assert pdr_resp["total_required"] == num_target_assets
assert pdr_resp["total_received"] == received_count

@mock.patch(
"airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_authorized_dag_ids",
return_value={"other_dag"},
)
def test_partitioned_dag_runs_filters_unreadable_dags(self, _, test_client, dag_maker, session):
schedule = PartitionedAssetTimetable(assets=Asset(uri="s3://bucket/a", name="a"))
with dag_maker(dag_id="restricted_dag", schedule=schedule, serialized=True):
EmptyOperator(task_id="t")
dag_maker.sync_dagbag_to_db()
session.add(AssetPartitionDagRun(target_dag_id="restricted_dag", partition_key="2024-06-01"))
session.commit()

resp = test_client.get("/partitioned_dag_runs?has_created_dag_run_id=false")
assert resp.status_code == 200
body = resp.json()
dag_ids = {r["dag_id"] for r in body["partitioned_dag_runs"]}
assert "restricted_dag" not in dag_ids


class TestGetPendingPartitionedDagRun:
def test_should_response_401(self, unauthenticated_test_client):
Expand Down
Loading