Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -562,16 +562,16 @@ def _check_task_instances_states(
end_date: datetime,
states: Iterable[str],
) -> bool:
found_task_instances_in_window = False
for task_instance in task_instances:
if (
start_date.timestamp()
< parser.parse(
task_instance["execution_date" if self._composer_airflow_version < 3 else "logical_date"]
).timestamp()
< end_date.timestamp()
) and task_instance["state"] not in states:
return False
return True
execution_date = parser.parse(
task_instance["execution_date" if self._composer_airflow_version < 3 else "logical_date"]
)
if start_date.timestamp() < execution_date.timestamp() < end_date.timestamp():
found_task_instances_in_window = True
if task_instance["state"] not in states:
return False
return found_task_instances_in_window

def _get_composer_airflow_version(self) -> int:
"""Return Composer Airflow version."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,16 +444,16 @@ def _check_task_instances_states(
end_date: datetime,
states: Iterable[str],
) -> bool:
found_task_instances_in_window = False
for task_instance in task_instances:
if (
start_date.timestamp()
< parser.parse(
task_instance["execution_date" if self.composer_airflow_version < 3 else "logical_date"]
).timestamp()
< end_date.timestamp()
) and task_instance["state"] not in states:
return False
return True
execution_date = parser.parse(
task_instance["execution_date" if self.composer_airflow_version < 3 else "logical_date"]
)
if start_date.timestamp() < execution_date.timestamp() < end_date.timestamp():
found_task_instances_in_window = True
if task_instance["state"] not in states:
return False
return found_task_instances_in_window

def _get_async_hook(self) -> CloudComposerAsyncHook:
return CloudComposerAsyncHook(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import json
from datetime import datetime
from datetime import datetime, timezone
from unittest import mock

import pytest
Expand Down Expand Up @@ -350,3 +350,91 @@ def test_composer_external_task_group_id_wait_not_ready(self, mock_hook, compose
task._composer_airflow_version = composer_airflow_version

assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)})

@pytest.mark.parametrize("composer_airflow_version", [2, 3])
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
def test_wait_not_ready_when_all_task_instances_outside_window(self, mock_hook, composer_airflow_version):
# All returned task instances are dated 2024-05-22, which is outside the
# execution window derived from a 2024-06-01 logical date.
mock_hook.return_value.get_task_instances.return_value = TEST_GET_TASK_INSTANCES_RESULT(
"success",
"execution_date" if composer_airflow_version < 3 else "logical_date",
TEST_COMPOSER_EXTERNAL_TASK_ID,
)

task = CloudComposerExternalTaskSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
environment_id=TEST_ENVIRONMENT_ID,
composer_external_dag_id="test_dag_id",
allowed_states=["success"],
)
task._composer_airflow_version = composer_airflow_version

assert not task.poke(context={"logical_date": datetime(2024, 6, 1, 0, 0, 0)})

@pytest.mark.parametrize("composer_airflow_version", [2, 3])
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
def test_wait_ready_when_in_window_instance_present_with_out_of_window_instances(
self, mock_hook, composer_airflow_version
):
date_key = "execution_date" if composer_airflow_version < 3 else "logical_date"
mock_hook.return_value.get_task_instances.return_value = {
"task_instances": [
{
"task_id": TEST_COMPOSER_EXTERNAL_TASK_ID,
"dag_id": "test_dag_id",
"state": "running",
date_key: "2024-05-20T11:10:00+00:00",
},
{
"task_id": TEST_COMPOSER_EXTERNAL_TASK_ID,
"dag_id": "test_dag_id",
"state": "success",
date_key: "2024-05-22T11:10:00+00:00",
},
],
"total_entries": 2,
}

task = CloudComposerExternalTaskSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
environment_id=TEST_ENVIRONMENT_ID,
composer_external_dag_id="test_dag_id",
allowed_states=["success"],
)
task._composer_airflow_version = composer_airflow_version

assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)})

@pytest.mark.parametrize("composer_airflow_version", [2, 3])
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
def test_wait_not_ready_when_task_instance_on_window_boundary(self, mock_hook, composer_airflow_version):
# The task instance is dated exactly on the window start (start of the
# range is exclusive), so it must not be treated as in-window.
mock_hook.return_value.get_task_instances.return_value = TEST_GET_TASK_INSTANCES_RESULT(
"success",
"execution_date" if composer_airflow_version < 3 else "logical_date",
TEST_COMPOSER_EXTERNAL_TASK_ID,
)

# Window start is set exactly to the task instance date; the start of the
# range is exclusive, so the task instance must not be treated as in-window.
task = CloudComposerExternalTaskSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
environment_id=TEST_ENVIRONMENT_ID,
composer_external_dag_id="test_dag_id",
allowed_states=["success"],
execution_range=[
datetime(2024, 5, 22, 11, 10, 0, tzinfo=timezone.utc),
datetime(2024, 5, 22, 12, 0, 0, tzinfo=timezone.utc),
],
)
task._composer_airflow_version = composer_airflow_version

assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)})
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

from __future__ import annotations

from datetime import datetime
from datetime import datetime, timezone
from unittest import mock
from unittest.mock import AsyncMock

import pytest

Expand Down Expand Up @@ -209,3 +210,105 @@ def test_serialize(self, external_task_trigger):
},
)
assert actual_data == expected_data

@staticmethod
def _build_task_instances_result(
composer_airflow_version: int, task_instances: list[tuple[str, str]]
) -> dict:
date_key = "execution_date" if composer_airflow_version < 3 else "logical_date"
return {
"task_instances": [
{
"task_id": TEST_COMPOSER_EXTERNAL_TASK_IDS[0],
"dag_id": TEST_COMPOSER_DAG_ID,
"state": state,
date_key: logical_date,
}
for state, logical_date in task_instances
],
"total_entries": len(task_instances),
}

@pytest.mark.parametrize("composer_airflow_version", [2, 3])
@pytest.mark.asyncio
@mock.patch(
"airflow.providers.google.cloud.triggers.cloud_composer.asyncio.sleep", new_callable=AsyncMock
)
async def test_trigger_keeps_polling_when_only_out_of_window_task_instances(
self, mock_sleep, composer_airflow_version
):
hook = AsyncMock()
environment = mock.Mock()
environment.config.airflow_uri = "https://composer.example"
hook.get_environment.return_value = environment
# First poll returns only an out-of-window task instance, so the trigger
# must keep polling instead of yielding success.
hook.get_task_instances.side_effect = [
self._build_task_instances_result(
composer_airflow_version, [("success", "2024-03-23T11:10:00+00:00")]
),
self._build_task_instances_result(
composer_airflow_version, [("success", "2024-03-22T11:10:00+00:00")]
),
]
trigger = CloudComposerExternalTaskTrigger(
project_id=TEST_PROJECT_ID,
region=TEST_LOCATION,
environment_id=TEST_ENVIRONMENT_ID,
start_date=datetime(2024, 3, 22, 11, 0, 0, tzinfo=timezone.utc),
end_date=datetime(2024, 3, 22, 12, 0, 0, tzinfo=timezone.utc),
allowed_states=TEST_ALLOWED_STATES,
skipped_states=TEST_SKIPPED_STATES,
failed_states=TEST_FAILED_STATES,
composer_external_dag_id=TEST_COMPOSER_DAG_ID,
composer_external_task_ids=TEST_COMPOSER_EXTERNAL_TASK_IDS,
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
poll_interval=TEST_POLL_INTERVAL,
composer_airflow_version=composer_airflow_version,
)

with mock.patch.object(trigger, "_get_async_hook", return_value=hook):
actual_event = await trigger.run().asend(None)

assert actual_event == TriggerEvent({"status": "success"})
assert hook.get_task_instances.await_count == 2
assert mock_sleep.await_count == 1

@pytest.mark.parametrize("composer_airflow_version", [2, 3])
@pytest.mark.asyncio
@mock.patch(
"airflow.providers.google.cloud.triggers.cloud_composer.asyncio.sleep", new_callable=AsyncMock
)
async def test_trigger_yields_success_on_in_window_allowed_task_instance(
self, mock_sleep, composer_airflow_version
):
hook = AsyncMock()
environment = mock.Mock()
environment.config.airflow_uri = "https://composer.example"
hook.get_environment.return_value = environment
hook.get_task_instances.return_value = self._build_task_instances_result(
composer_airflow_version, [("success", "2024-03-22T11:10:00+00:00")]
)
trigger = CloudComposerExternalTaskTrigger(
project_id=TEST_PROJECT_ID,
region=TEST_LOCATION,
environment_id=TEST_ENVIRONMENT_ID,
start_date=datetime(2024, 3, 22, 11, 0, 0, tzinfo=timezone.utc),
end_date=datetime(2024, 3, 22, 12, 0, 0, tzinfo=timezone.utc),
allowed_states=TEST_ALLOWED_STATES,
skipped_states=TEST_SKIPPED_STATES,
failed_states=TEST_FAILED_STATES,
composer_external_dag_id=TEST_COMPOSER_DAG_ID,
composer_external_task_ids=TEST_COMPOSER_EXTERNAL_TASK_IDS,
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
poll_interval=TEST_POLL_INTERVAL,
composer_airflow_version=composer_airflow_version,
)

with mock.patch.object(trigger, "_get_async_hook", return_value=hook):
actual_event = await trigger.run().asend(None)

assert actual_event == TriggerEvent({"status": "success"})
assert hook.get_task_instances.await_count == 1