diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py index c1735158ec1ec..d6b4aa6bfb0c6 100644 --- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py +++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py @@ -420,8 +420,7 @@ def await_pod_start(self, pod: k8s.V1Pod): self.pod_manager.await_pod_start(pod=pod, startup_timeout=self.startup_timeout_seconds) except PodLaunchFailedException: if self.log_events_on_failure: - for event in self.pod_manager.read_pod_events(pod).items: - self.log.error("Pod Event: %s - %s", event.reason, event.message) + self._read_pod_log_events(pod, reraise=True) raise def extract_xcom(self, pod: k8s.V1Pod): @@ -472,34 +471,36 @@ def execute(self, context: Context): if self.do_xcom_push: return result + def _read_pod_log_events(self, pod, *, reraise=True): + """Will fetch and emit events from pod""" + with _optionally_suppress(reraise=reraise): + for event in self.pod_manager.read_pod_events(pod).items: + self.log.error("Pod Event: %s - %s", event.reason, event.message) + def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod): pod_phase = remote_pod.status.phase if hasattr(remote_pod, "status") else None - if not self.is_delete_operator_pod: - with _suppress(Exception): - self.patch_already_checked(remote_pod) + if pod_phase != PodPhase.SUCCEEDED or not self.is_delete_operator_pod: + self.patch_already_checked(remote_pod, reraise=False) if pod_phase != PodPhase.SUCCEEDED: if self.log_events_on_failure: - with _suppress(Exception): - for event in self.pod_manager.read_pod_events(pod).items: - self.log.error("Pod Event: %s - %s", event.reason, event.message) - with _suppress(Exception): - self.process_pod_deletion(remote_pod) + self._read_pod_log_events(pod, reraise=False) + self.process_pod_deletion(remote_pod, reraise=False) error_message = get_container_termination_message(remote_pod, self.BASE_CONTAINER_NAME) - error_message = "\n" + error_message if error_message else "" raise AirflowException( - f"Pod {pod and pod.metadata.name} returned a failure:{error_message}\n{remote_pod}" + f"Pod {pod and pod.metadata.name} returned a failure:\n{error_message}\n" + f"remote_pod: {remote_pod}" ) else: - with _suppress(Exception): - self.process_pod_deletion(remote_pod) - - def process_pod_deletion(self, pod: k8s.V1Pod): - if pod is not None: - if self.is_delete_operator_pod: - self.log.info("Deleting pod: %s", pod.metadata.name) - self.pod_manager.delete_pod(pod) - else: - self.log.info("skipping deleting pod: %s", pod.metadata.name) + self.process_pod_deletion(remote_pod, reraise=False) + + def process_pod_deletion(self, pod: k8s.V1Pod, *, reraise=True): + with _optionally_suppress(reraise=reraise): + if pod is not None: + if self.is_delete_operator_pod: + self.log.info("Deleting pod: %s", pod.metadata.name) + self.pod_manager.delete_pod(pod) + else: + self.log.info("skipping deleting pod: %s", pod.metadata.name) def _build_find_pod_label_selector(self, context: Context | None = None, *, exclude_checked=True) -> str: labels = self._get_ti_pod_labels(context, include_try_number=False) @@ -517,11 +518,12 @@ def _set_name(name: str | None) -> str | None: return re.sub(r"[^a-z0-9-]+", "-", name.lower()) return None - def patch_already_checked(self, pod: k8s.V1Pod): + def patch_already_checked(self, pod: k8s.V1Pod, *, reraise=True): """Add an "already checked" annotation to ensure we don't reattach on retries""" - pod.metadata.labels[self.POD_CHECKED_KEY] = "True" - body = PodGenerator.serialize_pod(pod) - self.client.patch_namespaced_pod(pod.metadata.name, pod.metadata.namespace, body) + with _optionally_suppress(reraise=reraise): + pod.metadata.labels[self.POD_CHECKED_KEY] = "True" + body = PodGenerator.serialize_pod(pod) + self.client.patch_namespaced_pod(pod.metadata.name, pod.metadata.namespace, body) def on_kill(self) -> None: if self.pod: @@ -643,26 +645,37 @@ def dry_run(self) -> None: print(yaml.dump(prune_dict(pod.to_dict(), mode="strict"))) -class _suppress(AbstractContextManager): +class _optionally_suppress(AbstractContextManager): """ - This behaves the same as ``contextlib.suppress`` but logs the suppressed - exceptions as errors with traceback. + Returns context manager that will swallow and log exceptions. - The caught exception is also stored on the context manager instance under - attribute ``exception``. + By default swallows descendents of Exception, but you can provide other classes through + the vararg ``exceptions``. + + Suppression behavior can be disabled with reraise=True. + + :meta private: """ - def __init__(self, *exceptions): - self._exceptions = exceptions + def __init__(self, *exceptions, reraise=False): + self._exceptions = exceptions or (Exception,) + self.reraise = reraise self.exception = None def __enter__(self): return self def __exit__(self, exctype, excinst, exctb): - caught_error = exctype is not None and issubclass(exctype, self._exceptions) - if caught_error: + error = exctype is not None + matching_error = error and issubclass(exctype, self._exceptions) + if error and not matching_error: + return False + elif matching_error and self.reraise: + return False + elif matching_error: self.exception = excinst logger = logging.getLogger(__name__) logger.exception(excinst) - return caught_error + return True + else: + return True diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/airflow/providers/cncf/kubernetes/utils/pod_manager.py index ffd3e7ec4bd88..c139f2f4b13dd 100644 --- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -21,7 +21,7 @@ import math import time import warnings -from contextlib import closing +from contextlib import closing, suppress from dataclasses import dataclass from datetime import datetime from typing import TYPE_CHECKING, Iterable, cast @@ -85,12 +85,10 @@ def container_is_running(pod: V1Pod, container_name: str) -> bool: def get_container_termination_message(pod: V1Pod, container_name: str): - try: + with suppress(AttributeError, TypeError): container_statuses = pod.status.container_statuses container_status = next((x for x in container_statuses if x.name == container_name), None) return container_status.state.terminated.message if container_status else None - except (AttributeError, TypeError): - return None @dataclass diff --git a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py index 12cde9d1d0021..f2022855046bb 100644 --- a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py +++ b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py @@ -32,7 +32,7 @@ from airflow.models.xcom import XCom from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import ( KubernetesPodOperator, - _suppress, + _optionally_suppress, _task_id_to_pod_name, ) from airflow.utils import timezone @@ -938,7 +938,6 @@ def test_mark_checked_unexpected_exception(self, mock_patch_already_checked, moc """If we aren't deleting pods and have an exception, mark it so we don't reattach to it""" k = KubernetesPodOperator( task_id="task", - is_delete_operator_pod=False, ) self.await_pod_mock.side_effect = AirflowException("oops") context = create_context(k) @@ -1023,11 +1022,62 @@ def test_task_id_as_name_dag_id_is_ignored(self): assert re.match(r"a-very-reasonable-task-name-[a-z0-9-]+", pod.metadata.name) is not None -def test__suppress(caplog): - with _suppress(ValueError): - raise ValueError("failure") +class TestSuppress: + def test__suppress(self, caplog): + with _optionally_suppress(ValueError): + raise ValueError("failure") + assert "ValueError: failure" in caplog.text + + def test__suppress_no_args(self, caplog): + """By default, suppresses Exception, so should suppress and log RuntimeError""" + with _optionally_suppress(): + raise RuntimeError("failure") + assert "RuntimeError: failure" in caplog.text - assert "ValueError: failure" in caplog.text + def test__suppress_no_args_reraise(self, caplog): + """ + By default, suppresses Exception, but with reraise=True, + should raise RuntimeError and not log. + """ + with pytest.raises(RuntimeError): + with _optionally_suppress(reraise=True): + raise RuntimeError("failure") + assert caplog.text == "" + + def test__suppress_wrong_error(self, caplog): + """ + Here, we specify only catch ValueError. But we raise RuntimeError. + So it should raise and not log. + """ + with pytest.raises(RuntimeError): + with _optionally_suppress(ValueError): + raise RuntimeError("failure") + assert caplog.text == "" + + def test__suppress_wrong_error_multiple(self, caplog): + """ + Here, we specify only catch RuntimeError/IndexError. + But we raise RuntimeError. So it should raise and not log. + """ + with pytest.raises(RuntimeError): + with _optionally_suppress(ValueError, IndexError): + raise RuntimeError("failure") + assert caplog.text == "" + + def test__suppress_right_error_multiple(self, caplog): + """ + Here, we specify catch RuntimeError/IndexError. + And we raise RuntimeError. So it should suppress and log. + """ + with _optionally_suppress(ValueError, IndexError): + raise IndexError("failure") + assert "IndexError: failure" in caplog.text + + def test__suppress_no_error(self, caplog): + """When no error in context, should do nothing.""" + with _optionally_suppress(): + print("hi") + assert caplog.text == "" @pytest.mark.parametrize(