diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 13d8245621d63..d960d49044b40 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -82,6 +82,8 @@ from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState if TYPE_CHECKING: + from datetime import datetime + from sqlalchemy.sql.dml import Update router = VersionedAPIRouter() @@ -176,10 +178,28 @@ def ti_run( # We exclude_unset to avoid updating fields that are not set in the payload data = ti_run_payload.model_dump(exclude_unset=True) + first_reschedule_start_date: datetime | None = None + # don't update start date when resuming from deferral if ti.next_kwargs: data.pop("start_date") log.debug("Removed start_date from update as task is resuming from deferral") + elif "start_date" in data: + # Preserve the first-poke start_date for a rescheduled task. The supervisor sends + # start_date=utcnow() on every poke; without this guard the metric + # dagrun.first_task_scheduling_delay (computed from start_date - queued_at) + # collapses to ~0 for any DAG fronted by a reschedule-mode sensor. + # prepare_db_for_next_try clears TaskReschedule rows and rotates ti.id on each + # retry, so rows with ti_id == task_instance_id always belong to the current try. + first_reschedule_start_date = session.scalar( + select(TaskReschedule.start_date) + .where(TaskReschedule.ti_id == task_instance_id) + .order_by(TaskReschedule.id.asc()) + .limit(1) + ) + if first_reschedule_start_date is not None: + data["start_date"] = first_reschedule_start_date + log.debug("Restored start_date from first TaskReschedule entry for rescheduled task") query = update(TI).where(TI.id == task_instance_id).values(data) @@ -302,6 +322,10 @@ def ti_run( context.next_method = ti.next_method context.next_kwargs = ti.next_kwargs context.start_date = ti.start_date + elif first_reschedule_start_date is not None: + # Mirror the deferral-resume behavior so the supervisor preserves the + # first-poke start_date for context["ti"].start_date as well. + context.start_date = first_reschedule_start_date except SQLAlchemyError: log.exception("Error marking Task Instance state as running") raise HTTPException( diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index f4bba7c65ea3d..2a8e4df5d94fb 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -1312,11 +1312,18 @@ def _check_and_change_state_before_execution( # start date that is recorded in task_reschedule table # If the task continues after being deferred (next_method is set), use the original start_date ti.start_date = ti.start_date if ti.next_method else timezone.utcnow() - if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE: + if not ti.next_method: + # Restore start_date for a rescheduled task. The state guard + # `ti.state == UP_FOR_RESCHEDULE` that previously wrapped this lookup is + # unreliable: the scheduler advances UP_FOR_RESCHEDULE -> QUEUED before the + # worker calls this method, so refresh_from_db returns QUEUED and the guard + # never fires. The query is scoped to the current ti.id, which is + # rotated by prepare_db_for_next_try on each retry, and returns no rows for + # non-rescheduled tasks -- so dropping the state guard is safe. tr_start_date = session.scalar( TR.stmt_for_task_instance(ti, descending=False).with_only_columns(TR.start_date).limit(1) ) - if tr_start_date: + if tr_start_date is not None: ti.start_date = tr_start_date # Secondly we find non-runnable but requeueable tis. We reset its state. diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index d26bcf7bfd862..de4bd27b8116a 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -790,6 +790,97 @@ def test_ti_run_resume_returns_original_start_date_in_context( assert ti is not None assert ti.start_date == original_start_date + def test_ti_run_restores_start_date_for_rescheduled_task( + self, + client, + session, + create_task_instance, + ): + """ + For a reschedule-mode sensor, the supervisor sends ``start_date=utcnow()`` on every + poke. The ``ti_run`` endpoint must restore the first-poke ``start_date`` from the + ``TaskReschedule`` table so that ``dagrun.first_task_scheduling_delay`` reflects the + actual wait time instead of collapsing to ~0. + """ + original_start_date = timezone.parse("2024-09-30T12:00:00Z") + payload_start_date = timezone.parse("2024-09-30T12:05:00Z") + + ti = create_task_instance( + task_id="test_ti_run_restores_start_date_for_rescheduled_task", + state=State.QUEUED, + session=session, + start_date=original_start_date, + dag_id=str(uuid4()), + ) + session.commit() + + session.add( + TaskReschedule( + ti_id=ti.id, + start_date=original_start_date, + end_date=timezone.parse("2024-09-30T12:00:10Z"), + reschedule_date=timezone.parse("2024-09-30T12:01:00Z"), + ) + ) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": payload_start_date.isoformat(), + }, + ) + + assert response.status_code == 200 + result = response.json() + assert timezone.parse(result["start_date"]) == original_start_date + + session.expunge_all() + ti = session.get(TaskInstance, ti.id) + assert ti is not None + assert ti.start_date == original_start_date + + def test_ti_run_uses_payload_start_date_when_no_reschedule_rows( + self, + client, + session, + create_task_instance, + ): + """ + For a non-rescheduled task, the ``TaskReschedule`` lookup returns nothing and the + payload ``start_date`` from the supervisor is preserved unchanged. + """ + payload_start_date = timezone.parse("2024-09-30T12:05:00Z") + + ti = create_task_instance( + task_id="test_ti_run_uses_payload_start_date_when_no_reschedule_rows", + state=State.QUEUED, + session=session, + dag_id=str(uuid4()), + ) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": payload_start_date.isoformat(), + }, + ) + + assert response.status_code == 200 + session.expunge_all() + ti = session.get(TaskInstance, ti.id) + assert ti is not None + assert ti.start_date == payload_start_date + @pytest.mark.parametrize( "initial_ti_state", [s for s in TaskInstanceState if s not in (TaskInstanceState.QUEUED, TaskInstanceState.RESTARTING)], diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index b9e5452855ff1..1258c73d97fae 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -1562,6 +1562,36 @@ def test_check_and_change_state_before_execution_dep_not_met_not_runnable_state( assert not ti_from_deserialized_task.check_and_change_state_before_execution() assert ti_from_deserialized_task.state == State.FAILED + def test_check_and_change_state_before_execution_restores_reschedule_start_date( + self, create_task_instance, testing_dag_bundle + ): + """ + The state guard ``ti.state == UP_FOR_RESCHEDULE`` does not fire when the scheduler + has already advanced the state to ``QUEUED`` before the worker runs the check (the + normal multi-scheduler flow). The TaskReschedule lookup must restore the original + ``start_date`` regardless of the observed state. + """ + first_poke_start_date = timezone.datetime(2024, 1, 1, 10, 0, 0) + ti = create_task_instance(dag_id="test_reschedule_start_date_restored_under_queued") + with create_session() as session: + ti.state = State.QUEUED + session.add( + TaskReschedule( + ti_id=ti.id, + start_date=first_poke_start_date, + end_date=timezone.datetime(2024, 1, 1, 10, 0, 10), + reschedule_date=timezone.datetime(2024, 1, 1, 10, 1, 0), + ) + ) + + serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag + ti_from_deserialized_task = TI( + task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id, dag_version_id=ti.dag_version_id + ) + + assert ti_from_deserialized_task.check_and_change_state_before_execution() + assert ti_from_deserialized_task.start_date == first_poke_start_date + def test_try_number(self, create_task_instance): """ Test the try_number accessor behaves in various running states