Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ class TIRunContext(BaseModel):
task_reschedule_count: int = 0
"""How many times the task has been rescheduled."""

first_task_reschedule_start_date: UtcDateTime | None = None
"""The first reschedule start date for the task instance, if it has been rescheduled."""

max_tries: int
"""Maximum number of tries for the task instance (from DB)."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,14 @@ def ti_run(
)
or 0
)
first_task_reschedule_start_date = None
if task_reschedule_count > 0:
first_task_reschedule_start_date = session.scalar(
select(TaskReschedule.start_date)
.where(TaskReschedule.ti_id == task_instance_id)
.order_by(TaskReschedule.id.asc())
.limit(1)
)

from airflow.api_fastapi.execution_api.security import get_team_name_for_ti

Expand All @@ -302,6 +310,8 @@ def ti_run(
xcom_keys_to_clear=xcom_keys,
should_retry=_is_eligible_to_retry(previous_state, ti.try_number, ti.max_tries),
)
if first_task_reschedule_start_date is not None:
context.first_task_reschedule_start_date = first_task_reschedule_start_date

# Only set if they are non-null
if ti.next_method:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,18 @@
)
from airflow.api_fastapi.execution_api.versions.v2026_06_30 import (
AddConnectionTestEndpoint,
AddFirstTaskRescheduleStartDateField,
AddVariableKeysEndpoint,
)

bundle = VersionBundle(
HeadVersion(),
Version("2026-06-30", AddVariableKeysEndpoint, AddConnectionTestEndpoint),
Version(
"2026-06-30",
AddFirstTaskRescheduleStartDateField,
AddVariableKeysEndpoint,
AddConnectionTestEndpoint,
),
Version(
"2026-06-16",
AddRetryPolicyFields,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

from __future__ import annotations

from cadwyn import VersionChange, endpoint
from cadwyn import VersionChange, endpoint, schema

from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRunContext


class AddVariableKeysEndpoint(VersionChange):
Expand All @@ -37,3 +39,13 @@ class AddConnectionTestEndpoint(VersionChange):
endpoint("/connection-tests/{connection_test_id}", ["PATCH"]).didnt_exist,
endpoint("/connection-tests/{connection_test_id}/connection", ["GET"]).didnt_exist,
)


class AddFirstTaskRescheduleStartDateField(VersionChange):
"""Add first_task_reschedule_start_date field to TIRunContext."""

description = __doc__

instructions_to_migrate_to_previous_version = (
schema(TIRunContext).field("first_task_reschedule_start_date").didnt_exist,
)
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,60 @@ def test_ti_run_state_to_running(
)
assert response.status_code == 409

def test_ti_run_state_includes_first_task_reschedule_start_date(
self,
client,
session,
create_task_instance,
):
"""Test that running a rescheduled Task Instance includes its first reschedule start date."""
instant_str = "2024-09-30T12:00:00Z"
instant = timezone.parse(instant_str)
first_reschedule_start_date = timezone.datetime(2024, 9, 30, 10)
second_reschedule_start_date = timezone.datetime(2024, 9, 30, 11)

ti = create_task_instance(
task_id="test_ti_run_state_includes_first_task_reschedule_start_date",
state=State.QUEUED,
dagrun_state=DagRunState.RUNNING,
session=session,
start_date=instant,
dag_id=str(uuid4()),
)
session.add_all(
[
TaskReschedule(
ti_id=ti.id,
start_date=first_reschedule_start_date,
end_date=timezone.datetime(2024, 9, 30, 10, 1),
reschedule_date=timezone.datetime(2024, 9, 30, 10, 2),
),
TaskReschedule(
ti_id=ti.id,
start_date=second_reschedule_start_date,
end_date=timezone.datetime(2024, 9, 30, 11, 1),
reschedule_date=timezone.datetime(2024, 9, 30, 11, 2),
),
]
)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/run",
json={
"state": "running",
"hostname": "random-hostname",
"unixname": "random-unixname",
"pid": 100,
"start_date": instant_str,
},
)

assert response.status_code == 200
result = response.json()
assert result["task_reschedule_count"] == 2
assert result["first_task_reschedule_start_date"] == "2024-09-30T10:00:00Z"

def test_ti_run_returns_execution_token(
self, client, exec_app, session, create_task_instance, time_machine
):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from uuid import uuid4

import pytest

from airflow._shared.timezones import timezone
from airflow.models import TaskReschedule
from airflow.utils.state import DagRunState, State

from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags

pytestmark = pytest.mark.db_test


@pytest.fixture
def old_ver_client(client):
"""Last released execution API before first_task_reschedule_start_date was added."""
client.headers["Airflow-API-Version"] = "2026-06-16"
return client


@pytest.fixture(autouse=True)
def setup_teardown():
clear_db_runs()
clear_db_serialized_dags()
clear_db_dags()
yield
clear_db_runs()
clear_db_serialized_dags()
clear_db_dags()


def test_first_task_reschedule_start_date_removed_from_previous_version(
old_ver_client,
session,
create_task_instance,
):
ti = create_task_instance(
task_id="test_first_task_reschedule_start_date_removed_from_previous_version",
state=State.QUEUED,
dagrun_state=DagRunState.RUNNING,
session=session,
start_date=timezone.datetime(2024, 9, 30, 12),
dag_id=str(uuid4()),
)
session.add(
TaskReschedule(
ti_id=ti.id,
start_date=timezone.datetime(2024, 9, 30, 10),
end_date=timezone.datetime(2024, 9, 30, 10, 1),
reschedule_date=timezone.datetime(2024, 9, 30, 10, 2),
)
)
session.commit()

response = old_ver_client.patch(
f"/execution/task-instances/{ti.id}/run",
json={
"state": "running",
"hostname": "random-hostname",
"unixname": "random-unixname",
"pid": 100,
"start_date": "2024-09-30T12:00:00Z",
},
)

assert response.status_code == 200
result = response.json()
assert result["task_reschedule_count"] == 1
assert "first_task_reschedule_start_date" not in result
3 changes: 3 additions & 0 deletions task-sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,9 @@ class TIRunContext(BaseModel):

dag_run: DagRun
task_reschedule_count: Annotated[int | None, Field(title="Task Reschedule Count")] = 0
first_task_reschedule_start_date: Annotated[
AwareDatetime | None, Field(title="First Task Reschedule Start Date")
] = None
max_tries: Annotated[int, Field(title="Max Tries")]
variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None
connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None
Expand Down
6 changes: 6 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,12 @@ def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None:
# If the task has not been rescheduled, there is no need to ask the supervisor
return None

first_task_reschedule_start_date = getattr(
self._ti_context_from_server, "first_task_reschedule_start_date", None
)
if first_task_reschedule_start_date is not None:
return first_task_reschedule_start_date

max_tries: int = self.max_tries
retries: int = self.task.retries or 0
first_try_number = max_tries - retries + 1
Expand Down
6 changes: 6 additions & 0 deletions task-sdk/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def __call__(
run_after: str | datetime = ...,
run_type: str = ...,
task_reschedule_count: int = ...,
first_task_reschedule_start_date: str | datetime | None = ...,
conf: dict[str, Any] | None = ...,
should_retry: bool = ...,
max_tries: int = ...,
Expand All @@ -249,6 +250,7 @@ def __call__(
run_after: str | datetime = ...,
run_type: str = ...,
task_reschedule_count: int = ...,
first_task_reschedule_start_date: str | datetime | None = ...,
conf=None,
consumed_asset_events: Sequence[AssetEventDagRunReference] = ...,
) -> dict[str, Any]: ...
Expand All @@ -271,6 +273,7 @@ def _make_context(
run_after: str | datetime = "2024-12-01T01:00:00Z",
run_type: str = "manual",
task_reschedule_count: int = 0,
first_task_reschedule_start_date: str | datetime | None = None,
conf: dict[str, Any] | None = None,
should_retry: bool = False,
max_tries: int = 0,
Expand All @@ -292,6 +295,7 @@ def _make_context(
consumed_asset_events=list(consumed_asset_events),
),
task_reschedule_count=task_reschedule_count,
first_task_reschedule_start_date=first_task_reschedule_start_date,
max_tries=max_tries,
should_retry=should_retry,
)
Expand All @@ -314,6 +318,7 @@ def _make_context_dict(
run_after: str | datetime = "2024-12-01T00:00:00Z",
run_type: str = "manual",
task_reschedule_count: int = 0,
first_task_reschedule_start_date: str | datetime | None = None,
conf=None,
consumed_asset_events: Sequence[AssetEventDagRunReference] = (),
) -> dict[str, Any]:
Expand All @@ -329,6 +334,7 @@ def _make_context_dict(
run_type=run_type,
conf=conf,
task_reschedule_count=task_reschedule_count,
first_task_reschedule_start_date=first_task_reschedule_start_date,
consumed_asset_events=consumed_asset_events,
)
return context.model_dump(exclude_unset=True, mode="json")
Expand Down
18 changes: 17 additions & 1 deletion task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2722,7 +2722,7 @@ def __init__(self, command, *args, **kwargs):
def test_get_first_reschedule_date(
self, create_runtime_ti, mock_supervisor_comms, task_reschedule_count, expected_date
):
"""Test that the first reschedule date is fetched from the Supervisor."""
"""Test that the first reschedule date falls back to the Supervisor."""
task = BaseOperator(task_id="hello")
runtime_ti = create_runtime_ti(task=task, task_reschedule_count=task_reschedule_count)

Expand All @@ -2733,6 +2733,22 @@ def test_get_first_reschedule_date(
context = runtime_ti.get_template_context()
assert runtime_ti.get_first_reschedule_date(context=context) == expected_date

def test_get_first_reschedule_date_uses_context_from_server(
self, create_runtime_ti, make_ti_context, mock_supervisor_comms
):
"""Test that first reschedule date from server context avoids a Supervisor request."""
first_reschedule_date = timezone.datetime(2025, 1, 1)
task = BaseOperator(task_id="hello")
runtime_ti = create_runtime_ti(task=task, task_reschedule_count=1)
runtime_ti._ti_context_from_server = make_ti_context(
task_reschedule_count=1,
first_task_reschedule_start_date=first_reschedule_date,
)

context = runtime_ti.get_template_context()
assert runtime_ti.get_first_reschedule_date(context=context) == first_reschedule_date
mock_supervisor_comms.send.assert_not_called()

def test_get_ti_count(self, mock_supervisor_comms):
"""Test that get_ti_count sends the correct request and returns the count."""
mock_supervisor_comms.send.return_value = TICount(count=2)
Expand Down
Loading