From 023d89bcb7e6d5f66cefe82aee16aedc749aaa9f Mon Sep 17 00:00:00 2001 From: Diogo Pan Date: Thu, 4 Jun 2026 13:56:04 +0100 Subject: [PATCH 1/2] Implement DagTaskGroupsExistence and DagTasksExistence endpoints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds two GET endpoints to the Execution API for batched existence checks against a Dag's tasks and task groups. Each takes a list of ids and returns them partitioned into 'existing' and 'missing' with 200 or 404 only when the Dag is missing. Passing an empty list works as a Dag existence probe. These allow the clients to get the actual information of a Dag, returning correct information even when the Dag hasn’t been ran once related: #40745. Co-authored-by: Diogo Callado --- .../execution_api/datamodels/dags.py | 14 +++ .../api_fastapi/execution_api/routes/dags.py | 75 +++++++++++- .../execution_api/versions/__init__.py | 8 +- .../execution_api/versions/v2026_06_30.py | 11 ++ .../execution_api/versions/head/test_dags.py | 84 +++++++++++++ .../unit/dag_processing/test_processor.py | 3 + .../tests/unit/jobs/test_triggerer_job.py | 3 + .../standard/sensors/external_task.py | 55 +++++++++ .../sensors/test_external_task_sensor.py | 62 ++++++++++ task-sdk/src/airflow/sdk/api/client.py | 52 ++++++++ .../airflow/sdk/api/datamodels/_generated.py | 18 +++ task-sdk/src/airflow/sdk/exceptions.py | 1 + .../src/airflow/sdk/execution_time/comms.py | 52 ++++++++ .../sdk/execution_time/schema/schema.json | 115 ++++++++++++++++++ .../airflow/sdk/execution_time/supervisor.py | 24 ++++ .../airflow/sdk/execution_time/task_runner.py | 30 +++++ task-sdk/src/airflow/sdk/types.py | 14 ++- task-sdk/tests/task_sdk/api/test_client.py | 76 ++++++++++++ .../execution_time/test_supervisor.py | 50 ++++++++ .../execution_time/test_task_runner.py | 46 +++++++ 20 files changed, 789 insertions(+), 4 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py index 1334e99069f3b..06b31a44d0883 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py @@ -33,3 +33,17 @@ class DagResponse(BaseModel): owners: str | None tags: list[str] next_dagrun: datetime | None + + +class DagTaskGroupsExistenceResponse(BaseModel): + """Schema for batch Dag task group existence response.""" + + existing: list[str] + missing: list[str] + + +class DagTasksExistenceResponse(BaseModel): + """Schema for batch Dag task existence response.""" + + existing: list[str] + missing: list[str] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py index 9061b4862161c..0e9c50d9ec30e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py @@ -17,10 +17,15 @@ from __future__ import annotations -from fastapi import APIRouter, HTTPException, status +from fastapi import APIRouter, HTTPException, Query, status +from airflow.api_fastapi.common.dagbag import DagBagDep, get_latest_version_of_dag from airflow.api_fastapi.common.db.common import SessionDep -from airflow.api_fastapi.execution_api.datamodels.dags import DagResponse +from airflow.api_fastapi.execution_api.datamodels.dags import ( + DagResponse, + DagTaskGroupsExistenceResponse, + DagTasksExistenceResponse, +) from airflow.models.dag import DagModel router = APIRouter() @@ -57,3 +62,69 @@ def get_dag( tags=sorted(tag.name for tag in dag_model.tags), next_dagrun=dag_model.next_dagrun, ) + + +@router.get( + "/{dag_id}/task-groups/existence", + responses={ + status.HTTP_404_NOT_FOUND: {"description": "DAG not found for the given dag_id"}, + }, +) +def get_dag_task_groups_existence( + dag_id: str, + session: SessionDep, + dag_bag: DagBagDep, + task_group_ids: list[str] = Query( + default_factory=list, description="Task group ids to check for existence" + ), +) -> DagTaskGroupsExistenceResponse: + """Get the list of existing and missing Dag task group ids from the given ids.""" + if not session.get(DagModel, dag_id): + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"The Dag with dag_id: `{dag_id}` was not found", + }, + ) + + dag = get_latest_version_of_dag(dag_bag, dag_id, session, include_reason=True) + + existing: list[str] = [] + missing: list[str] = [] + for task_group_id in task_group_ids: + (existing if task_group_id in dag.task_group_dict else missing).append(task_group_id) + + return DagTaskGroupsExistenceResponse(existing=existing, missing=missing) + + +@router.get( + "/{dag_id}/tasks/existence", + responses={ + status.HTTP_404_NOT_FOUND: {"description": "DAG not found for the given dag_id"}, + }, +) +def get_dag_tasks_existence( + dag_id: str, + session: SessionDep, + dag_bag: DagBagDep, + task_ids: list[str] = Query(default_factory=list, description="Task ids to check for existence"), +) -> DagTasksExistenceResponse: + """Get the list of existing and missing Dag task ids from the given ids.""" + if not session.get(DagModel, dag_id): + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"The Dag with dag_id: `{dag_id}` was not found", + }, + ) + + dag = get_latest_version_of_dag(dag_bag, dag_id, session, include_reason=True) + + existing: list[str] = [] + missing: list[str] = [] + for task_id in task_ids: + (existing if dag.has_task(task_id) else missing).append(task_id) + + return DagTasksExistenceResponse(existing=existing, missing=missing) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index 50ddb6985e890..0aebe8f5aadcc 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -48,12 +48,18 @@ ) from airflow.api_fastapi.execution_api.versions.v2026_06_30 import ( AddConnectionTestEndpoint, + AddDagTaskDetailsExistenceEndpoints, AddVariableKeysEndpoint, ) bundle = VersionBundle( HeadVersion(), - Version("2026-06-30", AddVariableKeysEndpoint, AddConnectionTestEndpoint), + Version( + "2026-06-30", + AddVariableKeysEndpoint, + AddConnectionTestEndpoint, + AddDagTaskDetailsExistenceEndpoints, + ), Version( "2026-06-16", AddRetryPolicyFields, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py index cc751bcc79765..f81031a3b3d27 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py @@ -37,3 +37,14 @@ class AddConnectionTestEndpoint(VersionChange): endpoint("/connection-tests/{connection_test_id}", ["PATCH"]).didnt_exist, endpoint("/connection-tests/{connection_test_id}/connection", ["GET"]).didnt_exist, ) + + +class AddDagTaskDetailsExistenceEndpoints(VersionChange): + """Add Dag task and task group existence endpoints.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + endpoint("/dags/{dag_id}/tasks/existence", ["GET"]).didnt_exist, + endpoint("/dags/{dag_id}/task-groups/existence", ["GET"]).didnt_exist, + ) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py index 78b8f74a1d6ea..85d141b3ddb03 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py @@ -25,6 +25,7 @@ from airflow.models import DagModel from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.sdk import TaskGroup from tests_common.test_utils.db import clear_db_runs @@ -126,3 +127,86 @@ def test_get_dag_defaults(self, client, session, dag_maker): "tags": [], "next_dagrun": ANY, } + + def test_get_dag_task_groups_existence_partitions(self, client, session, dag_maker): + """Test partitioning task groups into existing and missing.""" + + dag_id = "test_dag_task_groups_existence" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + with TaskGroup(group_id="grp"): + EmptyOperator(task_id="a") + session.commit() + + response = client.get( + f"/execution/dags/{dag_id}/task-groups/existence", + params={"task_group_ids": ["grp", "ghost_group"]}, + ) + + assert response.status_code == 200 + assert response.json() == {"existing": ["grp"], "missing": ["ghost_group"]} + + def test_get_dag_task_groups_existence_dag_not_found(self, client, session, dag_maker): + """Test missing Dag when checking task group existence.""" + + response = client.get( + "/execution/dags/no_such_dag/task-groups/existence", + params={"task_group_ids": ["grp"]}, + ) + + assert response.status_code == 404 + assert response.json() == { + "detail": { + "message": "The Dag with dag_id: `no_such_dag` was not found", + "reason": "not_found", + } + } + + def test_get_dag_tasks_existence_partitions(self, client, session, dag_maker): + """Test partitioning tasks into existing and missing.""" + + dag_id = "test_dag_tasks_existence" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="a") + EmptyOperator(task_id="b") + + session.commit() + + response = client.get( + f"/execution/dags/{dag_id}/tasks/existence", + params={"task_ids": ["a", "b", "ghost"]}, + ) + + assert response.status_code == 200 + assert response.json() == {"existing": ["a", "b"], "missing": ["ghost"]} + + def test_get_dag_tasks_existence_empty_list(self, client, session, dag_maker): + """Test empty task_ids returns empty partition.""" + + dag_id = "test_dag_tasks_existence_empty" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="a") + session.commit() + + response = client.get(f"/execution/dags/{dag_id}/tasks/existence") + + assert response.status_code == 200 + assert response.json() == {"existing": [], "missing": []} + + def test_get_dag_tasks_existence_dag_not_found(self, client, session, dag_maker): + """Test missing Dag when checking task existence.""" + + response = client.get( + "/execution/dags/no_such_dag/tasks/existence", + params={"task_ids": ["a"]}, + ) + + assert response.status_code == 404 + assert response.json() == { + "detail": { + "message": "The Dag with dag_id: `no_such_dag` was not found", + "reason": "not_found", + } + } diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index ef92372f57378..cc09423dc62c1 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -1954,6 +1954,7 @@ def get_type_names(union_type): "GetDagRun", "GetDagRunState", "GetDag", + "GetDagTaskGroupsExistenceGetDagTasksExistence", "GetDRCount", "GetTaskBreadcrumbs", "GetTaskRescheduleStartDate", @@ -1995,6 +1996,8 @@ def get_type_names(union_type): "DagResult", "DagRunResult", "DagRunStateResult", + "DagTaskGroupsExistenceResult", + "DagTasksExistenceResult", "DRCount", "SentFDs", "StartupDetails", diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index eab9540df3cb9..a638453c25902 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -1914,6 +1914,8 @@ def get_type_names(union_type): "GetAssetEventByAsset", "GetAssetEventByAssetAlias", "GetDagRun", + "GetDagTaskGroupsExistence", + "GetDagTasksExistence", "GetPrevSuccessfulDagRun", "GetPreviousDagRun", "GetTaskBreadcrumbs", @@ -1953,6 +1955,7 @@ def get_type_names(union_type): "AssetsByAliasResult", "AssetEventsResult", "DagRunResult", + "DagTaskGroupsExistenceResultDagTasksExistenceResult", "SentFDs", "StartupDetails", "TaskBreadcrumbsResult", diff --git a/providers/standard/src/airflow/providers/standard/sensors/external_task.py b/providers/standard/src/airflow/providers/standard/sensors/external_task.py index 8270386754f0a..8266c76435492 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py +++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py @@ -443,6 +443,9 @@ def execute(self, context: Context) -> None: dttm_filter = self._get_dttm_filter(context) if AIRFLOW_V_3_0_PLUS: + if self.check_existence and not self._has_checked_existence: + self._check_for_existence_af3(context) + self.defer( timeout=datetime.timedelta(seconds=timeout_value) if timeout_value else None, trigger=WorkflowTrigger( @@ -546,6 +549,58 @@ def _check_for_existence(self, session: Session) -> None: self._has_checked_existence = True + def _check_for_existence_af3(self, context: Context) -> None: + from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType + + ti = context["ti"] + + def _raise_if_dag_missing(exc: AirflowRuntimeError) -> None: + if exc.error.error == ErrorType.DAG_NOT_FOUND: + raise ExternalDagNotFoundError( + f"The external DAG {self.external_dag_id} does not exist." + ) from None + raise exc + + try: + # Workaround for checking Dag existence by calling get_dag_tasks_existence with an empty task list. + ti.get_dag_tasks_existence( + dag_id=self.external_dag_id, + task_ids=[], + ) + except AirflowRuntimeError as exc: + _raise_if_dag_missing(exc) + + if self.external_task_ids: + try: + tasks_existence = ti.get_dag_tasks_existence( + dag_id=self.external_dag_id, + task_ids=list(self.external_task_ids), + ) + except AirflowRuntimeError as exc: + _raise_if_dag_missing(exc) + else: + if tasks_existence.missing: + raise ExternalTaskNotFoundError( + f"The external task(s) {sorted(tasks_existence.missing)} in Dag " + f"{self.external_dag_id} do not exist." + ) + elif self.external_task_group_id: + try: + group_existence = ti.get_dag_task_groups_existence( + dag_id=self.external_dag_id, + task_group_ids=[self.external_task_group_id], + ) + except AirflowRuntimeError as exc: + _raise_if_dag_missing(exc) + else: + if group_existence.missing: + raise ExternalTaskGroupNotFoundError( + f"The external task group '{self.external_task_group_id}' in " + f"Dag '{self.external_dag_id}' does not exist." + ) + + self._has_checked_existence = True + def get_count(self, dttm_filter: Sequence[datetime.datetime], session: Session, states: list[str]) -> int: """ Get the count of records against dttm filter and states. diff --git a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py index 589bd32920e6f..f3a2ad6da5e57 100644 --- a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py +++ b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py @@ -1427,6 +1427,68 @@ def test_external_task_sensor_deferrable(self, dag_maker): assert exc.value.trigger.external_task_ids == ["test_task"] assert exc.value.trigger.logical_dates == [DEFAULT_DATE] + @pytest.mark.execution_timeout(10) + def test_external_task_sensor_deferrable_check_existence_dag_not_found(self, dag_maker): + from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + with dag_maker("test_dag_child"): + op = ExternalTaskSensor( + task_id="test_external_task_sensor_check", + external_dag_id="non_existing_dag", + check_existence=True, + deferrable=True, + ) + + self.context["ti"].get_dag_tasks_existence.side_effect = AirflowRuntimeError( + ErrorResponse(error=ErrorType.DAG_NOT_FOUND, detail={"dag_id": "non_existing_dag"}) + ) + + with pytest.raises( + ExternalDagNotFoundError, + ): + op.execute(context=self.context) + + @pytest.mark.execution_timeout(10) + def test_external_task_sensor_deferrable_check_existence_task_not_found(self, dag_maker): + from airflow.sdk.execution_time.comms import DagTasksExistenceResult + + with dag_maker("test_dag_child"): + op = ExternalTaskSensor( + task_id="test_external_task_sensor_check", + external_dag_id="test_dag_parent", + external_task_ids=["real_task", "missing_task"], + check_existence=True, + deferrable=True, + ) + + self.context["ti"].get_dag_tasks_existence.return_value = DagTasksExistenceResult( + existing=["real_task"], missing=["missing_task"] + ) + + with pytest.raises(ExternalTaskNotFoundError): + op.execute(context=self.context) + + @pytest.mark.execution_timeout(10) + def test_external_task_sensor_deferrable_check_existence_task_group_not_found(self, dag_maker): + from airflow.sdk.execution_time.comms import DagTaskGroupsExistenceResult + + with dag_maker("test_dag_child"): + op = ExternalTaskSensor( + task_id="test_external_task_sensor_check", + external_dag_id="test_dag_parent", + external_task_group_id="missing_group", + check_existence=True, + deferrable=True, + ) + + self.context["ti"].get_dag_task_groups_existence.return_value = DagTaskGroupsExistenceResult( + existing=[], missing=["missing_group"] + ) + + with pytest.raises(ExternalTaskGroupNotFoundError): + op.execute(context=self.context) + @pytest.mark.execution_timeout(10) def test_external_task_sensor_only_dag_id(self, dag_maker): """Test that the sensor works correctly when only external_dag_id is provided.""" diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 4a3160d6ecd59..a85a9b9d1c971 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -57,6 +57,8 @@ DagRun, DagRunStateResponse, DagRunType, + DagTaskGroupsExistenceResponse, + DagTasksExistenceResponse, HITLDetailRequest, HITLDetailResponse, HITLUser, @@ -984,6 +986,56 @@ def get(self, dag_id: str) -> DagResponse: resp = self.client.get(f"dags/{dag_id}") return DagResponse.model_validate_json(resp.read()) + def get_dag_task_groups_existence( + self, dag_id: str, task_group_ids: list[str] + ) -> DagTaskGroupsExistenceResponse | ErrorResponse: + """Check the existence of a Dags task groups.""" + try: + resp = self.client.get( + f"dags/{dag_id}/task-groups/existence", + params={"task_group_ids": task_group_ids}, + ) + except ServerResponseError as e: + if e.response.status_code == HTTPStatus.NOT_FOUND: + log.debug( + "Dag not found while checking task group existence", + dag_id=dag_id, + task_group_ids=task_group_ids, + detail=e.detail, + status_code=e.response.status_code, + ) + return ErrorResponse( + error=ErrorType.DAG_NOT_FOUND, + detail={"dag_id": dag_id}, + ) + raise + return DagTaskGroupsExistenceResponse.model_validate_json(resp.read()) + + def get_dag_tasks_existence( + self, dag_id: str, task_ids: list[str] + ) -> DagTasksExistenceResponse | ErrorResponse: + """Check the existence of a Dags tasks.""" + try: + resp = self.client.get( + f"dags/{dag_id}/tasks/existence", + params={"task_ids": task_ids}, + ) + except ServerResponseError as e: + if e.response.status_code == HTTPStatus.NOT_FOUND: + log.debug( + "Dag not found while checking task existence", + dag_id=dag_id, + task_ids=task_ids, + detail=e.detail, + status_code=e.response.status_code, + ) + return ErrorResponse( + error=ErrorType.DAG_NOT_FOUND, + detail={"dag_id": dag_id}, + ) + raise + return DagTasksExistenceResponse.model_validate_json(resp.read()) + class HITLOperations: """ diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index e1ac21585bf93..d84889b0e2bb2 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -179,6 +179,24 @@ class DagRunType(str, Enum): ASSET_MATERIALIZATION = "asset_materialization" +class DagTaskGroupsExistenceResponse(BaseModel): + """ + Schema for batch Dag task group existence response. + """ + + existing: Annotated[list[str], Field(title="Existing")] + missing: Annotated[list[str], Field(title="Missing")] + + +class DagTasksExistenceResponse(BaseModel): + """ + Schema for batch Dag task existence response. + """ + + existing: Annotated[list[str], Field(title="Existing")] + missing: Annotated[list[str], Field(title="Missing")] + + class HITLUser(BaseModel): """ Schema for a Human-in-the-loop users. diff --git a/task-sdk/src/airflow/sdk/exceptions.py b/task-sdk/src/airflow/sdk/exceptions.py index 0c211de4e362c..7b96b4e12a750 100644 --- a/task-sdk/src/airflow/sdk/exceptions.py +++ b/task-sdk/src/airflow/sdk/exceptions.py @@ -94,6 +94,7 @@ class ErrorType(enum.Enum): TASK_STORE_NOT_FOUND = "TASK_STORE_NOT_FOUND" ASSET_STORE_NOT_FOUND = "ASSET_STORE_NOT_FOUND" DAGRUN_ALREADY_EXISTS = "DAGRUN_ALREADY_EXISTS" + DAG_NOT_FOUND = "DAG_NOT_FOUND" # Distinct from API_SERVER_ERROR: signals an explicit 401/403 from the # Execution API. Callers like ExecutionAPISecretsBackend treat this as # a deny rather than a "not found" so the secrets-backend dispatcher diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 0e0483be70e68..42388e131c821 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -75,6 +75,8 @@ DagResponse, DagRun, DagRunStateResponse, + DagTaskGroupsExistenceResponse, + DagTasksExistenceResponse, HITLDetailRequest, InactiveAssetsResponse, PreviousTIResponse, @@ -771,6 +773,40 @@ def from_api_response(cls, dag_response: DagResponse) -> DagResult: return cls(**dag_response.model_dump(exclude_defaults=True), type="DagResult") +class DagTaskGroupsExistenceResult(BaseModel): + existing: list[str] + missing: list[str] + type: Literal["DagTaskGroupsExistenceResult"] = "DagTaskGroupsExistenceResult" + + @classmethod + def from_api_response(cls, response: DagTaskGroupsExistenceResponse) -> DagTaskGroupsExistenceResult: + """ + Create result class from API Response. + + API Response is autogenerated from the API schema, so we convert it to a + Result for communication between the Supervisor and the task + process since it needs a discriminator field. + """ + return cls(existing=response.existing, missing=response.missing, type="DagTaskGroupsExistenceResult") + + +class DagTasksExistenceResult(BaseModel): + existing: list[str] + missing: list[str] + type: Literal["DagTasksExistenceResult"] = "DagTasksExistenceResult" + + @classmethod + def from_api_response(cls, response: DagTasksExistenceResponse) -> DagTasksExistenceResult: + """ + Create result class from API Response. + + API Response is autogenerated from the API schema, so we convert it to a + Result for communication between the Supervisor and the task + process since it needs a discriminator field. + """ + return cls(existing=response.existing, missing=response.missing, type="DagTasksExistenceResult") + + ToTask = Annotated[ AssetResult | AssetsByAliasResult @@ -779,6 +815,8 @@ def from_api_response(cls, dag_response: DagResponse) -> DagResult: | ConnectionResult | DagRunResult | DagRunStateResult + | DagTaskGroupsExistenceResult + | DagTasksExistenceResult | DRCount | DagResult | ErrorResponse @@ -1192,6 +1230,18 @@ class GetDag(BaseModel): type: Literal["GetDag"] = "GetDag" +class GetDagTaskGroupsExistence(BaseModel): + dag_id: str + task_group_ids: list[str] + type: Literal["GetDagTaskGroupsExistence"] = "GetDagTaskGroupsExistence" + + +class GetDagTasksExistence(BaseModel): + dag_id: str + task_ids: list[str] + type: Literal["GetDagTasksExistence"] = "GetDagTasksExistence" + + ToSupervisor = Annotated[ ClearAssetStoreByName | ClearAssetStoreByUri @@ -1211,6 +1261,8 @@ class GetDag(BaseModel): | GetConnection | GetDagRun | GetDagRunState + | GetDagTaskGroupsExistence + | GetDagTasksExistence | GetDRCount | GetDag | GetPrevSuccessfulDagRun diff --git a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json index 094c3f94d9ee9..022aff97c32fe 100644 --- a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json +++ b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json @@ -1287,6 +1287,66 @@ "title": "DagRunType", "type": "string" }, + "DagTaskGroupsExistenceResult": { + "properties": { + "existing": { + "items": { + "type": "string" + }, + "title": "Existing", + "type": "array" + }, + "missing": { + "items": { + "type": "string" + }, + "title": "Missing", + "type": "array" + }, + "type": { + "const": "DagTaskGroupsExistenceResult", + "default": "DagTaskGroupsExistenceResult", + "title": "Type", + "type": "string" + } + }, + "required": [ + "existing", + "missing" + ], + "title": "DagTaskGroupsExistenceResult", + "type": "object" + }, + "DagTasksExistenceResult": { + "properties": { + "existing": { + "items": { + "type": "string" + }, + "title": "Existing", + "type": "array" + }, + "missing": { + "items": { + "type": "string" + }, + "title": "Missing", + "type": "array" + }, + "type": { + "const": "DagTasksExistenceResult", + "default": "DagTasksExistenceResult", + "title": "Type", + "type": "string" + } + }, + "required": [ + "existing", + "missing" + ], + "title": "DagTasksExistenceResult", + "type": "object" + }, "DeferTask": { "additionalProperties": false, "description": "Update a task instance state to deferred.", @@ -1639,6 +1699,7 @@ "TASK_STORE_NOT_FOUND", "ASSET_STORE_NOT_FOUND", "DAGRUN_ALREADY_EXISTS", + "DAG_NOT_FOUND", "PERMISSION_DENIED", "GENERIC_ERROR", "API_SERVER_ERROR" @@ -2045,6 +2106,60 @@ "title": "GetDagRunState", "type": "object" }, + "GetDagTaskGroupsExistence": { + "properties": { + "dag_id": { + "title": "Dag Id", + "type": "string" + }, + "task_group_ids": { + "items": { + "type": "string" + }, + "title": "Task Group Ids", + "type": "array" + }, + "type": { + "const": "GetDagTaskGroupsExistence", + "default": "GetDagTaskGroupsExistence", + "title": "Type", + "type": "string" + } + }, + "required": [ + "dag_id", + "task_group_ids" + ], + "title": "GetDagTaskGroupsExistence", + "type": "object" + }, + "GetDagTasksExistence": { + "properties": { + "dag_id": { + "title": "Dag Id", + "type": "string" + }, + "task_ids": { + "items": { + "type": "string" + }, + "title": "Task Ids", + "type": "array" + }, + "type": { + "const": "GetDagTasksExistence", + "default": "GetDagTasksExistence", + "title": "Type", + "type": "string" + } + }, + "required": [ + "dag_id", + "task_ids" + ], + "title": "GetDagTasksExistence", + "type": "object" + }, "GetHITLDetailResponse": { "description": "Get the response content part of a Human-in-the-loop response.", "properties": { diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 0005c36bc7e9b..7b247150615f2 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -70,6 +70,8 @@ CreateHITLDetailPayload, DagResult, DagRunResult, + DagTaskGroupsExistenceResult, + DagTasksExistenceResult, DeferTask, DeleteAssetStoreByName, DeleteAssetStoreByUri, @@ -88,6 +90,8 @@ GetDag, GetDagRun, GetDagRunState, + GetDagTaskGroupsExistence, + GetDagTasksExistence, GetDRCount, GetPreviousDagRun, GetPreviousTI, @@ -1822,6 +1826,26 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: dag_id=msg.dag_id, ) resp = DagResult.from_api_response(dag) + elif isinstance(msg, GetDagTaskGroupsExistence): + groups_existence = self.client.dags.get_dag_task_groups_existence( + dag_id=msg.dag_id, + task_group_ids=msg.task_group_ids, + ) + resp = ( + groups_existence + if isinstance(groups_existence, ErrorResponse) + else DagTaskGroupsExistenceResult.from_api_response(groups_existence) + ) + elif isinstance(msg, GetDagTasksExistence): + tasks_existence = self.client.dags.get_dag_tasks_existence( + dag_id=msg.dag_id, + task_ids=msg.task_ids, + ) + resp = ( + tasks_existence + if isinstance(tasks_existence, ErrorResponse) + else DagTasksExistenceResult.from_api_response(tasks_existence) + ) elif isinstance(msg, GetTaskStore): task_store = self.client.task_store.get(msg.ti_id, msg.key) resp = ( diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 22ac90405e027..2c4522ac385ed 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -85,11 +85,15 @@ CommsDecoder, DagResult, DagRunStateResult, + DagTaskGroupsExistenceResult, + DagTasksExistenceResult, DeferTask, DRCount, ErrorResponse, GetDag, GetDagRunState, + GetDagTaskGroupsExistence, + GetDagTasksExistence, GetDRCount, GetPreviousDagRun, GetPreviousTI, @@ -755,6 +759,32 @@ def get_dag(dag_id: str) -> DagResult: return response + @staticmethod + def get_dag_tasks_existence(dag_id: str, task_ids: list[str]) -> DagTasksExistenceResult: + response = SUPERVISOR_COMMS.send(msg=GetDagTasksExistence(dag_id=dag_id, task_ids=task_ids)) + + if isinstance(response, ErrorResponse): + raise AirflowRuntimeError(response) + + if TYPE_CHECKING: + assert isinstance(response, DagTasksExistenceResult) + + return response + + @staticmethod + def get_dag_task_groups_existence(dag_id: str, task_group_ids: list[str]) -> DagTaskGroupsExistenceResult: + response = SUPERVISOR_COMMS.send( + msg=GetDagTaskGroupsExistence(dag_id=dag_id, task_group_ids=task_group_ids) + ) + + if isinstance(response, ErrorResponse): + raise AirflowRuntimeError(response) + + if TYPE_CHECKING: + assert isinstance(response, DagTaskGroupsExistenceResult) + + return response + @property def log_url(self) -> str: run_id = quote(self.run_id) diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index 711dbe71ca46c..104030e28cbe4 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -52,7 +52,11 @@ ) from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.mappedoperator import MappedOperator - from airflow.sdk.execution_time.comms import DagResult + from airflow.sdk.execution_time.comms import ( + DagResult, + DagTaskGroupsExistenceResult, + DagTasksExistenceResult, + ) Operator: TypeAlias = BaseOperator | MappedOperator @@ -219,6 +223,14 @@ def get_dagrun_state(dag_id: str, run_id: str) -> str: ... @staticmethod def get_dag(dag_id: str) -> DagResult: ... + @staticmethod + def get_dag_tasks_existence(dag_id: str, task_ids: list[str]) -> DagTasksExistenceResult: ... + + @staticmethod + def get_dag_task_groups_existence( + dag_id: str, task_group_ids: list[str] + ) -> DagTaskGroupsExistenceResult: ... + # Public alias for RuntimeTaskInstanceProtocol class TaskInstance(RuntimeTaskInstanceProtocol): diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index ce4afe51af5fe..fa5f4bcdb7d3e 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -41,6 +41,8 @@ DagResponse, DagRunState, DagRunStateResponse, + DagTaskGroupsExistenceResponse, + DagTasksExistenceResponse, HITLDetailRequest, HITLDetailResponse, HITLUser, @@ -1784,6 +1786,80 @@ def handle_request(request: httpx.Request) -> httpx.Response: with pytest.raises(ServerResponseError): client.dags.get(dag_id="test_dag") + def test_get_dag_task_groups_existence(self): + """Test that the client can partition task group ids by existence.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dags/test_dag/task-groups/existence": + return httpx.Response( + status_code=200, + json={"existing": ["g1"], "missing": ["g2"]}, + ) + return httpx.Response(status_code=200) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dags.get_dag_task_groups_existence(dag_id="test_dag", task_group_ids=["g1", "g2"]) + + assert result == DagTaskGroupsExistenceResponse(existing=["g1"], missing=["g2"]) + + def test_get_dag_task_groups_existence_dag_not_found(self): + """Test a missing dag while checking task group existence.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dags/missing_dag/task-groups/existence": + return httpx.Response( + status_code=404, + json={ + "detail": { + "message": "The Dag with dag_id: `missing_dag` was not found", + "reason": "not_found", + } + }, + ) + return httpx.Response(status_code=200) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dags.get_dag_task_groups_existence(dag_id="missing_dag", task_group_ids=["g1"]) + + assert result == ErrorResponse(error=ErrorType.DAG_NOT_FOUND, detail={"dag_id": "missing_dag"}) + + def test_get_dag_tasks_existence(self): + """Test that the client can partition task ids by existence.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dags/test_dag/tasks/existence": + return httpx.Response( + status_code=200, + json={"existing": ["a"], "missing": ["b"]}, + ) + return httpx.Response(status_code=200) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dags.get_dag_tasks_existence(dag_id="test_dag", task_ids=["a", "b"]) + + assert result == DagTasksExistenceResponse(existing=["a"], missing=["b"]) + + def test_get_dag_tasks_existence_dag_not_found(self): + """Test a missing dag while checking task existence.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dags/missing_dag/tasks/existence": + return httpx.Response( + status_code=404, + json={ + "detail": { + "message": "The Dag with dag_id: `missing_dag` was not found", + "reason": "not_found", + } + }, + ) + return httpx.Response(status_code=200) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dags.get_dag_tasks_existence(dag_id="missing_dag", task_ids=["a"]) + + assert result == ErrorResponse(error=ErrorType.DAG_NOT_FOUND, detail={"dag_id": "missing_dag"}) + class TestTaskStateOperations: TI_ID = uuid7() diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index b72fbdb1e30b6..64ce77a7cf43b 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -83,6 +83,8 @@ DagResult, DagRunResult, DagRunStateResult, + DagTaskGroupsExistenceResult, + DagTasksExistenceResult, DeferTask, DeleteAssetStoreByName, DeleteAssetStoreByUri, @@ -102,6 +104,8 @@ GetDag, GetDagRun, GetDagRunState, + GetDagTaskGroupsExistence, + GetDagTasksExistence, GetDRCount, GetHITLDetailResponse, GetPreviousDagRun, @@ -2836,6 +2840,52 @@ class RequestTestCase: ), test_id="get_dag", ), + RequestTestCase( + message=GetDagTaskGroupsExistence( + dag_id="test_dag", + task_group_ids=["group_a", "group_b"], + ), + expected_body={ + "existing": ["group_a"], + "missing": ["group_b"], + "type": "DagTaskGroupsExistenceResult", + }, + client_mock=ClientMock( + method_path="dags.get_dag_task_groups_existence", + kwargs={ + "dag_id": "test_dag", + "task_group_ids": ["group_a", "group_b"], + }, + response=DagTaskGroupsExistenceResult( + existing=["group_a"], + missing=["group_b"], + ), + ), + test_id="get_dag_task_groups_existence", + ), + RequestTestCase( + message=GetDagTasksExistence( + dag_id="test_dag", + task_ids=["task_a", "task_b"], + ), + expected_body={ + "existing": ["task_a"], + "missing": ["task_b"], + "type": "DagTasksExistenceResult", + }, + client_mock=ClientMock( + method_path="dags.get_dag_tasks_existence", + kwargs={ + "dag_id": "test_dag", + "task_ids": ["task_a", "task_b"], + }, + response=DagTasksExistenceResult( + existing=["task_a"], + missing=["task_b"], + ), + ), + test_id="get_dag_tasks_existence", + ), RequestTestCase( message=GetTaskStore(ti_id=TI_ID, key="job_id"), test_id="get_task_store", diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index f8953dc232970..b32203ae84950 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -101,6 +101,8 @@ ConnectionResult, DagResult, DagRunStateResult, + DagTaskGroupsExistenceResult, + DagTasksExistenceResult, DeferTask, DeleteAssetStoreByName, DeleteTaskStore, @@ -113,6 +115,8 @@ GetConnection, GetDag, GetDagRunState, + GetDagTaskGroupsExistence, + GetDagTasksExistence, GetDRCount, GetPreviousDagRun, GetPreviousTI, @@ -3404,6 +3408,48 @@ def test_get_dag(self, mock_supervisor_comms): assert response.dag_id == "test_dag" assert response.is_paused is False + def test_get_task_groups_existence(self, mock_supervisor_comms): + """Test that get_task_group sends the correct request and returns the partitioned task group existence.""" + mock_supervisor_comms.send.return_value = DagTaskGroupsExistenceResult( + existing=["group_a"], + missing=["group_b"], + ) + + response = RuntimeTaskInstance.get_dag_task_groups_existence( + dag_id="test_dag", + task_group_ids=["group_a", "group_b"], + ) + + mock_supervisor_comms.send.assert_called_once_with( + msg=GetDagTaskGroupsExistence( + dag_id="test_dag", + task_group_ids=["group_a", "group_b"], + ) + ) + assert response.existing == ["group_a"] + assert response.missing == ["group_b"] + + def test_get_tasks_existence(self, mock_supervisor_comms): + """Test that get_task sends the correct request and returns the partitioned task existence.""" + mock_supervisor_comms.send.return_value = DagTasksExistenceResult( + existing=["task_a"], + missing=["task_b"], + ) + + response = RuntimeTaskInstance.get_dag_tasks_existence( + dag_id="test_dag", + task_ids=["task_a", "task_b"], + ) + + mock_supervisor_comms.send.assert_called_once_with( + msg=GetDagTasksExistence( + dag_id="test_dag", + task_ids=["task_a", "task_b"], + ) + ) + assert response.existing == ["task_a"] + assert response.missing == ["task_b"] + class TestXComAfterTaskExecution: @pytest.mark.parametrize( From d17ea3d59692930a62071d55d8b7f5f10f735028 Mon Sep 17 00:00:00 2001 From: Diogo Pan Date: Mon, 1 Jun 2026 13:09:35 +0100 Subject: [PATCH 2/2] fix: exempt new ExternalTaskSensor function for versions below 3.3 The new existence check and tests for ExternalTaskSensor uses endpoints that were newly added, so these will be excluded for lower versions. --- .../tests/unit/dag_processing/test_processor.py | 3 ++- airflow-core/tests/unit/jobs/test_triggerer_job.py | 3 ++- .../providers/standard/sensors/external_task.py | 12 +++++++----- .../standard/sensors/test_external_task_sensor.py | 10 +++++++++- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index cc09423dc62c1..26b5b6049ffed 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -1954,7 +1954,8 @@ def get_type_names(union_type): "GetDagRun", "GetDagRunState", "GetDag", - "GetDagTaskGroupsExistenceGetDagTasksExistence", + "GetDagTaskGroupsExistence", + "GetDagTasksExistence", "GetDRCount", "GetTaskBreadcrumbs", "GetTaskRescheduleStartDate", diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index a638453c25902..e801144414d23 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -1955,7 +1955,8 @@ def get_type_names(union_type): "AssetsByAliasResult", "AssetEventsResult", "DagRunResult", - "DagTaskGroupsExistenceResultDagTasksExistenceResult", + "DagTaskGroupsExistenceResult", + "DagTasksExistenceResult", "SentFDs", "StartupDetails", "TaskBreadcrumbsResult", diff --git a/providers/standard/src/airflow/providers/standard/sensors/external_task.py b/providers/standard/src/airflow/providers/standard/sensors/external_task.py index 8266c76435492..63ae86c47970c 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py +++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py @@ -46,6 +46,7 @@ from airflow.providers.standard.version_compat import ( AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS, + AIRFLOW_V_3_3_PLUS, BaseOperator, ) from airflow.utils.file import correct_maybe_zipped @@ -443,8 +444,11 @@ def execute(self, context: Context) -> None: dttm_filter = self._get_dttm_filter(context) if AIRFLOW_V_3_0_PLUS: - if self.check_existence and not self._has_checked_existence: - self._check_for_existence_af3(context) + # It relies on endpoints (GetDagTasksExistence and GetDagTaskGroupsExistence) + # that don't exist on previous versions. + if AIRFLOW_V_3_3_PLUS: + if self.check_existence and not self._has_checked_existence: + self._check_for_existence_af3(context) self.defer( timeout=datetime.timedelta(seconds=timeout_value) if timeout_value else None, @@ -556,9 +560,7 @@ def _check_for_existence_af3(self, context: Context) -> None: def _raise_if_dag_missing(exc: AirflowRuntimeError) -> None: if exc.error.error == ErrorType.DAG_NOT_FOUND: - raise ExternalDagNotFoundError( - f"The external DAG {self.external_dag_id} does not exist." - ) from None + raise ExternalDagNotFoundError(f"The external DAG {self.external_dag_id} does not exist.") raise exc try: diff --git a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py index f3a2ad6da5e57..1987c265c6918 100644 --- a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py +++ b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py @@ -62,7 +62,12 @@ from tests_common.test_utils.dag import create_scheduler_dag, sync_dag_to_db, sync_dags_to_db from tests_common.test_utils.db import clear_db_runs from tests_common.test_utils.mock_operators import MockOperator -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_2_PLUS +from tests_common.test_utils.version_compat import ( + AIRFLOW_V_3_0_PLUS, + AIRFLOW_V_3_1_PLUS, + AIRFLOW_V_3_2_PLUS, + AIRFLOW_V_3_3_PLUS, +) if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion @@ -1427,6 +1432,7 @@ def test_external_task_sensor_deferrable(self, dag_maker): assert exc.value.trigger.external_task_ids == ["test_task"] assert exc.value.trigger.logical_dates == [DEFAULT_DATE] + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Existence-check endpoints require Airflow 3.3+") @pytest.mark.execution_timeout(10) def test_external_task_sensor_deferrable_check_existence_dag_not_found(self, dag_maker): from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType @@ -1449,6 +1455,7 @@ def test_external_task_sensor_deferrable_check_existence_dag_not_found(self, dag ): op.execute(context=self.context) + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Existence-check endpoints require Airflow 3.3+") @pytest.mark.execution_timeout(10) def test_external_task_sensor_deferrable_check_existence_task_not_found(self, dag_maker): from airflow.sdk.execution_time.comms import DagTasksExistenceResult @@ -1469,6 +1476,7 @@ def test_external_task_sensor_deferrable_check_existence_task_not_found(self, da with pytest.raises(ExternalTaskNotFoundError): op.execute(context=self.context) + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Existence-check endpoints require Airflow 3.3+") @pytest.mark.execution_timeout(10) def test_external_task_sensor_deferrable_check_existence_task_group_not_found(self, dag_maker): from airflow.sdk.execution_time.comms import DagTaskGroupsExistenceResult