From 5bad0bc006d654e617292a410647aa74168ae601 Mon Sep 17 00:00:00 2001 From: Ashwin Upadhyay Date: Wed, 20 May 2026 23:10:07 +0530 Subject: [PATCH] Fix CloudComposerExternalTaskSensor success on empty execution window CloudComposerExternalTaskSensor and CloudComposerExternalTaskTrigger reported success whenever no in-window task instance was in a disallowed state. When every returned task instance fell outside the requested execution_range, the helper still returned success instead of continuing to wait for a relevant in-window task instance. _check_task_instances_states now only returns success when at least one task instance is inside the requested window and all in-window task instances satisfy the expected state set, mirroring the equivalent CloudComposerDAGRunSensor fix. closes: #67051 --- .../google/cloud/sensors/cloud_composer.py | 18 +-- .../google/cloud/triggers/cloud_composer.py | 18 +-- .../cloud/sensors/test_cloud_composer.py | 90 ++++++++++++++- .../cloud/triggers/test_cloud_composer.py | 105 +++++++++++++++++- 4 files changed, 211 insertions(+), 20 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py index 713fc3d8ea3e7..def0e8e92ae75 100644 --- a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py +++ b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py @@ -562,16 +562,16 @@ def _check_task_instances_states( end_date: datetime, states: Iterable[str], ) -> bool: + found_task_instances_in_window = False for task_instance in task_instances: - if ( - start_date.timestamp() - < parser.parse( - task_instance["execution_date" if self._composer_airflow_version < 3 else "logical_date"] - ).timestamp() - < end_date.timestamp() - ) and task_instance["state"] not in states: - return False - return True + execution_date = parser.parse( + task_instance["execution_date" if self._composer_airflow_version < 3 else "logical_date"] + ) + if start_date.timestamp() < execution_date.timestamp() < end_date.timestamp(): + found_task_instances_in_window = True + if task_instance["state"] not in states: + return False + return found_task_instances_in_window def _get_composer_airflow_version(self) -> int: """Return Composer Airflow version.""" diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py index 040e35f2a3859..704972d037010 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py @@ -444,16 +444,16 @@ def _check_task_instances_states( end_date: datetime, states: Iterable[str], ) -> bool: + found_task_instances_in_window = False for task_instance in task_instances: - if ( - start_date.timestamp() - < parser.parse( - task_instance["execution_date" if self.composer_airflow_version < 3 else "logical_date"] - ).timestamp() - < end_date.timestamp() - ) and task_instance["state"] not in states: - return False - return True + execution_date = parser.parse( + task_instance["execution_date" if self.composer_airflow_version < 3 else "logical_date"] + ) + if start_date.timestamp() < execution_date.timestamp() < end_date.timestamp(): + found_task_instances_in_window = True + if task_instance["state"] not in states: + return False + return found_task_instances_in_window def _get_async_hook(self) -> CloudComposerAsyncHook: return CloudComposerAsyncHook( diff --git a/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py b/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py index 21a545a9dc373..76fa3f51d4176 100644 --- a/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py +++ b/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py @@ -18,7 +18,7 @@ from __future__ import annotations import json -from datetime import datetime +from datetime import datetime, timezone from unittest import mock import pytest @@ -350,3 +350,91 @@ def test_composer_external_task_group_id_wait_not_ready(self, mock_hook, compose task._composer_airflow_version = composer_airflow_version assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)}) + + @pytest.mark.parametrize("composer_airflow_version", [2, 3]) + @mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook") + def test_wait_not_ready_when_all_task_instances_outside_window(self, mock_hook, composer_airflow_version): + # All returned task instances are dated 2024-05-22, which is outside the + # execution window derived from a 2024-06-01 logical date. + mock_hook.return_value.get_task_instances.return_value = TEST_GET_TASK_INSTANCES_RESULT( + "success", + "execution_date" if composer_airflow_version < 3 else "logical_date", + TEST_COMPOSER_EXTERNAL_TASK_ID, + ) + + task = CloudComposerExternalTaskSensor( + task_id="task-id", + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + environment_id=TEST_ENVIRONMENT_ID, + composer_external_dag_id="test_dag_id", + allowed_states=["success"], + ) + task._composer_airflow_version = composer_airflow_version + + assert not task.poke(context={"logical_date": datetime(2024, 6, 1, 0, 0, 0)}) + + @pytest.mark.parametrize("composer_airflow_version", [2, 3]) + @mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook") + def test_wait_ready_when_in_window_instance_present_with_out_of_window_instances( + self, mock_hook, composer_airflow_version + ): + date_key = "execution_date" if composer_airflow_version < 3 else "logical_date" + mock_hook.return_value.get_task_instances.return_value = { + "task_instances": [ + { + "task_id": TEST_COMPOSER_EXTERNAL_TASK_ID, + "dag_id": "test_dag_id", + "state": "running", + date_key: "2024-05-20T11:10:00+00:00", + }, + { + "task_id": TEST_COMPOSER_EXTERNAL_TASK_ID, + "dag_id": "test_dag_id", + "state": "success", + date_key: "2024-05-22T11:10:00+00:00", + }, + ], + "total_entries": 2, + } + + task = CloudComposerExternalTaskSensor( + task_id="task-id", + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + environment_id=TEST_ENVIRONMENT_ID, + composer_external_dag_id="test_dag_id", + allowed_states=["success"], + ) + task._composer_airflow_version = composer_airflow_version + + assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)}) + + @pytest.mark.parametrize("composer_airflow_version", [2, 3]) + @mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook") + def test_wait_not_ready_when_task_instance_on_window_boundary(self, mock_hook, composer_airflow_version): + # The task instance is dated exactly on the window start (start of the + # range is exclusive), so it must not be treated as in-window. + mock_hook.return_value.get_task_instances.return_value = TEST_GET_TASK_INSTANCES_RESULT( + "success", + "execution_date" if composer_airflow_version < 3 else "logical_date", + TEST_COMPOSER_EXTERNAL_TASK_ID, + ) + + # Window start is set exactly to the task instance date; the start of the + # range is exclusive, so the task instance must not be treated as in-window. + task = CloudComposerExternalTaskSensor( + task_id="task-id", + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + environment_id=TEST_ENVIRONMENT_ID, + composer_external_dag_id="test_dag_id", + allowed_states=["success"], + execution_range=[ + datetime(2024, 5, 22, 11, 10, 0, tzinfo=timezone.utc), + datetime(2024, 5, 22, 12, 0, 0, tzinfo=timezone.utc), + ], + ) + task._composer_airflow_version = composer_airflow_version + + assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)}) diff --git a/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py b/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py index f19fb6e2ca414..374471577878f 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py @@ -17,8 +17,9 @@ from __future__ import annotations -from datetime import datetime +from datetime import datetime, timezone from unittest import mock +from unittest.mock import AsyncMock import pytest @@ -209,3 +210,105 @@ def test_serialize(self, external_task_trigger): }, ) assert actual_data == expected_data + + @staticmethod + def _build_task_instances_result( + composer_airflow_version: int, task_instances: list[tuple[str, str]] + ) -> dict: + date_key = "execution_date" if composer_airflow_version < 3 else "logical_date" + return { + "task_instances": [ + { + "task_id": TEST_COMPOSER_EXTERNAL_TASK_IDS[0], + "dag_id": TEST_COMPOSER_DAG_ID, + "state": state, + date_key: logical_date, + } + for state, logical_date in task_instances + ], + "total_entries": len(task_instances), + } + + @pytest.mark.parametrize("composer_airflow_version", [2, 3]) + @pytest.mark.asyncio + @mock.patch( + "airflow.providers.google.cloud.triggers.cloud_composer.asyncio.sleep", new_callable=AsyncMock + ) + async def test_trigger_keeps_polling_when_only_out_of_window_task_instances( + self, mock_sleep, composer_airflow_version + ): + hook = AsyncMock() + environment = mock.Mock() + environment.config.airflow_uri = "https://composer.example" + hook.get_environment.return_value = environment + # First poll returns only an out-of-window task instance, so the trigger + # must keep polling instead of yielding success. + hook.get_task_instances.side_effect = [ + self._build_task_instances_result( + composer_airflow_version, [("success", "2024-03-23T11:10:00+00:00")] + ), + self._build_task_instances_result( + composer_airflow_version, [("success", "2024-03-22T11:10:00+00:00")] + ), + ] + trigger = CloudComposerExternalTaskTrigger( + project_id=TEST_PROJECT_ID, + region=TEST_LOCATION, + environment_id=TEST_ENVIRONMENT_ID, + start_date=datetime(2024, 3, 22, 11, 0, 0, tzinfo=timezone.utc), + end_date=datetime(2024, 3, 22, 12, 0, 0, tzinfo=timezone.utc), + allowed_states=TEST_ALLOWED_STATES, + skipped_states=TEST_SKIPPED_STATES, + failed_states=TEST_FAILED_STATES, + composer_external_dag_id=TEST_COMPOSER_DAG_ID, + composer_external_task_ids=TEST_COMPOSER_EXTERNAL_TASK_IDS, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + poll_interval=TEST_POLL_INTERVAL, + composer_airflow_version=composer_airflow_version, + ) + + with mock.patch.object(trigger, "_get_async_hook", return_value=hook): + actual_event = await trigger.run().asend(None) + + assert actual_event == TriggerEvent({"status": "success"}) + assert hook.get_task_instances.await_count == 2 + assert mock_sleep.await_count == 1 + + @pytest.mark.parametrize("composer_airflow_version", [2, 3]) + @pytest.mark.asyncio + @mock.patch( + "airflow.providers.google.cloud.triggers.cloud_composer.asyncio.sleep", new_callable=AsyncMock + ) + async def test_trigger_yields_success_on_in_window_allowed_task_instance( + self, mock_sleep, composer_airflow_version + ): + hook = AsyncMock() + environment = mock.Mock() + environment.config.airflow_uri = "https://composer.example" + hook.get_environment.return_value = environment + hook.get_task_instances.return_value = self._build_task_instances_result( + composer_airflow_version, [("success", "2024-03-22T11:10:00+00:00")] + ) + trigger = CloudComposerExternalTaskTrigger( + project_id=TEST_PROJECT_ID, + region=TEST_LOCATION, + environment_id=TEST_ENVIRONMENT_ID, + start_date=datetime(2024, 3, 22, 11, 0, 0, tzinfo=timezone.utc), + end_date=datetime(2024, 3, 22, 12, 0, 0, tzinfo=timezone.utc), + allowed_states=TEST_ALLOWED_STATES, + skipped_states=TEST_SKIPPED_STATES, + failed_states=TEST_FAILED_STATES, + composer_external_dag_id=TEST_COMPOSER_DAG_ID, + composer_external_task_ids=TEST_COMPOSER_EXTERNAL_TASK_IDS, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + poll_interval=TEST_POLL_INTERVAL, + composer_airflow_version=composer_airflow_version, + ) + + with mock.patch.object(trigger, "_get_async_hook", return_value=hook): + actual_event = await trigger.run().asend(None) + + assert actual_event == TriggerEvent({"status": "success"}) + assert hook.get_task_instances.await_count == 1