Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING, Any

from botocore.exceptions import ClientError
Expand All @@ -29,8 +28,6 @@
from airflow.triggers.base import TriggerEvent

if TYPE_CHECKING:
from pendulum import DateTime

from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook


Expand Down Expand Up @@ -118,49 +115,11 @@ def __init__(
eks_cluster_name: str,
aws_conn_id: str | None = None,
region: str | None = None,
pod_name: str,
pod_namespace: str,
trigger_start_time: datetime.datetime,
base_container_name: str,
kubernetes_conn_id: str | None = None,
connection_extras: dict | None = None,
poll_interval: float = 2,
cluster_context: str | None = None,
config_dict: dict | None = None,
in_cluster: bool | None = None,
get_logs: bool = True,
startup_timeout: int = 120,
startup_check_interval: float = 5,
schedule_timeout: int = 120,
on_finish_action: str = "delete_pod",
on_kill_action: str = "delete_pod",
termination_grace_period: int | None = None,
last_log_time: DateTime | None = None,
logging_interval: int | None = None,
trigger_kwargs: dict | None = None,
**kwargs,
):
super().__init__(
pod_name=pod_name,
pod_namespace=pod_namespace,
trigger_start_time=trigger_start_time,
base_container_name=base_container_name,
kubernetes_conn_id=kubernetes_conn_id,
connection_extras=connection_extras,
poll_interval=poll_interval,
cluster_context=cluster_context,
config_dict=config_dict,
in_cluster=in_cluster,
get_logs=get_logs,
startup_timeout=startup_timeout,
startup_check_interval=startup_check_interval,
schedule_timeout=schedule_timeout,
on_finish_action=on_finish_action,
on_kill_action=on_kill_action,
termination_grace_period=termination_grace_period,
last_log_time=last_log_time,
logging_interval=logging_interval,
trigger_kwargs=trigger_kwargs,
)
# Forward base-trigger kwargs through ``**kwargs`` rather than
# listing each one explicitly.
super().__init__(**kwargs)
self.eks_cluster_name = eks_cluster_name
self._aws_conn_id = aws_conn_id
self.region = region
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import re
import shlex
import string
import time
from collections.abc import Callable, Container, Iterable, Mapping, Sequence
from contextlib import AbstractContextManager, suppress
from enum import Enum
Expand Down Expand Up @@ -908,6 +909,28 @@ def invoke_defer_method(

trigger_start_time = datetime.datetime.now(tz=datetime.timezone.utc)

# Translate ``execution_timeout`` into an absolute deadline plumbed
# into the trigger. Anchoring on ``ti.start_date`` keeps the deadline
# stable across re-deferrals (``logging_interval`` re-entries), since
# Airflow preserves the original ``start_date`` when a task resumes
# from defer.
execution_deadline: int | None = None
defer_timeout: datetime.timedelta | None = None
if self.execution_timeout is not None and context is not None:
ti_start_date = context["ti"].start_date
execution_deadline = int(ti_start_date.timestamp() + self.execution_timeout.total_seconds())
# ``defer.timeout`` bounds the trigger's lifetime via the
# framework's ``trigger_timeout``. Clamp to a 60s minimum buffer:
# the trigger's first-iteration deadline check fires within
# ``poll_interval`` seconds and emits the operator-handled
# ``status="timeout"`` event, which runs ``_clean()`` and deletes
# the pod.
remaining = execution_deadline - time.time()
defer_timeout = max(
datetime.timedelta(seconds=remaining),
datetime.timedelta(seconds=60),
)

trigger = KubernetesPodTrigger(
pod_name=self.pod.metadata.name, # type: ignore[union-attr]
pod_namespace=self.pod.metadata.namespace, # type: ignore[union-attr]
Expand All @@ -929,6 +952,7 @@ def invoke_defer_method(
last_log_time=last_log_time,
logging_interval=self.logging_interval,
trigger_kwargs=self.trigger_kwargs,
execution_deadline=execution_deadline,
)
pod_container_state = trigger.define_pod_container_state(self.pod) if self.pod else None
if context and (
Expand All @@ -949,7 +973,7 @@ def invoke_defer_method(
},
)
else:
self.defer(trigger=trigger, method_name="trigger_reentry")
self.defer(trigger=trigger, method_name="trigger_reentry", timeout=defer_timeout)

def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
"""
Expand Down Expand Up @@ -1018,8 +1042,13 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
)
self.trigger_kwargs = dict(self.trigger_kwargs or {})
self.trigger_kwargs["_redefer_count"] = redefer_count + 1
# Re-pass ``context`` so ``invoke_defer_method`` can recompute the
# ``execution_deadline`` for this re-deferral. ``ti.start_date`` is
# preserved across resumes, so the deadline stays anchored to the
# original task start.
self.invoke_defer_method(
last_log_time=last_log_time,
context=context,
)
# invoke_defer_method raises TaskDeferred, execution does not continue here

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import asyncio
import contextlib
import datetime
import time
import traceback
from collections.abc import AsyncIterator
from enum import Enum
Expand Down Expand Up @@ -99,6 +100,13 @@ class KubernetesPodTrigger(BaseTrigger):
the operator to print latest logs. If ``None`` will wait until container done.
:param last_log_time: where to resume logs from
:param trigger_kwargs: additional keyword parameters to send in the event
:param execution_deadline: Optional absolute timestamp (integer Unix epoch seconds)
after which the trigger emits a ``timeout`` event so the operator can fail the
task and delete the pod. Checked at the top of ``run()``, around
``_wait_for_pod_start()``, and on every iteration of
``_wait_for_container_completion()`` so the deadline is enforced regardless of
which phase the pod is in. Used to enforce ``execution_timeout`` semantics for
deferred tasks.
"""

def __init__(
Expand All @@ -123,6 +131,7 @@ def __init__(
last_log_time: DateTime | None = None,
logging_interval: int | None = None,
trigger_kwargs: dict | None = None,
execution_deadline: int | None = None,
):
super().__init__()
self.pod_name = pod_name
Expand All @@ -145,6 +154,7 @@ def __init__(
self.on_kill_action = OnKillAction(on_kill_action)
self.termination_grace_period = termination_grace_period
self.trigger_kwargs = trigger_kwargs or {}
self.execution_deadline = execution_deadline
self._fired_event = False
self._since_time = None

Expand Down Expand Up @@ -173,6 +183,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"last_log_time": self.last_log_time,
"logging_interval": self.logging_interval,
"trigger_kwargs": self.trigger_kwargs,
"execution_deadline": self.execution_deadline,
},
)

Expand All @@ -184,8 +195,26 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
self.pod_namespace,
self.poll_interval,
)
# Fast-path the timeout when the deadline has already elapsed before
# the trigger even starts (e.g. a long-paused triggerer queue, or a
# re-defer after the deadline passed).
if self.execution_deadline is not None and time.time() >= self.execution_deadline:
self._fired_event = True
yield TriggerEvent(
{
"status": "timeout",
"namespace": self.pod_namespace,
"name": self.pod_name,
"message": (
f"Pod {self.pod_namespace}/{self.pod_name} reached the task's "
"execution_timeout deadline before the trigger could begin polling."
),
**self.trigger_kwargs,
}
)
return
try:
state = await self._wait_for_pod_start()
state = await self._wait_for_pod_start_within_deadline()
if state == ContainerState.TERMINATED:
event = TriggerEvent(
{
Expand Down Expand Up @@ -272,6 +301,35 @@ def _format_exception_description(self, exc: Exception) -> Any:
description += f"\ntrigger traceback:\n{curr_traceback}"
return description

async def _wait_for_pod_start_within_deadline(self) -> ContainerState:
"""
Run ``_wait_for_pod_start`` bounded by ``execution_deadline``.

Wraps the underlying call in :func:`asyncio.wait_for` when an
``execution_deadline`` is set so the startup phase honours
``execution_timeout`` too — otherwise a Pending pod would not time
out until ``startup_timeout`` (default 120s) regardless of how
short the user's ``execution_timeout`` was. On timeout we raise
:class:`PodLaunchTimeoutException` so the existing handler in
:meth:`run` emits the operator's expected ``status="timeout"``
event.
"""
if self.execution_deadline is None:
return await self._wait_for_pod_start()
remaining = self.execution_deadline - time.time()
if remaining <= 0:
raise PodLaunchTimeoutException(
f"Pod {self.pod_namespace}/{self.pod_name} reached the task's "
"execution_timeout deadline before the pod left the Pending phase."
)
try:
return await asyncio.wait_for(self._wait_for_pod_start(), timeout=remaining)
except asyncio.TimeoutError as exc:
raise PodLaunchTimeoutException(
f"Pod {self.pod_namespace}/{self.pod_name} reached the task's "
"execution_timeout deadline while waiting for the pod to start."
) from exc

async def _wait_for_pod_start(self) -> ContainerState:
"""Loops until pod phase leaves ``PENDING`` If timeout is reached, throws error."""
pod = await self._get_pod()
Expand Down Expand Up @@ -306,6 +364,27 @@ async def _wait_for_container_completion(self) -> TriggerEvent:
if self.logging_interval is not None:
time_get_more_logs = time_begin + datetime.timedelta(seconds=self.logging_interval)
while True:
# ``execution_deadline`` is the operator's translation of the
# task-level ``execution_timeout`` into an absolute UTC timestamp
if self.execution_deadline is not None and time.time() >= self.execution_deadline:
self.log.info(
"Execution deadline reached for pod %s/%s — emitting timeout event.",
self.pod_namespace,
self.pod_name,
)
return TriggerEvent(
{
"status": "timeout",
"namespace": self.pod_namespace,
"name": self.pod_name,
"message": (
f"Pod {self.pod_namespace}/{self.pod_name} reached the task's "
"execution_timeout deadline."
),
"last_log_time": self.last_log_time,
**self.trigger_kwargs,
}
)
pod = await self._get_pod()
pod_container_state = self.define_pod_container_state(pod)
if pod_container_state == ContainerState.TERMINATED:
Expand Down
Loading
Loading