Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 88 additions & 24 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,18 @@ class ActivitySubprocess(WatchedSubprocess):

_terminal_state: str | None = attrs.field(default=None, init=False)
_final_state: str | None = attrs.field(default=None, init=False)
# The terminal-state message currently being processed by `_handle_request`,
# captured BEFORE the dedicated API call (succeed / retry / defer /
# reschedule). If the API call raises (network blip, server 5xx, etc.),
# this attribute stays set and the dispatcher in
# `update_task_state_if_needed` re-issues the matching API call on
# subprocess exit — re-attempting the original transition rather than
# falling back to `finish()`, which doesn't accept SUCCESS / DEFERRED /
# SERVER_TERMINATED on the server side. Cleared (and `_terminal_state`
# set) only after the API call returns successfully.
_pending_terminal_state_msg: SucceedTask | RetryTask | DeferTask | RescheduleTask | None = attrs.field(
default=None, init=False
)

_last_successful_heartbeat: float = attrs.field(default=0, init=False)
_last_heartbeat_attempt: float = attrs.field(default=0, init=False)
Expand Down Expand Up @@ -1269,10 +1281,23 @@ def wait(self) -> int:
return self._exit_code

def update_task_state_if_needed(self):
# If the process has finished non-directly patched state (directly means deferred, reschedule, etc.),
# update the state of the TaskInstance to reflect the final state of the process.
# For states like `deferred`, `up_for_reschedule`, the process will exit with 0, but the state will be updated
# by the subprocess in the `handle_requests` method.
# If a direct-state API call (succeed / retry / defer / reschedule)
# was attempted but raised, `_pending_terminal_state_msg` still holds
# the original request. Re-issue the matching dedicated API call so
# the server learns the terminal state we couldn't deliver earlier.
# Without this recovery, a transient API failure during the direct
# call would leave the TI stuck RUNNING on the server — `finish()`
# cannot substitute because the server-side `finish` endpoint does
# not accept SUCCESS / DEFERRED / SERVER_TERMINATED transitions.
if self._pending_terminal_state_msg is not None:
self._replay_pending_terminal_state_msg()
return

# If the process has finished a non-directly-patched state (e.g.
# FAILED, UP_FOR_RETRY without RetryTask), `finish()` is the
# dedicated endpoint for those transitions. For states already in
# STATES_SENT_DIRECTLY whose direct API call succeeded, no further
# action is needed.
if self.final_state not in STATES_SENT_DIRECTLY:
self.client.task_instances.finish(
id=self.id,
Expand All @@ -1281,6 +1306,58 @@ def update_task_state_if_needed(self):
rendered_map_index=self._rendered_map_index,
)

def _send_terminal_state_msg(self, msg: SucceedTask | RetryTask | DeferTask | RescheduleTask) -> None:
# Capture the message BEFORE the API call so the recovery dispatcher
# in `update_task_state_if_needed` can re-issue it if the call raises
# (network blip, transient server 5xx). Clear the pending slot and
# record the resulting state only after the call returns successfully.
self._pending_terminal_state_msg = msg
if isinstance(msg, SucceedTask):
self.client.task_instances.succeed(
id=self.id,
when=msg.end_date,
task_outlets=msg.task_outlets,
outlet_events=msg.outlet_events,
rendered_map_index=self._rendered_map_index,
)
self._terminal_state = msg.state
elif isinstance(msg, RetryTask):
self.client.task_instances.retry(
id=self.id,
end_date=msg.end_date,
rendered_map_index=self._rendered_map_index,
retry_delay_seconds=getattr(msg, "retry_delay_seconds", None),
retry_reason=getattr(msg, "retry_reason", None),
)
self._terminal_state = msg.state
elif isinstance(msg, DeferTask):
self.client.task_instances.defer(self.id, msg)
self._terminal_state = TaskInstanceState.DEFERRED
elif isinstance(msg, RescheduleTask):
self.client.task_instances.reschedule(self.id, msg)
self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE
self._pending_terminal_state_msg = None

def _replay_pending_terminal_state_msg(self) -> None:
"""
Re-issue the dedicated API call for an unsynced terminal-state msg.

Best-effort — if the second attempt also fails the exception is
logged and we move on; the supervisor's overall failure handling
(heartbeat, exit-code reporting) will eventually surface the issue.
"""
msg = self._pending_terminal_state_msg
if msg is None:
return
try:
self._send_terminal_state_msg(msg)
except Exception:
log.exception(
"Recovery retry of terminal-state API call failed; TI may be stuck on the server",
ti_id=self.id,
msg_type=type(msg).__name__,
)

def _upload_logs(self):
"""
Upload all log files found to the remote storage.
Expand Down Expand Up @@ -1452,31 +1529,20 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:
resp: BaseModel | None = None
dump_opts: dict[str, bool] = {}
if isinstance(msg, TaskState):
# No direct API call here — the recovery path in
# `update_task_state_if_needed` will call `finish()` for
# non-direct states (FAILED, etc.) once the subprocess exits.
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
self._rendered_map_index = msg.rendered_map_index
elif isinstance(msg, SucceedTask):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
self._rendered_map_index = msg.rendered_map_index
self.client.task_instances.succeed(
id=self.id,
when=msg.end_date,
task_outlets=msg.task_outlets,
outlet_events=msg.outlet_events,
rendered_map_index=self._rendered_map_index,
)
self._send_terminal_state_msg(msg)
elif isinstance(msg, RetryTask):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
self._rendered_map_index = msg.rendered_map_index
self.client.task_instances.retry(
id=self.id,
end_date=msg.end_date,
rendered_map_index=self._rendered_map_index,
retry_delay_seconds=getattr(msg, "retry_delay_seconds", None),
retry_reason=getattr(msg, "retry_reason", None),
)
self._send_terminal_state_msg(msg)
elif isinstance(msg, GetConnection):
resp, dump_opts = handle_get_connection(self.client, msg)
elif isinstance(msg, GetVariable):
Expand Down Expand Up @@ -1512,12 +1578,10 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:
)
resp = XComSequenceSliceResult.from_response(xcoms)
elif isinstance(msg, DeferTask):
self._terminal_state = TaskInstanceState.DEFERRED
self._rendered_map_index = msg.rendered_map_index
self.client.task_instances.defer(self.id, msg)
self._send_terminal_state_msg(msg)
elif isinstance(msg, RescheduleTask):
self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE
self.client.task_instances.reschedule(self.id, msg)
self._send_terminal_state_msg(msg)
elif isinstance(msg, SkipDownstreamTasks):
self.client.task_instances.skip_downstream_tasks(self.id, msg)
elif isinstance(msg, SetXCom):
Expand Down
131 changes: 131 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3079,6 +3079,137 @@ def test_handle_requests_network_exception_does_not_crash_loop(self, watched_sub
# Should not raise StopIteration (which would mean the loop crashed).
generator.send(req2)

@pytest.mark.parametrize(
("msg", "api_method", "expected_state"),
[
pytest.param(
SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
"succeed",
TaskInstanceState.SUCCESS,
id="succeed",
),
pytest.param(
RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
"retry",
TaskInstanceState.UP_FOR_RETRY,
id="retry",
),
pytest.param(
DeferTask(
next_method="execute_complete",
classpath="airflow.providers.standard.triggers.external_task.WorkflowTrigger",
trigger_kwargs={},
),
"defer",
TaskInstanceState.DEFERRED,
id="defer",
),
pytest.param(
RescheduleTask(
reschedule_date=timezone.parse("2024-10-31T12:00:00Z"),
end_date=timezone.parse("2024-10-31T12:00:00Z"),
),
"reschedule",
TaskInstanceState.UP_FOR_RESCHEDULE,
id="reschedule",
),
],
)
def test_terminal_state_not_set_when_direct_api_fails(
self, watched_subprocess, mocker, msg, api_method, expected_state
):
"""`_terminal_state` must NOT be set when the dedicated terminal-state
API raises.

The original message is captured in `_pending_terminal_state_msg`
BEFORE the API call so the recovery dispatcher in
`update_task_state_if_needed` can re-issue it on subprocess exit.
Covers all four terminal-state message types.
"""
watched_subprocess, _ = watched_subprocess
setattr(
watched_subprocess.client.task_instances,
api_method,
mocker.Mock(side_effect=httpx.ConnectError("connection refused")),
)

with pytest.raises(httpx.ConnectError):
watched_subprocess._handle_request(msg, mocker.Mock(), req_id=1)

assert watched_subprocess._terminal_state is None
# Pending msg preserved so the recovery dispatcher can re-issue.
assert watched_subprocess._pending_terminal_state_msg is msg

@pytest.mark.parametrize(
("msg", "api_method", "expected_state"),
[
pytest.param(
SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
"succeed",
TaskInstanceState.SUCCESS,
id="succeed",
),
pytest.param(
RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
"retry",
TaskInstanceState.UP_FOR_RETRY,
id="retry",
),
pytest.param(
DeferTask(
next_method="execute_complete",
classpath="airflow.providers.standard.triggers.external_task.WorkflowTrigger",
trigger_kwargs={},
),
"defer",
TaskInstanceState.DEFERRED,
id="defer",
),
pytest.param(
RescheduleTask(
reschedule_date=timezone.parse("2024-10-31T12:00:00Z"),
end_date=timezone.parse("2024-10-31T12:00:00Z"),
),
"reschedule",
TaskInstanceState.UP_FOR_RESCHEDULE,
id="reschedule",
),
],
)
def test_update_task_state_replays_pending_terminal_state_call(
self, watched_subprocess, mocker, msg, api_method, expected_state
):
"""If a direct terminal-state API call was attempted and raised, the
recovery dispatcher must re-issue the dedicated endpoint (not
`finish()`, which the server-side endpoint refuses for SUCCESS /
DEFERRED / SERVER_TERMINATED). Covers all four message types.
"""
watched_subprocess, _ = watched_subprocess
watched_subprocess._exit_code = 0
# Simulate the failure scenario: original API call raised, msg preserved.
watched_subprocess._pending_terminal_state_msg = msg

watched_subprocess.update_task_state_if_needed()

# Recovery re-issues the dedicated endpoint, NOT finish().
getattr(watched_subprocess.client.task_instances, api_method).assert_called_once()
watched_subprocess.client.task_instances.finish.assert_not_called()
assert watched_subprocess._terminal_state == expected_state
assert watched_subprocess._pending_terminal_state_msg is None

def test_update_task_state_no_recovery_without_pending_msg(self, watched_subprocess, mocker):
"""No replay when nothing was pending — preserves the original
STATES_SENT_DIRECTLY short-circuit for the happy path."""
watched_subprocess, _ = watched_subprocess
watched_subprocess._exit_code = 0
watched_subprocess._terminal_state = TaskInstanceState.SUCCESS
watched_subprocess._pending_terminal_state_msg = None

watched_subprocess.update_task_state_if_needed()

watched_subprocess.client.task_instances.finish.assert_not_called()
watched_subprocess.client.task_instances.succeed.assert_not_called()


class TestSetSupervisorComms:
class DummyComms:
Expand Down
Loading