Skip to content

Commit

Permalink
Revert "KPO Maintain backward compatibility for execute_complete and …
Browse files Browse the repository at this point in the history
…trigger run method (#37363)" (#37446)

This reverts commit 0640e6d.
  • Loading branch information
potiuk committed Feb 15, 2024
1 parent df132b2 commit 0be6430
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 208 deletions.
150 changes: 90 additions & 60 deletions airflow/providers/cncf/kubernetes/operators/pod.py
Expand Up @@ -18,7 +18,6 @@

from __future__ import annotations

import datetime
import json
import logging
import re
Expand All @@ -31,7 +30,6 @@
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence

import kubernetes
from deprecated import deprecated
from kubernetes.client import CoreV1Api, V1Pod, models as k8s
from kubernetes.stream import stream
from urllib3.exceptions import HTTPError
Expand Down Expand Up @@ -70,6 +68,7 @@
EMPTY_XCOM_RESULT,
OnFinishAction,
PodLaunchFailedException,
PodLaunchTimeoutException,
PodManager,
PodNotFoundException,
PodOperatorHookProtocol,
Expand All @@ -80,6 +79,7 @@
from airflow.settings import pod_mutation_hook
from airflow.utils import yaml
from airflow.utils.helpers import prune_dict, validate_key
from airflow.utils.timezone import utcnow
from airflow.version import version as airflow_version

if TYPE_CHECKING:
Expand Down Expand Up @@ -656,7 +656,7 @@ def execute_async(self, context: Context):

def invoke_defer_method(self, last_log_time: DateTime | None = None):
"""Redefine triggers which are being used in child classes."""
trigger_start_time = datetime.datetime.now(tz=datetime.timezone.utc)
trigger_start_time = utcnow()
self.defer(
trigger=KubernetesPodTrigger(
pod_name=self.pod.metadata.name, # type: ignore[union-attr]
Expand All @@ -678,87 +678,117 @@ def invoke_defer_method(self, last_log_time: DateTime | None = None):
method_name="trigger_reentry",
)

@staticmethod
def raise_for_trigger_status(event: dict[str, Any]) -> None:
"""Raise exception if pod is not in expected state."""
if event["status"] == "error":
error_type = event["error_type"]
description = event["description"]
if error_type == "PodLaunchTimeoutException":
raise PodLaunchTimeoutException(description)
else:
raise AirflowException(description)

def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
"""
Point of re-entry from trigger.
If ``logging_interval`` is None, then at this point, the pod should be done, and we'll just fetch
If ``logging_interval`` is None, then at this point the pod should be done and we'll just fetch
the logs and exit.
If ``logging_interval`` is not None, it could be that the pod is still running, and we'll just
If ``logging_interval`` is not None, it could be that the pod is still running and we'll just
grab the latest logs and defer back to the trigger again.
"""
self.pod = None
remote_pod = None
try:
pod_name = event["name"]
pod_namespace = event["namespace"]
self.pod_request_obj = self.build_pod_request_obj(context)
self.pod = self.find_pod(
namespace=self.namespace or self.pod_request_obj.metadata.namespace,
context=context,
)

self.pod = self.hook.get_pod(pod_name, pod_namespace)
# we try to find pod before possibly raising so that on_kill will have `pod` attr
self.raise_for_trigger_status(event)

if not self.pod:
raise PodNotFoundException("Could not find pod after resuming from deferral")

if self.callbacks and event["status"] != "running":
self.callbacks.on_operator_resuming(
pod=self.pod, event=event, client=self.client, mode=ExecutionMode.SYNC
if self.get_logs:
last_log_time = event and event.get("last_log_time")
if last_log_time:
self.log.info("Resuming logs read from time %r", last_log_time)
pod_log_status = self.pod_manager.fetch_container_logs(
pod=self.pod,
container_name=self.BASE_CONTAINER_NAME,
follow=self.logging_interval is None,
since_time=last_log_time,
)
if pod_log_status.running:
self.log.info("Container still running; deferring again.")
self.invoke_defer_method(pod_log_status.last_log_time)

if self.do_xcom_push:
result = self.extract_xcom(pod=self.pod)
remote_pod = self.pod_manager.await_pod_completion(self.pod)
except TaskDeferred:
raise
except Exception:
self.cleanup(
pod=self.pod or self.pod_request_obj,
remote_pod=remote_pod,
)
raise
self.cleanup(
pod=self.pod or self.pod_request_obj,
remote_pod=remote_pod,
)
if self.do_xcom_push:
return result

def execute_complete(self, context: Context, event: dict, **kwargs):
self.log.debug("Triggered with event: %s", event)
pod = None
try:
pod = self.hook.get_pod(
event["name"],
event["namespace"],
)
if self.callbacks:
self.callbacks.on_operator_resuming(
pod=pod, event=event, client=self.client, mode=ExecutionMode.SYNC
)
if event["status"] in ("error", "failed", "timeout"):
# fetch some logs when pod is failed
if self.get_logs:
self.write_logs(pod)
if "stack_trace" in event:
message = f"{event['message']}\n{event['stack_trace']}"
else:
message = event["message"]
if self.do_xcom_push:
_ = self.extract_xcom(pod=self.pod)

message = event.get("stack_trace", event["message"])
# In the event of base container failure, we need to kill the xcom sidecar.
# We disregard xcom output and do that here
_ = self.extract_xcom(pod=pod)
raise AirflowException(message)

elif event["status"] == "running":
elif event["status"] == "success":
# fetch some logs when pod is executed successfully
if self.get_logs:
last_log_time = event.get("last_log_time")
self.log.info("Resuming logs read from time %r", last_log_time)

pod_log_status = self.pod_manager.fetch_container_logs(
pod=self.pod,
container_name=self.BASE_CONTAINER_NAME,
follow=self.logging_interval is None,
since_time=last_log_time,
)
self.write_logs(pod)

if pod_log_status.running:
self.log.info("Container still running; deferring again.")
self.invoke_defer_method(pod_log_status.last_log_time)
else:
self.invoke_defer_method()

elif event["status"] == "success":
if self.do_xcom_push:
xcom_sidecar_output = self.extract_xcom(pod=self.pod)
xcom_sidecar_output = self.extract_xcom(pod=pod)
return xcom_sidecar_output
return
except TaskDeferred:
raise
finally:
self._clean(event)

def _clean(self, event: dict[str, Any]):
if event["status"] == "running":
return
if self.get_logs:
self.write_logs(self.pod)
istio_enabled = self.is_istio_enabled(self.pod)
# Skip await_pod_completion when the event is 'timeout' due to the pod can hang
# on the ErrImagePull or ContainerCreating step and it will never complete
if event["status"] != "timeout":
self.pod = self.pod_manager.await_pod_completion(
self.pod, istio_enabled, self.base_container_name
)
if self.pod is not None:
self.post_complete_action(
pod=self.pod,
remote_pod=self.pod,
)

@deprecated(reason="use `trigger_reentry` instead.", category=AirflowProviderDeprecationWarning)
def execute_complete(self, context: Context, event: dict, **kwargs):
self.trigger_reentry(context=context, event=event)
istio_enabled = self.is_istio_enabled(pod)
# Skip await_pod_completion when the event is 'timeout' due to the pod can hang
# on the ErrImagePull or ContainerCreating step and it will never complete
if event["status"] != "timeout":
pod = self.pod_manager.await_pod_completion(pod, istio_enabled, self.base_container_name)
if pod is not None:
self.post_complete_action(
pod=pod,
remote_pod=pod,
)

def write_logs(self, pod: k8s.V1Pod):
try:
Expand Down
70 changes: 17 additions & 53 deletions airflow/providers/cncf/kubernetes/triggers/pod.py
Expand Up @@ -30,8 +30,10 @@
OnFinishAction,
PodLaunchTimeoutException,
PodPhase,
container_is_running,
)
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils import timezone

if TYPE_CHECKING:
from kubernetes_asyncio.client.models import V1Pod
Expand Down Expand Up @@ -158,49 +160,22 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
self.log.info("Checking pod %r in namespace %r.", self.pod_name, self.pod_namespace)
try:
state = await self._wait_for_pod_start()
if state == ContainerState.TERMINATED:
if state in PodPhase.terminal_states:
event = TriggerEvent(
{
"status": "success",
"namespace": self.pod_namespace,
"name": self.pod_name,
"message": "All containers inside pod have started successfully.",
}
)
elif state == ContainerState.FAILED:
event = TriggerEvent(
{
"status": "failed",
"namespace": self.pod_namespace,
"name": self.pod_name,
"message": "pod failed",
}
{"status": "done", "namespace": self.pod_namespace, "pod_name": self.pod_name}
)
else:
event = await self._wait_for_container_completion()
yield event
return
except PodLaunchTimeoutException as e:
message = self._format_exception_description(e)
yield TriggerEvent(
{
"name": self.pod_name,
"namespace": self.pod_namespace,
"status": "timeout",
"message": message,
}
)
except Exception as e:
description = self._format_exception_description(e)
yield TriggerEvent(
{
"name": self.pod_name,
"namespace": self.pod_namespace,
"status": "error",
"message": str(e),
"stack_trace": traceback.format_exc(),
"error_type": e.__class__.__name__,
"description": description,
}
)
return

def _format_exception_description(self, exc: Exception) -> Any:
if isinstance(exc, PodLaunchTimeoutException):
Expand All @@ -214,13 +189,14 @@ def _format_exception_description(self, exc: Exception) -> Any:
description += f"\ntrigger traceback:\n{curr_traceback}"
return description

async def _wait_for_pod_start(self) -> ContainerState:
async def _wait_for_pod_start(self) -> Any:
"""Loops until pod phase leaves ``PENDING`` If timeout is reached, throws error."""
delta = datetime.datetime.now(tz=datetime.timezone.utc) - self.trigger_start_time
while self.startup_timeout >= delta.total_seconds():
start_time = timezone.utcnow()
timeout_end = start_time + datetime.timedelta(seconds=self.startup_timeout)
while timeout_end > timezone.utcnow():
pod = await self.hook.get_pod(self.pod_name, self.pod_namespace)
if not pod.status.phase == "Pending":
return self.define_container_state(pod)
return pod.status.phase
self.log.info("Still waiting for pod to start. The pod state is %s", pod.status.phase)
await asyncio.sleep(self.poll_interval)
raise PodLaunchTimeoutException("Pod did not leave 'Pending' phase within specified timeout")
Expand All @@ -232,30 +208,18 @@ async def _wait_for_container_completion(self) -> TriggerEvent:
Waits until container is no longer in running state. If trigger is configured with a logging period,
then will emit an event to resume the task for the purpose of fetching more logs.
"""
time_begin = datetime.datetime.now(tz=datetime.timezone.utc)
time_begin = timezone.utcnow()
time_get_more_logs = None
if self.logging_interval is not None:
time_get_more_logs = time_begin + datetime.timedelta(seconds=self.logging_interval)
while True:
pod = await self.hook.get_pod(self.pod_name, self.pod_namespace)
container_state = self.define_container_state(pod)
if container_state == ContainerState.TERMINATED:
return TriggerEvent(
{"status": "success", "namespace": self.pod_namespace, "name": self.pod_name}
)
elif container_state == ContainerState.FAILED:
return TriggerEvent(
{"status": "failed", "namespace": self.pod_namespace, "name": self.pod_name}
)
if time_get_more_logs and datetime.datetime.now(tz=datetime.timezone.utc) > time_get_more_logs:
if not container_is_running(pod=pod, container_name=self.base_container_name):
return TriggerEvent(
{
"status": "running",
"last_log_time": self.last_log_time,
"namespace": self.pod_namespace,
"name": self.pod_name,
}
{"status": "done", "namespace": self.pod_namespace, "pod_name": self.pod_name}
)
if time_get_more_logs and timezone.utcnow() > time_get_more_logs:
return TriggerEvent({"status": "running", "last_log_time": self.last_log_time})
await asyncio.sleep(self.poll_interval)

def _get_async_hook(self) -> AsyncKubernetesHook:
Expand Down

0 comments on commit 0be6430

Please sign in to comment.