Skip to content

Commit

Permalink
Patch "checked" when pod not successful (#27845)
Browse files Browse the repository at this point in the history
In KPO, if pod gets stuck in terminating status, then we might keep reattaching to it.  Also a drive-by refactor to clean up the `cleanup` method, by pushing the suppression logic down.
  • Loading branch information
dstandish committed Dec 2, 2022
1 parent 6dd658b commit ebd7b67
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 46 deletions.
85 changes: 49 additions & 36 deletions airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,7 @@ def await_pod_start(self, pod: k8s.V1Pod):
self.pod_manager.await_pod_start(pod=pod, startup_timeout=self.startup_timeout_seconds)
except PodLaunchFailedException:
if self.log_events_on_failure:
for event in self.pod_manager.read_pod_events(pod).items:
self.log.error("Pod Event: %s - %s", event.reason, event.message)
self._read_pod_log_events(pod, reraise=True)
raise

def extract_xcom(self, pod: k8s.V1Pod):
Expand Down Expand Up @@ -472,34 +471,36 @@ def execute(self, context: Context):
if self.do_xcom_push:
return result

def _read_pod_log_events(self, pod, *, reraise=True):
"""Will fetch and emit events from pod"""
with _optionally_suppress(reraise=reraise):
for event in self.pod_manager.read_pod_events(pod).items:
self.log.error("Pod Event: %s - %s", event.reason, event.message)

def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod):
pod_phase = remote_pod.status.phase if hasattr(remote_pod, "status") else None
if not self.is_delete_operator_pod:
with _suppress(Exception):
self.patch_already_checked(remote_pod)
if pod_phase != PodPhase.SUCCEEDED or not self.is_delete_operator_pod:
self.patch_already_checked(remote_pod, reraise=False)
if pod_phase != PodPhase.SUCCEEDED:
if self.log_events_on_failure:
with _suppress(Exception):
for event in self.pod_manager.read_pod_events(pod).items:
self.log.error("Pod Event: %s - %s", event.reason, event.message)
with _suppress(Exception):
self.process_pod_deletion(remote_pod)
self._read_pod_log_events(pod, reraise=False)
self.process_pod_deletion(remote_pod, reraise=False)
error_message = get_container_termination_message(remote_pod, self.BASE_CONTAINER_NAME)
error_message = "\n" + error_message if error_message else ""
raise AirflowException(
f"Pod {pod and pod.metadata.name} returned a failure:{error_message}\n{remote_pod}"
f"Pod {pod and pod.metadata.name} returned a failure:\n{error_message}\n"
f"remote_pod: {remote_pod}"
)
else:
with _suppress(Exception):
self.process_pod_deletion(remote_pod)

def process_pod_deletion(self, pod: k8s.V1Pod):
if pod is not None:
if self.is_delete_operator_pod:
self.log.info("Deleting pod: %s", pod.metadata.name)
self.pod_manager.delete_pod(pod)
else:
self.log.info("skipping deleting pod: %s", pod.metadata.name)
self.process_pod_deletion(remote_pod, reraise=False)

def process_pod_deletion(self, pod: k8s.V1Pod, *, reraise=True):
with _optionally_suppress(reraise=reraise):
if pod is not None:
if self.is_delete_operator_pod:
self.log.info("Deleting pod: %s", pod.metadata.name)
self.pod_manager.delete_pod(pod)
else:
self.log.info("skipping deleting pod: %s", pod.metadata.name)

def _build_find_pod_label_selector(self, context: Context | None = None, *, exclude_checked=True) -> str:
labels = self._get_ti_pod_labels(context, include_try_number=False)
Expand All @@ -517,11 +518,12 @@ def _set_name(name: str | None) -> str | None:
return re.sub(r"[^a-z0-9-]+", "-", name.lower())
return None

def patch_already_checked(self, pod: k8s.V1Pod):
def patch_already_checked(self, pod: k8s.V1Pod, *, reraise=True):
"""Add an "already checked" annotation to ensure we don't reattach on retries"""
pod.metadata.labels[self.POD_CHECKED_KEY] = "True"
body = PodGenerator.serialize_pod(pod)
self.client.patch_namespaced_pod(pod.metadata.name, pod.metadata.namespace, body)
with _optionally_suppress(reraise=reraise):
pod.metadata.labels[self.POD_CHECKED_KEY] = "True"
body = PodGenerator.serialize_pod(pod)
self.client.patch_namespaced_pod(pod.metadata.name, pod.metadata.namespace, body)

def on_kill(self) -> None:
if self.pod:
Expand Down Expand Up @@ -643,26 +645,37 @@ def dry_run(self) -> None:
print(yaml.dump(prune_dict(pod.to_dict(), mode="strict")))


class _suppress(AbstractContextManager):
class _optionally_suppress(AbstractContextManager):
"""
This behaves the same as ``contextlib.suppress`` but logs the suppressed
exceptions as errors with traceback.
Returns context manager that will swallow and log exceptions.
The caught exception is also stored on the context manager instance under
attribute ``exception``.
By default swallows descendents of Exception, but you can provide other classes through
the vararg ``exceptions``.
Suppression behavior can be disabled with reraise=True.
:meta private:
"""

def __init__(self, *exceptions):
self._exceptions = exceptions
def __init__(self, *exceptions, reraise=False):
self._exceptions = exceptions or (Exception,)
self.reraise = reraise
self.exception = None

def __enter__(self):
return self

def __exit__(self, exctype, excinst, exctb):
caught_error = exctype is not None and issubclass(exctype, self._exceptions)
if caught_error:
error = exctype is not None
matching_error = error and issubclass(exctype, self._exceptions)
if error and not matching_error:
return False
elif matching_error and self.reraise:
return False
elif matching_error:
self.exception = excinst
logger = logging.getLogger(__name__)
logger.exception(excinst)
return caught_error
return True
else:
return True
6 changes: 2 additions & 4 deletions airflow/providers/cncf/kubernetes/utils/pod_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import math
import time
import warnings
from contextlib import closing
from contextlib import closing, suppress
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Iterable, cast
Expand Down Expand Up @@ -85,12 +85,10 @@ def container_is_running(pod: V1Pod, container_name: str) -> bool:


def get_container_termination_message(pod: V1Pod, container_name: str):
try:
with suppress(AttributeError, TypeError):
container_statuses = pod.status.container_statuses
container_status = next((x for x in container_statuses if x.name == container_name), None)
return container_status.state.terminated.message if container_status else None
except (AttributeError, TypeError):
return None


@dataclass
Expand Down
62 changes: 56 additions & 6 deletions tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from airflow.models.xcom import XCom
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import (
KubernetesPodOperator,
_suppress,
_optionally_suppress,
_task_id_to_pod_name,
)
from airflow.utils import timezone
Expand Down Expand Up @@ -938,7 +938,6 @@ def test_mark_checked_unexpected_exception(self, mock_patch_already_checked, moc
"""If we aren't deleting pods and have an exception, mark it so we don't reattach to it"""
k = KubernetesPodOperator(
task_id="task",
is_delete_operator_pod=False,
)
self.await_pod_mock.side_effect = AirflowException("oops")
context = create_context(k)
Expand Down Expand Up @@ -1023,11 +1022,62 @@ def test_task_id_as_name_dag_id_is_ignored(self):
assert re.match(r"a-very-reasonable-task-name-[a-z0-9-]+", pod.metadata.name) is not None


def test__suppress(caplog):
with _suppress(ValueError):
raise ValueError("failure")
class TestSuppress:
def test__suppress(self, caplog):
with _optionally_suppress(ValueError):
raise ValueError("failure")
assert "ValueError: failure" in caplog.text

def test__suppress_no_args(self, caplog):
"""By default, suppresses Exception, so should suppress and log RuntimeError"""
with _optionally_suppress():
raise RuntimeError("failure")
assert "RuntimeError: failure" in caplog.text

assert "ValueError: failure" in caplog.text
def test__suppress_no_args_reraise(self, caplog):
"""
By default, suppresses Exception, but with reraise=True,
should raise RuntimeError and not log.
"""
with pytest.raises(RuntimeError):
with _optionally_suppress(reraise=True):
raise RuntimeError("failure")
assert caplog.text == ""

def test__suppress_wrong_error(self, caplog):
"""
Here, we specify only catch ValueError. But we raise RuntimeError.
So it should raise and not log.
"""
with pytest.raises(RuntimeError):
with _optionally_suppress(ValueError):
raise RuntimeError("failure")
assert caplog.text == ""

def test__suppress_wrong_error_multiple(self, caplog):
"""
Here, we specify only catch RuntimeError/IndexError.
But we raise RuntimeError. So it should raise and not log.
"""
with pytest.raises(RuntimeError):
with _optionally_suppress(ValueError, IndexError):
raise RuntimeError("failure")
assert caplog.text == ""

def test__suppress_right_error_multiple(self, caplog):
"""
Here, we specify catch RuntimeError/IndexError.
And we raise RuntimeError. So it should suppress and log.
"""
with _optionally_suppress(ValueError, IndexError):
raise IndexError("failure")
assert "IndexError: failure" in caplog.text

def test__suppress_no_error(self, caplog):
"""When no error in context, should do nothing."""
with _optionally_suppress():
print("hi")
assert caplog.text == ""


@pytest.mark.parametrize(
Expand Down

0 comments on commit ebd7b67

Please sign in to comment.