Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState

if TYPE_CHECKING:
from datetime import datetime

from sqlalchemy.sql.dml import Update

router = VersionedAPIRouter()
Expand Down Expand Up @@ -176,10 +178,28 @@ def ti_run(
# We exclude_unset to avoid updating fields that are not set in the payload
data = ti_run_payload.model_dump(exclude_unset=True)

first_reschedule_start_date: datetime | None = None

# don't update start date when resuming from deferral
if ti.next_kwargs:
data.pop("start_date")
log.debug("Removed start_date from update as task is resuming from deferral")
elif "start_date" in data:
# Preserve the first-poke start_date for a rescheduled task. The supervisor sends
# start_date=utcnow() on every poke; without this guard the metric
# dagrun.first_task_scheduling_delay (computed from start_date - queued_at)
# collapses to ~0 for any DAG fronted by a reschedule-mode sensor.
# prepare_db_for_next_try clears TaskReschedule rows and rotates ti.id on each
# retry, so rows with ti_id == task_instance_id always belong to the current try.
first_reschedule_start_date = session.scalar(
select(TaskReschedule.start_date)
.where(TaskReschedule.ti_id == task_instance_id)
.order_by(TaskReschedule.id.asc())
.limit(1)
)
if first_reschedule_start_date is not None:
data["start_date"] = first_reschedule_start_date
log.debug("Restored start_date from first TaskReschedule entry for rescheduled task")

query = update(TI).where(TI.id == task_instance_id).values(data)

Expand Down Expand Up @@ -302,6 +322,10 @@ def ti_run(
context.next_method = ti.next_method
context.next_kwargs = ti.next_kwargs
context.start_date = ti.start_date
elif first_reschedule_start_date is not None:
# Mirror the deferral-resume behavior so the supervisor preserves the
# first-poke start_date for context["ti"].start_date as well.
context.start_date = first_reschedule_start_date
except SQLAlchemyError:
log.exception("Error marking Task Instance state as running")
raise HTTPException(
Expand Down
11 changes: 9 additions & 2 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,11 +1312,18 @@ def _check_and_change_state_before_execution(
# start date that is recorded in task_reschedule table
# If the task continues after being deferred (next_method is set), use the original start_date
ti.start_date = ti.start_date if ti.next_method else timezone.utcnow()
if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE:
if not ti.next_method:
# Restore start_date for a rescheduled task. The state guard
# `ti.state == UP_FOR_RESCHEDULE` that previously wrapped this lookup is
# unreliable: the scheduler advances UP_FOR_RESCHEDULE -> QUEUED before the
# worker calls this method, so refresh_from_db returns QUEUED and the guard
# never fires. The query is scoped to the current ti.id, which is
# rotated by prepare_db_for_next_try on each retry, and returns no rows for
# non-rescheduled tasks -- so dropping the state guard is safe.
tr_start_date = session.scalar(
TR.stmt_for_task_instance(ti, descending=False).with_only_columns(TR.start_date).limit(1)
)
if tr_start_date:
if tr_start_date is not None:
ti.start_date = tr_start_date

# Secondly we find non-runnable but requeueable tis. We reset its state.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,97 @@ def test_ti_run_resume_returns_original_start_date_in_context(
assert ti is not None
assert ti.start_date == original_start_date

def test_ti_run_restores_start_date_for_rescheduled_task(
self,
client,
session,
create_task_instance,
):
"""
For a reschedule-mode sensor, the supervisor sends ``start_date=utcnow()`` on every
poke. The ``ti_run`` endpoint must restore the first-poke ``start_date`` from the
``TaskReschedule`` table so that ``dagrun.first_task_scheduling_delay`` reflects the
actual wait time instead of collapsing to ~0.
"""
original_start_date = timezone.parse("2024-09-30T12:00:00Z")
payload_start_date = timezone.parse("2024-09-30T12:05:00Z")

ti = create_task_instance(
task_id="test_ti_run_restores_start_date_for_rescheduled_task",
state=State.QUEUED,
session=session,
start_date=original_start_date,
dag_id=str(uuid4()),
)
session.commit()

session.add(
TaskReschedule(
ti_id=ti.id,
start_date=original_start_date,
end_date=timezone.parse("2024-09-30T12:00:10Z"),
reschedule_date=timezone.parse("2024-09-30T12:01:00Z"),
)
)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/run",
json={
"state": "running",
"hostname": "random-hostname",
"unixname": "random-unixname",
"pid": 100,
"start_date": payload_start_date.isoformat(),
},
)

assert response.status_code == 200
result = response.json()
assert timezone.parse(result["start_date"]) == original_start_date

session.expunge_all()
ti = session.get(TaskInstance, ti.id)
assert ti is not None
assert ti.start_date == original_start_date

def test_ti_run_uses_payload_start_date_when_no_reschedule_rows(
self,
client,
session,
create_task_instance,
):
"""
For a non-rescheduled task, the ``TaskReschedule`` lookup returns nothing and the
payload ``start_date`` from the supervisor is preserved unchanged.
"""
payload_start_date = timezone.parse("2024-09-30T12:05:00Z")

ti = create_task_instance(
task_id="test_ti_run_uses_payload_start_date_when_no_reschedule_rows",
state=State.QUEUED,
session=session,
dag_id=str(uuid4()),
)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/run",
json={
"state": "running",
"hostname": "random-hostname",
"unixname": "random-unixname",
"pid": 100,
"start_date": payload_start_date.isoformat(),
},
)

assert response.status_code == 200
session.expunge_all()
ti = session.get(TaskInstance, ti.id)
assert ti is not None
assert ti.start_date == payload_start_date

@pytest.mark.parametrize(
"initial_ti_state",
[s for s in TaskInstanceState if s not in (TaskInstanceState.QUEUED, TaskInstanceState.RESTARTING)],
Expand Down
30 changes: 30 additions & 0 deletions airflow-core/tests/unit/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,36 @@ def test_check_and_change_state_before_execution_dep_not_met_not_runnable_state(
assert not ti_from_deserialized_task.check_and_change_state_before_execution()
assert ti_from_deserialized_task.state == State.FAILED

def test_check_and_change_state_before_execution_restores_reschedule_start_date(
self, create_task_instance, testing_dag_bundle
):
"""
The state guard ``ti.state == UP_FOR_RESCHEDULE`` does not fire when the scheduler
has already advanced the state to ``QUEUED`` before the worker runs the check (the
normal multi-scheduler flow). The TaskReschedule lookup must restore the original
``start_date`` regardless of the observed state.
"""
first_poke_start_date = timezone.datetime(2024, 1, 1, 10, 0, 0)
ti = create_task_instance(dag_id="test_reschedule_start_date_restored_under_queued")
with create_session() as session:
ti.state = State.QUEUED
session.add(
TaskReschedule(
ti_id=ti.id,
start_date=first_poke_start_date,
end_date=timezone.datetime(2024, 1, 1, 10, 0, 10),
reschedule_date=timezone.datetime(2024, 1, 1, 10, 1, 0),
)
)

serialized_dag = SerializedDagModel.get(ti.task.dag.dag_id).dag
ti_from_deserialized_task = TI(
task=serialized_dag.get_task(ti.task_id), run_id=ti.run_id, dag_version_id=ti.dag_version_id
)

assert ti_from_deserialized_task.check_and_change_state_before_execution()
assert ti_from_deserialized_task.start_date == first_poke_start_date

def test_try_number(self, create_task_instance):
"""
Test the try_number accessor behaves in various running states
Expand Down
Loading