Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.executors.base_executor import BaseExecutor
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.cncf.kubernetes.exceptions import PodMutationHookException, PodReconciliationError
from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_types import (
ADOPTED,
Expand All @@ -51,8 +52,8 @@
from airflow.providers.cncf.kubernetes.kube_config import KubeConfig
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import annotations_to_key
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.providers.common.compat.sdk import Stats, conf
from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_3_PLUS
from airflow.providers.common.compat.sdk import Stats, conf, timezone
from airflow.utils.log.logging_mixin import remove_escape_codes
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import TaskInstanceState
Expand All @@ -66,8 +67,8 @@

from airflow.cli.cli_config import GroupCommand
from airflow.executors import workloads
from airflow.executors.workloads.types import WorkloadKey
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure this needs a back-compat wrapper. Check in the ECS executor they did:

     if AIRFLOW_V_3_3_PLUS:
         from airflow.executors.workloads.types import WorkloadKey as _EcsWorkloadKey
         WorkloadKey: TypeAlias = _EcsWorkloadKey
     else:
         WorkloadKey: TypeAlias = TaskInstanceKey

from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils import (
AirflowKubernetesScheduler,
)
Expand All @@ -78,6 +79,7 @@ class KubernetesExecutor(BaseExecutor):

RUNNING_POD_LOG_LINES = 100
supports_ad_hoc_ti_run: bool = True
supports_callbacks: bool = AIRFLOW_V_3_3_PLUS
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is clever but I don't know if it actually works. In the other executors we used

     if AIRFLOW_V_3_3_PLUS:
         supports_callbacks: bool = True

Which leaves supports_callbacks undefined rather than False.

supports_multi_team: bool = True

if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS:
Expand Down Expand Up @@ -110,7 +112,7 @@ def __init__(self, *args, **kwargs):
self._last_completed_pod_adoption = 0.0
self.last_handled: dict[TaskInstanceKey, float] = {}
self.kubernetes_queue: str | None = None
self.task_publish_retries: Counter[TaskInstanceKey] = Counter()
self.task_publish_retries: Counter[WorkloadKey] = Counter()
self.task_publish_max_retries = self.conf.getint(
"kubernetes_executor", "task_publish_max_retries", fallback=0
)
Expand Down Expand Up @@ -232,31 +234,30 @@ def execute_async(
# try and remove it from the QUEUED state while we process it
self.last_handled[key] = time.time()

def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
from airflow.executors import workloads

if not isinstance(workload, workloads.ExecuteTask):
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")
ti = workload.ti
self.queued_tasks[ti.key] = workload
Comment on lines -235 to -241
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ECS and Celery both had to keep this for back-compat and added a note # TODO: Remove this once the minimum supported version is 3.3+, and defer to BaseExecutor.queue_workload. I'm reasonably certain that should be applied here as well unless you know of a reason to drop it that I've missed.


def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
from airflow.executors.workloads import ExecuteTask

# Airflow V3 version
for w in workloads:
if not isinstance(w, ExecuteTask):
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}")

# TODO: AIP-72 handle populating tokens once https://github.com/apache/airflow/issues/45107 is handled.
command = [w]
key = w.ti.key
queue = w.ti.queue
executor_config = w.ti.executor_config or {}

del self.queued_tasks[key]
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)
self.running.add(key)
def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None:
from airflow.executors.workloads import ExecuteCallback, ExecuteTask
from airflow.utils.state import CallbackState

for workload in workload_items:
if isinstance(workload, ExecuteTask):
# TODO: AIP-72 handle populating tokens once https://github.com/apache/airflow/issues/45107 is handled.
command = [workload]
key = workload.ti.key
queue = workload.ti.queue
executor_config = workload.ti.executor_config or {}

del self.queued_tasks[key]
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)
self.running.add(key)
elif isinstance(workload, ExecuteCallback):
callback_key = workload.callback.key
del self.queued_callbacks[callback_key]
# Put on task_queue for pod creation (no executor_config for callbacks)
self.task_queue.put(KubernetesJob(callback_key, [workload], None, None))
self.event_buffer[callback_key] = (CallbackState.QUEUED, None)
self.running.add(callback_key)
else:
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")

def sync(self) -> None:
"""Synchronize task state."""
Expand Down Expand Up @@ -332,7 +333,7 @@ def sync(self) -> None:
"Pod reconciliation failed, likely due to kubernetes library upgrade. "
"Try clearing the task to re-run.",
)
self.fail(task[0], e)
self.fail(task.key, e)
except ApiException as e:
try:
if e.body:
Expand Down Expand Up @@ -384,7 +385,6 @@ def sync(self) -> None:
break
else:
self.log.error("Pod creation failed with reason %r. Failing task", e.reason)
key = task.key
self.fail(key, e)
self.task_publish_retries.pop(key, None)
except PodMutationHookException as e:
Expand All @@ -404,11 +404,16 @@ def _change_state(
results: KubernetesResults,
session: Session = NEW_SESSION,
) -> None:
"""Change state of the task based on KubernetesResults."""
"""Change state of the workload based on KubernetesResults."""
if TYPE_CHECKING:
assert self.kube_scheduler

key = results.key

if not isinstance(key, TaskInstanceKey):
self._change_callback_state(results)
return

state = results.state
pod_name = results.pod_name
namespace = results.namespace
Expand Down Expand Up @@ -500,6 +505,51 @@ def _change_state(

self.event_buffer[key] = state, termination_reason

def _change_callback_state(self, results: KubernetesResults) -> None:
"""Change state of a callback based on KubernetesResults."""
from airflow.utils.state import CallbackState

if TYPE_CHECKING:
assert self.kube_scheduler

key = results.key
state = results.state
pod_name = results.pod_name
namespace = results.namespace

if state == ADOPTED:
self.running.discard(key)
return

if state == TaskInstanceState.FAILED:
self.log.warning("Callback %s failed in pod %s/%s", key, namespace, pod_name)

# Clean up pod
if self.kube_config.delete_worker_pods:
if state != TaskInstanceState.FAILED or self.kube_config.delete_worker_pods_on_failure:
self.kube_scheduler.delete_pod(pod_name=pod_name, namespace=namespace)
self.log.info(
"Deleted pod for callback %s. Pod name: %s. Namespace: %s",
key,
pod_name,
namespace,
)
else:
self.kube_scheduler.patch_pod_executor_done(pod_name=pod_name, namespace=namespace)

try:
self.running.remove(key)
except KeyError:
self.log.debug("Callback key not in running: %s", key)
return

# Map pod state to CallbackState
if state == TaskInstanceState.FAILED:
self.event_buffer[key] = CallbackState.FAILED, None
else:
# Pod succeeded (state is None for successful pods in K8s executor)
self.event_buffer[key] = CallbackState.SUCCESS, None

def _get_pod_namespace(self, ti: TaskInstance):
pod_override = (ti.executor_config or {}).get("pod_override")
namespace = None
Expand Down Expand Up @@ -661,25 +711,42 @@ def adopt_launched_task(

self.log.info("attempting to adopt pod %s", pod.metadata.name)
ti_key = annotations_to_key(pod.metadata.annotations)

if not isinstance(ti_key, TaskInstanceKey):
# Callback pod — re-adopt by patching the worker label so the new
# watcher can track it; no tis_to_flush_by_key entry exists for callbacks.
if self._patch_pod_worker_label(kube_client, pod):
self.running.add(ti_key)
return

if ti_key not in tis_to_flush_by_key:
self.log.error("attempting to adopt taskinstance which was not specified by database: %s", ti_key)
return

new_worker_id_label = self._make_safe_label_value(self.scheduler_job_id)
if not self._patch_pod_worker_label(kube_client, pod):
return

del tis_to_flush_by_key[ti_key]
self.running.add(ti_key)

def _patch_pod_worker_label(self, kube_client: client.CoreV1Api, pod: k8s.V1Pod) -> bool:
"""Patch the airflow-worker label on a pod to claim it for the current scheduler."""
if TYPE_CHECKING:
assert self.scheduler_job_id

from kubernetes.client.rest import ApiException

new_worker_id_label = self._make_safe_label_value(self.scheduler_job_id)
try:
kube_client.patch_namespaced_pod(
name=pod.metadata.name,
namespace=pod.metadata.namespace,
body={"metadata": {"labels": {"airflow-worker": new_worker_id_label}}},
)
return True
except ApiException as e:
self.log.info("Failed to adopt pod %s. Reason: %s", pod.metadata.name, e)
return

del tis_to_flush_by_key[ti_key]
self.running.add(ti_key)
return False

def _alive_other_scheduler_job_ids(self) -> set[int]:
"""
Expand Down Expand Up @@ -715,7 +782,6 @@ def _alive_other_scheduler_job_ids(self) -> set[int]:
from sqlalchemy import select

from airflow.jobs.job import Job
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import JobState

Expand Down Expand Up @@ -793,16 +859,7 @@ def _adopt_completed_pods(self, kube_client: client.CoreV1Api) -> None:
pod_list = self._list_pods(query_kwargs)
for pod in pod_list:
self.log.info("Attempting to adopt pod %s", pod.metadata.name)
from kubernetes.client.rest import ApiException

try:
kube_client.patch_namespaced_pod(
name=pod.metadata.name,
namespace=pod.metadata.namespace,
body={"metadata": {"labels": {"airflow-worker": self_label}}},
)
except ApiException as e:
self.log.info("Failed to adopt pod %s. Reason: %s", pod.metadata.name, e)
if not self._patch_pod_worker_label(kube_client, pod):
continue

ti_id = annotations_to_key(pod.metadata.annotations)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
if TYPE_CHECKING:
from collections.abc import Sequence

from airflow.models.taskinstance import TaskInstanceKey
from airflow.executors.workloads.types import WorkloadKey
from airflow.utils.state import TaskInstanceState


Expand All @@ -43,9 +43,9 @@ class FailureDetails(TypedDict, total=False):


class KubernetesResults(NamedTuple):
"""Results from Kubernetes task execution."""
"""Results from Kubernetes workload execution (task or callback)."""

key: TaskInstanceKey
key: WorkloadKey
state: TaskInstanceState | str | None
pod_name: str
namespace: str
Expand All @@ -69,10 +69,10 @@ class KubernetesWatch(NamedTuple):


class KubernetesJob(NamedTuple):
"""Job definition for Kubernetes execution."""
"""Job definition for Kubernetes execution (task or callback)."""

key: TaskInstanceKey
command: Sequence[str]
key: WorkloadKey
command: Sequence[Any]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I liked the way the ECS Executor handled this; they did

  if AIRFLOW_V_3_3_PLUS:
      CommandType: TypeAlias = Sequence[str] | Sequence[ExecuteTask] | Sequence[ExecuteCallback]
  else:
      CommandType: TypeAlias = Sequence[str]

then defined command as a CommandType instead of Sequence[All]

kube_executor_config: Any
pod_template_file: str | None

Expand Down
Loading
Loading