Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from airflow.models.dag_favorite import DagFavorite
from airflow.models.dagrun import DagRun
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk.definitions.callback import AsyncCallback
from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType

Expand All @@ -36,6 +38,8 @@
clear_db_assets,
clear_db_connections,
clear_db_dags,
clear_db_deadline,
clear_db_deadline_alert,
clear_db_runs,
clear_db_serialized_dags,
)
Expand All @@ -58,6 +62,18 @@
ASSET_DEP_DAG2_ID = "test_asset_dep_dag2"
TASK_ID = "op1"
UTC_JSON_REPR = "UTC" if pendulum.__version__.startswith("3") else "Timezone('UTC')"

_CALLBACK_PATH = "tests.unit.api_fastapi.core_api.routes.public.test_dags._noop_callback"


async def _noop_callback(**kwargs):
"""No-op async callback for deadline alert tests."""


def _deadline_callback() -> AsyncCallback:
return AsyncCallback(_CALLBACK_PATH)


API_PREFIX = "/dags"
DAG3_START_DATE_1 = datetime(2018, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
DAG3_START_DATE_2 = datetime(2019, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
Expand All @@ -68,6 +84,8 @@ class TestDagEndpoint:

@staticmethod
def _clear_db():
clear_db_deadline()
clear_db_deadline_alert()
clear_db_connections()
clear_db_runs()
clear_db_dags()
Expand Down Expand Up @@ -627,6 +645,57 @@ def test_get_dags_no_n_plus_one_queries(self, session, test_client):
f"({first_query_count} → {second_query_count}), suggesting n+1 queries for tags"
)

def test_get_dags_includes_dag_with_deadline(self, dag_maker, test_client, session):
"""DAGs created with a deadline appear in the list endpoint."""
deadline_dag_id = "test_dag_with_deadline"
with dag_maker(
deadline_dag_id,
schedule=None,
start_date=DAG1_START_DATE,
deadline=DeadlineAlert(
reference=DeadlineReference.DAGRUN_LOGICAL_DATE,
interval=timedelta(hours=1),
callback=_deadline_callback(),
),
):
EmptyOperator(task_id="task1")
dag_maker.sync_dagbag_to_db()

response = test_client.get("/dags")
assert response.status_code == 200
body = response.json()
dag_ids = [dag["dag_id"] for dag in body["dags"]]
assert deadline_dag_id in dag_ids

def test_get_dags_includes_dag_with_multiple_deadlines(self, dag_maker, test_client, session):
"""DAGs created with multiple deadline alerts appear in the list endpoint."""
deadline_dag_id = "test_dag_multi_deadline"
with dag_maker(
deadline_dag_id,
schedule=None,
start_date=DAG1_START_DATE,
deadline=[
DeadlineAlert(
reference=DeadlineReference.DAGRUN_LOGICAL_DATE,
interval=timedelta(hours=1),
callback=_deadline_callback(),
),
DeadlineAlert(
reference=DeadlineReference.DAGRUN_QUEUED_AT,
interval=timedelta(hours=2),
callback=_deadline_callback(),
),
],
):
EmptyOperator(task_id="task1")
dag_maker.sync_dagbag_to_db()

response = test_client.get("/dags")
assert response.status_code == 200
body = response.json()
dag_ids = [dag["dag_id"] for dag in body["dags"]]
assert deadline_dag_id in dag_ids


class TestPatchDag(TestDagEndpoint):
"""Unit tests for Patch DAG."""
Expand Down Expand Up @@ -725,6 +794,30 @@ def test_patch_dag_audit_log_payload(self, test_client, is_paused_value, session
session, dag_id=DAG1_ID, event="patch_dag", logical_date=None, expected_extra=expected_extra
)

def test_patch_dag_with_deadline(self, dag_maker, test_client, session):
"""Pausing a DAG that has deadline alerts succeeds."""
deadline_dag_id = "test_patch_deadline_dag"
with dag_maker(
deadline_dag_id,
schedule=None,
start_date=DAG1_START_DATE,
deadline=DeadlineAlert(
reference=DeadlineReference.DAGRUN_LOGICAL_DATE,
interval=timedelta(hours=1),
callback=_deadline_callback(),
),
):
EmptyOperator(task_id="task1")
dag_maker.sync_dagbag_to_db()

response = test_client.patch(f"/dags/{deadline_dag_id}", json={"is_paused": True})
assert response.status_code == 200
assert response.json()["is_paused"] is True

response = test_client.patch(f"/dags/{deadline_dag_id}", json={"is_paused": False})
assert response.status_code == 200
assert response.json()["is_paused"] is False


class TestPatchDags(TestDagEndpoint):
"""Unit tests for Patch DAGs."""
Expand Down Expand Up @@ -870,6 +963,33 @@ def test_patch_dags_should_response_403(self, unauthorized_test_client):
response = unauthorized_test_client.patch("/dags", json={"is_paused": True})
assert response.status_code == 403

def test_patch_dags_includes_dag_with_deadline(self, dag_maker, test_client, session):
"""Bulk-patching DAGs includes DAGs that have deadline alerts."""
deadline_dag_id = "test_bulk_patch_deadline"
with dag_maker(
deadline_dag_id,
schedule=None,
start_date=DAG1_START_DATE,
deadline=DeadlineAlert(
reference=DeadlineReference.DAGRUN_LOGICAL_DATE,
interval=timedelta(hours=1),
callback=_deadline_callback(),
),
):
EmptyOperator(task_id="task1")
dag_maker.sync_dagbag_to_db()

response = test_client.patch(
"/dags",
json={"is_paused": True},
params={"dag_id_pattern": "~"},
)
assert response.status_code == 200
body = response.json()
patched_ids = {dag["dag_id"] for dag in body["dags"]}
assert deadline_dag_id in patched_ids
assert all(dag["is_paused"] for dag in body["dags"] if dag["dag_id"] == deadline_dag_id)


class TestFavoriteDag(TestDagEndpoint):
"""Unit tests for favoriting a DAG."""
Expand Down Expand Up @@ -1251,6 +1371,28 @@ def test_dag_details_includes_active_runs_count(self, session, test_client):
assert isinstance(body["active_runs_count"], int)
assert body["active_runs_count"] == 0

@pytest.mark.usefixtures("configure_git_connection_for_dag_bundle")
def test_dag_details_with_deadline(self, dag_maker, test_client, session):
"""DAG details endpoint returns 200 for a DAG with a deadline alert."""
deadline_dag_id = "test_details_deadline"
with dag_maker(
deadline_dag_id,
schedule=None,
start_date=DAG1_START_DATE,
deadline=DeadlineAlert(
reference=DeadlineReference.DAGRUN_LOGICAL_DATE,
interval=timedelta(hours=1),
callback=_deadline_callback(),
),
):
EmptyOperator(task_id="task1")
dag_maker.sync_dagbag_to_db()

response = test_client.get(f"/dags/{deadline_dag_id}/details")
assert response.status_code == 200
body = response.json()
assert body["dag_id"] == deadline_dag_id


class TestGetDag(TestDagEndpoint):
"""Unit tests for Get DAG."""
Expand Down Expand Up @@ -1351,6 +1493,28 @@ def test_get_dag_should_response_403(self, unauthorized_test_client):
response = unauthorized_test_client.get(f"/dags/{DAG1_ID}")
assert response.status_code == 403

def test_get_dag_with_deadline(self, dag_maker, test_client, session):
"""Single DAG endpoint returns 200 for a DAG with a deadline alert."""
deadline_dag_id = "test_get_deadline_dag"
with dag_maker(
deadline_dag_id,
schedule=None,
start_date=DAG1_START_DATE,
deadline=DeadlineAlert(
reference=DeadlineReference.DAGRUN_LOGICAL_DATE,
interval=timedelta(hours=1),
callback=_deadline_callback(),
),
):
EmptyOperator(task_id="task1")
dag_maker.sync_dagbag_to_db()

response = test_client.get(f"/dags/{deadline_dag_id}")
assert response.status_code == 200
body = response.json()
assert body["dag_id"] == deadline_dag_id
assert body["is_paused"] is False


class TestDeleteDAG(TestDagEndpoint):
"""Unit tests for Delete DAG."""
Expand Down Expand Up @@ -1421,6 +1585,34 @@ def test_delete_dag(
if details_response.status_code == 204:
check_last_log(session, dag_id=dag_id, event="delete_dag", logical_date=None)

def test_delete_dag_with_deadline(self, dag_maker, test_client, session):
"""Deleting a DAG with deadline alerts succeeds without FK constraint errors."""
deadline_dag_id = "test_delete_deadline_dag"
with dag_maker(
deadline_dag_id,
schedule=None,
start_date=datetime(2024, 10, 10, tzinfo=timezone.utc),
deadline=DeadlineAlert(
reference=DeadlineReference.DAGRUN_LOGICAL_DATE,
interval=timedelta(hours=1),
callback=_deadline_callback(),
),
):
EmptyOperator(task_id="task1")
dag_maker.sync_dagbag_to_db()

# Verify DAG exists first
response = test_client.get(f"{API_PREFIX}/{deadline_dag_id}")
assert response.status_code == 200

# Delete the DAG
delete_response = test_client.delete(f"{API_PREFIX}/{deadline_dag_id}")
assert delete_response.status_code == 204

# Verify it's gone
details_response = test_client.get(f"{API_PREFIX}/{deadline_dag_id}/details")
assert details_response.status_code == 404

def test_delete_dag_should_response_401(self, unauthenticated_test_client):
response = unauthenticated_test_client.delete(f"{API_PREFIX}/{DAG1_ID}")
assert response.status_code == 401
Expand Down
Loading