From 08823ae3078da1ad15cd851897a941672e0b3c82 Mon Sep 17 00:00:00 2001 From: Hemkumar Chheda Date: Thu, 4 Jun 2026 15:23:21 +0530 Subject: [PATCH] Fix rescheduled sensors hanging before poke closes: #68010 --- .../execution_api/datamodels/taskinstance.py | 3 + .../execution_api/routes/task_instances.py | 10 +++ .../execution_api/versions/__init__.py | 8 +- .../execution_api/versions/v2026_06_30.py | 14 ++- .../versions/head/test_task_instances.py | 54 ++++++++++++ .../v2026_06_30/test_task_instances.py | 88 +++++++++++++++++++ .../airflow/sdk/api/datamodels/_generated.py | 3 + .../airflow/sdk/execution_time/task_runner.py | 6 ++ task-sdk/tests/conftest.py | 6 ++ .../execution_time/test_task_runner.py | 18 +++- 10 files changed, 207 insertions(+), 3 deletions(-) create mode 100644 airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_06_30/test_task_instances.py diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index a0d9739080118..95b94b23100f3 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -366,6 +366,9 @@ class TIRunContext(BaseModel): task_reschedule_count: int = 0 """How many times the task has been rescheduled.""" + first_task_reschedule_start_date: UtcDateTime | None = None + """The first reschedule start date for the task instance, if it has been rescheduled.""" + max_tries: int """Maximum number of tries for the task instance (from DB).""" diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 2afd96806c473..a1abcc2495a0f 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -287,6 +287,14 @@ def ti_run( ) or 0 ) + first_task_reschedule_start_date = None + if task_reschedule_count > 0: + first_task_reschedule_start_date = session.scalar( + select(TaskReschedule.start_date) + .where(TaskReschedule.ti_id == task_instance_id) + .order_by(TaskReschedule.id.asc()) + .limit(1) + ) from airflow.api_fastapi.execution_api.security import get_team_name_for_ti @@ -302,6 +310,8 @@ def ti_run( xcom_keys_to_clear=xcom_keys, should_retry=_is_eligible_to_retry(previous_state, ti.try_number, ti.max_tries), ) + if first_task_reschedule_start_date is not None: + context.first_task_reschedule_start_date = first_task_reschedule_start_date # Only set if they are non-null if ti.next_method: diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index 50ddb6985e890..bdffb31522428 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -48,12 +48,18 @@ ) from airflow.api_fastapi.execution_api.versions.v2026_06_30 import ( AddConnectionTestEndpoint, + AddFirstTaskRescheduleStartDateField, AddVariableKeysEndpoint, ) bundle = VersionBundle( HeadVersion(), - Version("2026-06-30", AddVariableKeysEndpoint, AddConnectionTestEndpoint), + Version( + "2026-06-30", + AddFirstTaskRescheduleStartDateField, + AddVariableKeysEndpoint, + AddConnectionTestEndpoint, + ), Version( "2026-06-16", AddRetryPolicyFields, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py index cc751bcc79765..b8f666c16c7c5 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py @@ -17,7 +17,9 @@ from __future__ import annotations -from cadwyn import VersionChange, endpoint +from cadwyn import VersionChange, endpoint, schema + +from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRunContext class AddVariableKeysEndpoint(VersionChange): @@ -37,3 +39,13 @@ class AddConnectionTestEndpoint(VersionChange): endpoint("/connection-tests/{connection_test_id}", ["PATCH"]).didnt_exist, endpoint("/connection-tests/{connection_test_id}/connection", ["GET"]).didnt_exist, ) + + +class AddFirstTaskRescheduleStartDateField(VersionChange): + """Add first_task_reschedule_start_date field to TIRunContext.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + schema(TIRunContext).field("first_task_reschedule_start_date").didnt_exist, + ) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 3022bbfea06e3..d51a0b9c6d3f8 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -289,6 +289,60 @@ def test_ti_run_state_to_running( ) assert response.status_code == 409 + def test_ti_run_state_includes_first_task_reschedule_start_date( + self, + client, + session, + create_task_instance, + ): + """Test that running a rescheduled Task Instance includes its first reschedule start date.""" + instant_str = "2024-09-30T12:00:00Z" + instant = timezone.parse(instant_str) + first_reschedule_start_date = timezone.datetime(2024, 9, 30, 10) + second_reschedule_start_date = timezone.datetime(2024, 9, 30, 11) + + ti = create_task_instance( + task_id="test_ti_run_state_includes_first_task_reschedule_start_date", + state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, + session=session, + start_date=instant, + dag_id=str(uuid4()), + ) + session.add_all( + [ + TaskReschedule( + ti_id=ti.id, + start_date=first_reschedule_start_date, + end_date=timezone.datetime(2024, 9, 30, 10, 1), + reschedule_date=timezone.datetime(2024, 9, 30, 10, 2), + ), + TaskReschedule( + ti_id=ti.id, + start_date=second_reschedule_start_date, + end_date=timezone.datetime(2024, 9, 30, 11, 1), + reschedule_date=timezone.datetime(2024, 9, 30, 11, 2), + ), + ] + ) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": instant_str, + }, + ) + + assert response.status_code == 200 + result = response.json() + assert result["task_reschedule_count"] == 2 + assert result["first_task_reschedule_start_date"] == "2024-09-30T10:00:00Z" + def test_ti_run_returns_execution_token( self, client, exec_app, session, create_task_instance, time_machine ): diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_06_30/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_06_30/test_task_instances.py new file mode 100644 index 0000000000000..7304d0e879ed2 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_06_30/test_task_instances.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from uuid import uuid4 + +import pytest + +from airflow._shared.timezones import timezone +from airflow.models import TaskReschedule +from airflow.utils.state import DagRunState, State + +from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags + +pytestmark = pytest.mark.db_test + + +@pytest.fixture +def old_ver_client(client): + """Last released execution API before first_task_reschedule_start_date was added.""" + client.headers["Airflow-API-Version"] = "2026-06-16" + return client + + +@pytest.fixture(autouse=True) +def setup_teardown(): + clear_db_runs() + clear_db_serialized_dags() + clear_db_dags() + yield + clear_db_runs() + clear_db_serialized_dags() + clear_db_dags() + + +def test_first_task_reschedule_start_date_removed_from_previous_version( + old_ver_client, + session, + create_task_instance, +): + ti = create_task_instance( + task_id="test_first_task_reschedule_start_date_removed_from_previous_version", + state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, + session=session, + start_date=timezone.datetime(2024, 9, 30, 12), + dag_id=str(uuid4()), + ) + session.add( + TaskReschedule( + ti_id=ti.id, + start_date=timezone.datetime(2024, 9, 30, 10), + end_date=timezone.datetime(2024, 9, 30, 10, 1), + reschedule_date=timezone.datetime(2024, 9, 30, 10, 2), + ) + ) + session.commit() + + response = old_ver_client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": "2024-09-30T12:00:00Z", + }, + ) + + assert response.status_code == 200 + result = response.json() + assert result["task_reschedule_count"] == 1 + assert "first_task_reschedule_start_date" not in result diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index e1ac21585bf93..4bc566923f135 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -768,6 +768,9 @@ class TIRunContext(BaseModel): dag_run: DagRun task_reschedule_count: Annotated[int | None, Field(title="Task Reschedule Count")] = 0 + first_task_reschedule_start_date: Annotated[ + AwareDatetime | None, Field(title="First Task Reschedule Start Date") + ] = None max_tries: Annotated[int, Field(title="Max Tries")] variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 22ac90405e027..316cd75514738 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -575,6 +575,12 @@ def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: # If the task has not been rescheduled, there is no need to ask the supervisor return None + first_task_reschedule_start_date = getattr( + self._ti_context_from_server, "first_task_reschedule_start_date", None + ) + if first_task_reschedule_start_date is not None: + return first_task_reschedule_start_date + max_tries: int = self.max_tries retries: int = self.task.retries or 0 first_try_number = max_tries - retries + 1 diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py index c1ef3b72c92f4..321dbc7c349de 100644 --- a/task-sdk/tests/conftest.py +++ b/task-sdk/tests/conftest.py @@ -229,6 +229,7 @@ def __call__( run_after: str | datetime = ..., run_type: str = ..., task_reschedule_count: int = ..., + first_task_reschedule_start_date: str | datetime | None = ..., conf: dict[str, Any] | None = ..., should_retry: bool = ..., max_tries: int = ..., @@ -249,6 +250,7 @@ def __call__( run_after: str | datetime = ..., run_type: str = ..., task_reschedule_count: int = ..., + first_task_reschedule_start_date: str | datetime | None = ..., conf=None, consumed_asset_events: Sequence[AssetEventDagRunReference] = ..., ) -> dict[str, Any]: ... @@ -271,6 +273,7 @@ def _make_context( run_after: str | datetime = "2024-12-01T01:00:00Z", run_type: str = "manual", task_reschedule_count: int = 0, + first_task_reschedule_start_date: str | datetime | None = None, conf: dict[str, Any] | None = None, should_retry: bool = False, max_tries: int = 0, @@ -292,6 +295,7 @@ def _make_context( consumed_asset_events=list(consumed_asset_events), ), task_reschedule_count=task_reschedule_count, + first_task_reschedule_start_date=first_task_reschedule_start_date, max_tries=max_tries, should_retry=should_retry, ) @@ -314,6 +318,7 @@ def _make_context_dict( run_after: str | datetime = "2024-12-01T00:00:00Z", run_type: str = "manual", task_reschedule_count: int = 0, + first_task_reschedule_start_date: str | datetime | None = None, conf=None, consumed_asset_events: Sequence[AssetEventDagRunReference] = (), ) -> dict[str, Any]: @@ -329,6 +334,7 @@ def _make_context_dict( run_type=run_type, conf=conf, task_reschedule_count=task_reschedule_count, + first_task_reschedule_start_date=first_task_reschedule_start_date, consumed_asset_events=consumed_asset_events, ) return context.model_dump(exclude_unset=True, mode="json") diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index f8953dc232970..497e31c3b771f 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -2722,7 +2722,7 @@ def __init__(self, command, *args, **kwargs): def test_get_first_reschedule_date( self, create_runtime_ti, mock_supervisor_comms, task_reschedule_count, expected_date ): - """Test that the first reschedule date is fetched from the Supervisor.""" + """Test that the first reschedule date falls back to the Supervisor.""" task = BaseOperator(task_id="hello") runtime_ti = create_runtime_ti(task=task, task_reschedule_count=task_reschedule_count) @@ -2733,6 +2733,22 @@ def test_get_first_reschedule_date( context = runtime_ti.get_template_context() assert runtime_ti.get_first_reschedule_date(context=context) == expected_date + def test_get_first_reschedule_date_uses_context_from_server( + self, create_runtime_ti, make_ti_context, mock_supervisor_comms + ): + """Test that first reschedule date from server context avoids a Supervisor request.""" + first_reschedule_date = timezone.datetime(2025, 1, 1) + task = BaseOperator(task_id="hello") + runtime_ti = create_runtime_ti(task=task, task_reschedule_count=1) + runtime_ti._ti_context_from_server = make_ti_context( + task_reschedule_count=1, + first_task_reschedule_start_date=first_reschedule_date, + ) + + context = runtime_ti.get_template_context() + assert runtime_ti.get_first_reschedule_date(context=context) == first_reschedule_date + mock_supervisor_comms.send.assert_not_called() + def test_get_ti_count(self, mock_supervisor_comms): """Test that get_ti_count sends the correct request and returns the count.""" mock_supervisor_comms.send.return_value = TICount(count=2)