From 1bd0b94987c7ea27eb003e2f335fc1e194dfa2fa Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Fri, 22 May 2026 15:48:28 +0300 Subject: [PATCH] Fixing pod leak in KubernetesJobOperator (#1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(providers/cncf/kubernetes): clean up monitoring pods in KubernetesJobOperator KubernetesJobOperator inherited from KubernetesPodOperator but overrode execute() without calling post_complete_action(), so the monitoring / log-streaming pods discovered via get_pods() were never deleted. These pods have no ownerReferences to the V1Job, so ttl_seconds_after_finished and the Foreground cascade in on_kill don't reap them either. - execute() and execute_complete() now wrap their work in try/finally and call post_complete_action() for each pod in self.pods. on_finish_action (delete_pod / delete_succeeded_pod / keep_pod) is now honoured. - on_kill() additionally calls pod_manager.delete_pod() for each monitoring pod (the Job's foreground cascade doesn't reach them). - Per-pod cleanup errors are logged but never mask the in-flight exception, so Job-level failures keep propagating. - execute_complete() resolves monitoring pods once and shares the lookup between the log-retrieval path and the cleanup path. - Added unit tests, a bugfix newsfragment, and an operators.rst section documenting the cleanup contract. * Address code review feedback: remove dead PodNotFoundException check, drop unused import, relax pod-deletion ordering in test, fix trailing comma * Potential fix for pull request finding In _cleanup_monitoring_pods, remote_pod is resolved via find_pod(), which is designed to locate a single matching pod by task-instance labels and can invoke duplicate-pod resolution logic (process_duplicate_label_pods). For KubernetesJobOperator with parallelism > 1, this lookup can return the wrong pod (or trigger duplicate-handling side effects), so post_complete_action() may receive a mismatched remote_pod. Consider using the already-discovered pod’s name/namespace to refresh state (e.g. via hook.get_pod) or just pass remote_pod=pod when you already have the V1Pod object from get_pods(). Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * Use isinstance(exc, TaskDeferred) instead of brittle string comparison * Potential fix for pull request finding The new unit tests add several mock.MagicMock() instances (pods, jobs, TI, etc.) without spec/autospec, and some patch() usages also create non-spec'd mocks by default. Using autospec=True on patches and create_autospec(...)/MagicMock(spec=...) for key Kubernetes objects helps catch typos/attribute mismatches in these tests and aligns with Airflow’s test hardening guidance. Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * Address PR review comments: fix trigger pod_names, on_kill logging, and test assertions - triggers/job.py: Always include pod_names/pod_namespace in trigger event regardless of get_logs setting, so execute_complete() can reliably clean up monitoring pods even when get_logs=False - operators/job.py: Log unexpected ApiException in on_kill() instead of suppressing all ApiExceptions; remove unused `suppress` import - tests/test_job.py: Rewrite test_execute_respects_keep_pod and test_execute_deletes_pod_default to keep process_pod_deletion real and assert on pod_manager.delete_pod; stub hook.get_pod for remote_pod resolution - tests/test_job.py: Add regression test for get_logs=False deferrable path * Fix orphaned test_on_kill_deletes_monitoring_pods method body after accidental deletion of method signature * Make pod resolution best-effort in execute_complete * Address remaining KubernetesJobOperator review comments * Finalize review-comment fixes for KubernetesJobOperator * Fix remaining KubernetesJobOperator review comments * Update KubernetesJobOperator docs for action semantics * Improve KubernetesJobOperator newsfragment readability --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Ville Jyrkkä Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- providers/cncf/kubernetes/docs/operators.rst | 23 ++ ...perator-cleanup-monitoring-pods.bugfix.rst | 9 + .../cncf/kubernetes/operators/job.py | 276 +++++++++++---- .../providers/cncf/kubernetes/triggers/job.py | 4 +- .../cncf/kubernetes/operators/test_job.py | 323 +++++++++++++++++- 5 files changed, 557 insertions(+), 78 deletions(-) create mode 100644 providers/cncf/kubernetes/newsfragments/kubernetes-job-operator-cleanup-monitoring-pods.bugfix.rst diff --git a/providers/cncf/kubernetes/docs/operators.rst b/providers/cncf/kubernetes/docs/operators.rst index 72521af704fbc..a3f8e417237e8 100644 --- a/providers/cncf/kubernetes/docs/operators.rst +++ b/providers/cncf/kubernetes/docs/operators.rst @@ -713,6 +713,29 @@ It means that user can use all parameters from :class:`~airflow.providers.cncf.k More information about the Jobs here: `Kubernetes Job Documentation `__ +Pod cleanup and ``on_finish_action`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When ``wait_until_job_complete=True``, the operator discovers Job pods via +``get_pods()`` and streams logs/XCom from those pods while the Job runs. + +The inherited ``on_finish_action`` parameter controls what happens to these +discovered pods at the end of the task: + +* ``delete_pod`` (default) — the pod is deleted after the task + finishes (success or failure). +* ``delete_succeeded_pod`` — the pod is deleted only when the task + succeeded. +* ``delete_active_pod`` — the pod is deleted only if it is still + active (``Pending`` or ``Running``). +* ``keep_pod`` — the pod is kept (useful for offline log + inspection). + +When the task is killed, ``on_kill`` deletes the Job (with foreground cascade). +For discovered pods, deletion is controlled by ``on_kill_action``: +``delete_pod`` attempts direct pod deletion and ``keep_pod`` skips it. + + .. _howto/operator:KubernetesDeleteJobOperator: diff --git a/providers/cncf/kubernetes/newsfragments/kubernetes-job-operator-cleanup-monitoring-pods.bugfix.rst b/providers/cncf/kubernetes/newsfragments/kubernetes-job-operator-cleanup-monitoring-pods.bugfix.rst new file mode 100644 index 0000000000000..c39d5847c0f06 --- /dev/null +++ b/providers/cncf/kubernetes/newsfragments/kubernetes-job-operator-cleanup-monitoring-pods.bugfix.rst @@ -0,0 +1,9 @@ +Fix pod cleanup gaps in ``KubernetesJobOperator``. +``execute()`` and ``execute_complete()`` now consistently clean up pods discovered via ``get_pods()``, +including deferrable resume paths where pod lookup can fail. +The inherited ``on_finish_action`` parameter (``delete_pod`` / ``delete_succeeded_pod`` / +``delete_active_pod`` / ``keep_pod``) is honored for these pods, matching +``KubernetesPodOperator`` behavior. +In ``on_kill()``, pod cleanup now respects ``on_kill_action`` (``delete_pod`` deletes discovered pods; +``keep_pod`` skips pod deletion). +Per-pod cleanup errors are logged but never mask a Job-level failure. diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py index 510346d34ad97..500515066aaad 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py @@ -22,6 +22,7 @@ import json import logging import os +import sys import warnings from collections.abc import Sequence from functools import cached_property @@ -41,9 +42,9 @@ from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator, merge_objects from airflow.providers.cncf.kubernetes.triggers.job import KubernetesJobTrigger -from airflow.providers.cncf.kubernetes.utils.pod_manager import EMPTY_XCOM_RESULT, PodNotFoundException +from airflow.providers.cncf.kubernetes.utils.pod_manager import EMPTY_XCOM_RESULT, OnKillAction from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_1_PLUS -from airflow.providers.common.compat.sdk import AirflowException, conf +from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred, conf from airflow.utils import yaml if AIRFLOW_V_3_1_PLUS: @@ -218,44 +219,49 @@ def execute(self, context: Context): ti.xcom_push(key="job_name", value=self.job.metadata.name) ti.xcom_push(key="job_namespace", value=self.job.metadata.namespace) - if self.wait_until_job_complete: - self.pods: Sequence[k8s.V1Pod] = self.get_pods( - pod_request_obj=self.pod_request_obj, context=context - ) + try: + if self.wait_until_job_complete: + self.pods: Sequence[k8s.V1Pod] = self.get_pods( + pod_request_obj=self.pod_request_obj, context=context + ) - if self.deferrable: - self.execute_deferrable() - return + if self.deferrable: + self.execute_deferrable() + # execute_deferrable raises TaskDeferred; cleanup is handled + # by execute_complete on resume. + return + + if self.do_xcom_push: + xcom_result = [] + for pod in self.pods: + self.pod_manager.await_container_completion( + pod=pod, container_name=self.base_container_name + ) + self.pod_manager.await_xcom_sidecar_container_start(pod=pod) + xcom_result.append(self.extract_xcom(pod=pod)) + self.job = self.hook.wait_until_job_complete( + job_name=self.job.metadata.name, + namespace=self.job.metadata.namespace, + job_poll_interval=self.job_poll_interval, + ) + if self.get_logs: + for pod in self.pods: + self.pod_manager.fetch_requested_container_logs( + pod=pod, + containers=self.container_logs, + follow_logs=True, + ) - if self.do_xcom_push: - xcom_result = [] - for pod in self.pods: - self.pod_manager.await_container_completion( - pod=pod, container_name=self.base_container_name - ) - self.pod_manager.await_xcom_sidecar_container_start(pod=pod) - xcom_result.append(self.extract_xcom(pod=pod)) - self.job = self.hook.wait_until_job_complete( - job_name=self.job.metadata.name, - namespace=self.job.metadata.namespace, - job_poll_interval=self.job_poll_interval, - ) - if self.get_logs: - for pod in self.pods: - self.pod_manager.fetch_requested_container_logs( - pod=pod, - containers=self.container_logs, - follow_logs=True, + ti.xcom_push(key="job", value=self.job.to_dict()) + if self.wait_until_job_complete: + if error_message := self.hook.is_job_failed(job=self.job): + raise AirflowException( + f"Kubernetes job '{self.job.metadata.name}' is failed with error '{error_message}'" ) - - ti.xcom_push(key="job", value=self.job.to_dict()) - if self.wait_until_job_complete: - if error_message := self.hook.is_job_failed(job=self.job): - raise AirflowException( - f"Kubernetes job '{self.job.metadata.name}' is failed with error '{error_message}'" - ) - if self.do_xcom_push: - return xcom_result[0] if self.unwrap_single and len(xcom_result) == 1 else xcom_result + if self.do_xcom_push: + return xcom_result[0] if self.unwrap_single and len(xcom_result) == 1 else xcom_result + finally: + self._cleanup_monitoring_pods(context) def execute_deferrable(self): self.defer( @@ -277,39 +283,85 @@ def execute_deferrable(self): ) def execute_complete(self, context: Context, event: dict, **kwargs): - ti = context["ti"] - ti.xcom_push(key="job", value=event["job"]) - if event["status"] == "error": - raise AirflowException(event["message"]) - - if self.get_logs: - for pod_name in event["pod_names"]: - pod_namespace = event["pod_namespace"] - try: - pod = self.hook.get_pod(pod_name, pod_namespace) - except ApiException as e: - if e.status == 404: + # Resolve monitoring pods up front so the log-retrieval path and the + # cleanup path in the finally block share the same lookup (no double + # ``hook.get_pod`` calls). + pods_by_name: dict[str, k8s.V1Pod] = {} + event_job = event.get("job") + job_namespace = ( + event_job.get("metadata", {}).get("namespace") + if isinstance(event_job, dict) + else None + ) + pod_namespace = event.get("pod_namespace") or event.get("namespace") or job_namespace + unresolved_pods: list[tuple[str, str]] = [] + for pod_name in event.get("pod_names") or []: + if not pod_namespace: + self.log.warning( + "Skipping pod %s lookup because no pod namespace was provided in trigger event.", + pod_name, + ) + continue + try: + pod = self.hook.get_pod(pod_name, pod_namespace) + except ApiException as e: + if e.status == 404: + self.log.warning( + "Pod %s in namespace %s not found (possibly deleted).", + pod_name, + pod_namespace, + ) + else: + self.log.warning( + "Failed to retrieve pod %s in namespace %s: %s. Skipping.", + pod_name, + pod_namespace, + e, + ) + unresolved_pods.append((pod_name, pod_namespace)) + continue + except Exception as e: + self.log.warning( + "Failed to retrieve pod %s in namespace %s: %s. Skipping.", + pod_name, + pod_namespace, + e, + ) + unresolved_pods.append((pod_name, pod_namespace)) + continue + if pod is not None: + pods_by_name[pod_name] = pod + + try: + ti = context["ti"] + ti.xcom_push(key="job", value=event["job"]) + if event["status"] == "error": + raise AirflowException(event["message"]) + + if self.get_logs: + for pod_name in event.get("pod_names") or []: + if pod_name not in pods_by_name: + # Pod was reported by the trigger but missing now (e.g. 404) self.log.warning( - "Pod %s in namespace %s not found (possibly deleted). Skipping log retrieval.", - pod_name, - pod_namespace, + "Skipping log retrieval for pod %s (not found).", pod_name ) continue - raise - if not pod: - raise PodNotFoundException("Could not find pod after resuming from deferral") - self._write_logs(pod) - - if self.do_xcom_push: - xcom_results: list[Any | None] = [] - for xcom_result in event["xcom_result"]: - if isinstance(xcom_result, str) and xcom_result.rstrip() == EMPTY_XCOM_RESULT: - self.log.info("xcom result file is empty.") - xcom_results.append(None) - continue - self.log.info("xcom result: \n%s", xcom_result) - xcom_results.append(json.loads(xcom_result)) - return xcom_results[0] if self.unwrap_single and len(xcom_results) == 1 else xcom_results + self._write_logs(pods_by_name[pod_name]) + + if self.do_xcom_push: + xcom_results: list[Any | None] = [] + for xcom_result in event["xcom_result"]: + if isinstance(xcom_result, str) and xcom_result.rstrip() == EMPTY_XCOM_RESULT: + self.log.info("xcom result file is empty.") + xcom_results.append(None) + continue + self.log.info("xcom result: \n%s", xcom_result) + xcom_results.append(json.loads(xcom_result)) + return xcom_results[0] if self.unwrap_single and len(xcom_results) == 1 else xcom_results + finally: + self._cleanup_monitoring_pods_from_dict( + context, pods_by_name, unresolved_pods=unresolved_pods, event_status=event.get("status") + ) @staticmethod def deserialize_job_template_file(path: str) -> k8s.V1Job: @@ -334,6 +386,7 @@ def deserialize_job_template_file(path: str) -> k8s.V1Job: return api_client._ApiClient__deserialize_model(job, k8s.V1Job) def on_kill(self) -> None: + self._killed = True if self.job: job = self.job kwargs = { @@ -344,6 +397,95 @@ def on_kill(self) -> None: if self.termination_grace_period is not None: kwargs.update(grace_period_seconds=self.termination_grace_period) self.job_client.delete_namespaced_job(**kwargs) + if self.on_kill_action == OnKillAction.KEEP_POD: + self.log.info( + "Skipping monitoring pod deletion since on_kill_action is set to %r.", + self.on_kill_action.value, + ) + return + # Monitoring pods discovered via get_pods() have no ownerReferences and + # are not reaped by the Job's foreground cascade. Delete them directly. + for pod in getattr(self, "pods", None) or []: + try: + self.pod_manager.delete_pod(pod) + except ApiException: + self.log.exception( + "Unable to delete monitoring pod %s", + getattr(pod.metadata, "name", ""), + ) + + def _cleanup_monitoring_pods(self, context: Context) -> None: + """Run ``post_complete_action`` on each monitoring pod from ``self.pods``. + + Honours ``on_finish_action`` (inherited from ``KubernetesPodOperator``) + and runs as a side-effect: any per-pod cleanup error is logged but never + masks the in-flight exception (e.g. an ``AirflowException`` raised because + the Job itself failed). + """ + # Skip cleanup when control is leaving execute() via TaskDeferred: the + # deferred trigger still needs the monitoring pods to exist; the pods + # will be cleaned up by execute_complete() on resume. + exc = sys.exc_info()[1] + if isinstance(exc, TaskDeferred): + return + for pod in getattr(self, "pods", None) or []: + remote_pod = pod + try: + pod_name = getattr(pod.metadata, "name", None) + pod_namespace = getattr(pod.metadata, "namespace", None) + if pod_name and pod_namespace: + remote_pod = self.hook.get_pod(name=pod_name, namespace=pod_namespace) or pod + except Exception: + remote_pod = pod + try: + self.post_complete_action( + pod=pod, + remote_pod=remote_pod, + context=context, + result=None, + ) + except Exception: + # cleanup() can raise AirflowException for failed pods, and the + # k8s client can raise transport errors. For the Job operator we + # prefer the Job-level failure (or the original exception) to + # propagate instead of any per-pod cleanup error. + self.log.warning( + "Error while cleaning up monitoring pod %s", + getattr(pod.metadata, "name", ""), + exc_info=True, + ) + + def _cleanup_monitoring_pods_from_dict( + self, + context: Context, + pods_by_name: dict[str, k8s.V1Pod], + *, + unresolved_pods: list[tuple[str, str]] | None = None, + event_status: str | None = None, + ) -> None: + """Run ``post_complete_action`` on each pod previously resolved via the trigger event. + + Same semantics as :meth:`_cleanup_monitoring_pods` - errors are logged + but never mask the in-flight exception. + """ + for pod_name, pod in pods_by_name.items(): + try: + self.post_complete_action( + pod=pod, remote_pod=pod, context=context, result=None + ) + except Exception: + self.log.warning( + "Error while cleaning up monitoring pod %s", + pod_name, + exc_info=True, + ) + pod_phase = "Succeeded" if event_status == "success" else "Failed" if event_status == "error" else None + for pod_name, pod_namespace in unresolved_pods or []: + fallback_pod = k8s.V1Pod( + metadata=k8s.V1ObjectMeta(name=pod_name, namespace=pod_namespace), + status=k8s.V1PodStatus(phase=pod_phase), + ) + self.process_pod_deletion(fallback_pod, reraise=False) def build_job_request_obj(self, context: Context | None = None) -> k8s.V1Job: """ diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py index b60373c2d5339..271099598f67c 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py @@ -145,8 +145,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]: { "name": job.metadata.name, "namespace": job.metadata.namespace, - "pod_names": [pod_name for pod_name in self.pod_names] if self.get_logs else None, - "pod_namespace": self.pod_namespace if self.get_logs else None, + "pod_names": list(self.pod_names), + "pod_namespace": self.pod_namespace, "status": "error" if error_message else "success", "message": f"Job failed with error: {error_message}" if error_message diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py index 6bde7c4772deb..e39e2aa412be0 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py @@ -885,11 +885,11 @@ def test_execute_complete_pod_not_found_skips_logs(self, mock_hook, mocked_write @pytest.mark.non_db_test_override @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator._write_logs")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.hook")) - def test_execute_complete_pod_api_error_reraises(self, mock_hook, mocked_write_logs): - """Non-404 ApiExceptions should still be raised.""" + def test_execute_complete_pod_api_error_skips_logs(self, mock_hook, mocked_write_logs): + """Non-404 pod lookup errors should not fail execute_complete.""" from kubernetes.client.rest import ApiException - mock_ti = mock.MagicMock() + mock_ti = mock.create_autospec(TaskInstance, instance=True) context = {"ti": mock_ti} mock_job = mock.MagicMock() event = { @@ -902,10 +902,65 @@ def test_execute_complete_pod_api_error_reraises(self, mock_hook, mocked_write_l mock_hook.get_pod.side_effect = ApiException(status=403, reason="Forbidden") - with pytest.raises(ApiException): - KubernetesJobOperator(task_id="test_task_id", get_logs=True, do_xcom_push=False).execute_complete( - context=context, event=event - ) + KubernetesJobOperator(task_id="test_task_id", get_logs=True, do_xcom_push=False).execute_complete( + context=context, event=event + ) + mock_hook.get_pod.assert_called_once_with(POD_NAME, POD_NAMESPACE) + mocked_write_logs.assert_not_called() + + @pytest.mark.non_db_test_override + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.pod_manager"), new_callable=mock.PropertyMock) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.hook")) + def test_execute_complete_pod_api_error_still_attempts_cleanup(self, mock_hook, mock_pod_manager_prop): + from kubernetes.client.rest import ApiException + + mock_ti = mock.create_autospec(TaskInstance, instance=True) + context = {"ti": mock_ti} + mock_job = mock.MagicMock() + event = { + "job": mock_job, + "status": "success", + "pod_names": [POD_NAME], + "pod_namespace": POD_NAMESPACE, + "xcom_result": None, + } + mock_hook.get_pod.side_effect = ApiException(status=403, reason="Forbidden") + mock_pod_manager = mock.MagicMock() + mock_pod_manager_prop.return_value = mock_pod_manager + + KubernetesJobOperator(task_id="test_task_id", get_logs=False, do_xcom_push=False).execute_complete( + context=context, event=event + ) + + mock_hook.get_pod.assert_called_once_with(POD_NAME, POD_NAMESPACE) + mock_pod_manager.delete_pod.assert_called_once() + assert mock_pod_manager.delete_pod.call_args.args[0].metadata.name == POD_NAME + assert mock_pod_manager.delete_pod.call_args.args[0].metadata.namespace == POD_NAMESPACE + + @pytest.mark.non_db_test_override + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator._write_logs")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.hook")) + def test_execute_complete_uses_event_namespace_fallback(self, mock_hook, mocked_write_logs): + mock_ti = mock.create_autospec(TaskInstance, instance=True) + context = {"ti": mock_ti} + mock_job = {"metadata": {"name": JOB_NAME, "namespace": JOB_NAMESPACE}} + pod = mock.create_autospec(k8s.V1Pod, instance=True) + event = { + "job": mock_job, + "namespace": JOB_NAMESPACE, + "status": "success", + "pod_names": [POD_NAME], + "xcom_result": None, + } + + mock_hook.get_pod.return_value = pod + + KubernetesJobOperator(task_id="test_task_id", get_logs=True, do_xcom_push=False).execute_complete( + context=context, event=event + ) + + mock_hook.get_pod.assert_called_once_with(POD_NAME, JOB_NAMESPACE) + mocked_write_logs.assert_called_once_with(pod) @pytest.mark.non_db_test_override @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator._write_logs")) @@ -944,7 +999,7 @@ def get_pod_side_effect(name, namespace): @pytest.mark.non_db_test_override @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.job_client")) def test_on_kill(self, mock_client): - mock_job = mock.MagicMock() + mock_job = mock.create_autospec(k8s.V1Job, instance=True) mock_job.metadata.name = JOB_NAME mock_job.metadata.namespace = JOB_NAMESPACE @@ -952,6 +1007,7 @@ def test_on_kill(self, mock_client): op.job = mock_job op.on_kill() + assert op._killed is True mock_client.delete_namespaced_job.assert_called_once_with( name=JOB_NAME, namespace=JOB_NAMESPACE, @@ -961,7 +1017,7 @@ def test_on_kill(self, mock_client): @pytest.mark.non_db_test_override @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.job_client")) def test_on_kill_termination_grace_period(self, mock_client): - mock_job = mock.MagicMock() + mock_job = mock.create_autospec(k8s.V1Job, instance=True) mock_job.metadata.name = JOB_NAME mock_job.metadata.namespace = JOB_NAMESPACE mock_termination_grace_period = mock.MagicMock() @@ -991,6 +1047,26 @@ def test_on_kill_none_job(self, mock_hook, mock_client): mock_client.delete_namespaced_job.assert_not_called() mock_serialize.assert_not_called() + @pytest.mark.non_db_test_override + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.pod_manager"), new_callable=mock.PropertyMock) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.job_client")) + def test_on_kill_respects_keep_pod_action(self, mock_client, mock_pod_manager_prop): + mock_pod_manager = mock.MagicMock() + mock_pod_manager_prop.return_value = mock_pod_manager + mock_job = mock.create_autospec(k8s.V1Job, instance=True) + mock_job.metadata.name = JOB_NAME + mock_job.metadata.namespace = JOB_NAMESPACE + pod = mock.create_autospec(k8s.V1Pod, instance=True) + + op = KubernetesJobOperator(task_id="test_task_id", on_kill_action="keep_pod") + op.job = mock_job + op.pods = [pod] + op.on_kill() + + assert op._killed is True + mock_client.delete_namespaced_job.assert_called_once() + mock_pod_manager.delete_pod.assert_not_called() + @pytest.mark.parametrize("parallelism", [1, 2]) @pytest.mark.parametrize("do_xcom_push", [True, False]) @pytest.mark.parametrize("get_logs", [True, False]) @@ -1193,6 +1269,235 @@ def test_create_zero_parallelism_fails_validation( mock_hook.return_value.create_job.assert_not_called() mock_get_pods.assert_not_called() + @pytest.mark.non_db_test_override + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.post_complete_action")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.find_pod")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) + @patch(HOOK_CLASS) + def test_execute_calls_post_complete_action_on_success( + self, + mock_hook, + mock_create_job, + mock_build_job_request_obj, + mock_get_pods, + mock_find_pod, + mock_post_complete_action, + ): + mock_hook.return_value.is_job_failed.return_value = False + mock_pod_1 = mock.create_autospec(k8s.V1Pod, instance=True) + mock_pod_2 = mock.create_autospec(k8s.V1Pod, instance=True) + mock_get_pods.return_value = [mock_pod_1, mock_pod_2] + mock_find_pod.side_effect = ( + lambda namespace, context: mock.create_autospec(k8s.V1Pod, instance=True) + ) + + op = KubernetesJobOperator( + task_id="test_task_id", wait_until_job_complete=True, parallelism=2 + ) + op.execute(context=dict(ti=mock.create_autospec(TaskInstance, instance=True))) + + assert mock_post_complete_action.call_count == 2 + called_pods = [call.kwargs["pod"] for call in mock_post_complete_action.call_args_list] + assert called_pods == [mock_pod_1, mock_pod_2] + + @pytest.mark.non_db_test_override + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.post_complete_action")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.find_pod")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) + @patch(HOOK_CLASS) + def test_execute_calls_post_complete_action_on_failure( + self, + mock_hook, + mock_create_job, + mock_build_job_request_obj, + mock_get_pods, + mock_find_pod, + mock_post_complete_action, + ): + mock_hook.return_value.is_job_failed.return_value = "Error" + mock_pod_1 = mock.create_autospec(k8s.V1Pod, instance=True) + mock_get_pods.return_value = [mock_pod_1] + mock_find_pod.return_value = mock_pod_1 + + op = KubernetesJobOperator(task_id="test_task_id", wait_until_job_complete=True) + with pytest.raises(AirflowException, match="is failed with error"): + op.execute(context=dict(ti=mock.create_autospec(TaskInstance, instance=True))) + + # Cleanup still ran in spite of the job-level failure. + mock_post_complete_action.assert_called_once() + assert mock_post_complete_action.call_args.kwargs["pod"] is mock_pod_1 + + @pytest.mark.non_db_test_override + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.pod_manager"), new_callable=mock.PropertyMock) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) + @patch(HOOK_CLASS) + def test_execute_respects_keep_pod( + self, + mock_hook, + mock_create_job, + mock_build_job_request_obj, + mock_get_pods, + mock_pod_manager_prop, + ): + """When on_finish_action=keep_pod, no monitoring pod should be deleted.""" + mock_pod_manager = mock.MagicMock() + mock_pod_manager_prop.return_value = mock_pod_manager + mock_hook.return_value.is_job_failed.return_value = False + # Return a pod with SUCCEEDED phase so cleanup() doesn't raise. + remote_pod = mock.create_autospec(k8s.V1Pod, instance=True) + remote_pod.status.phase = "Succeeded" + remote_pod.status.container_statuses = [] + mock_pod_1 = mock.create_autospec(k8s.V1Pod, instance=True) + mock_get_pods.return_value = [mock_pod_1] + mock_hook.return_value.get_pod.return_value = remote_pod + + op = KubernetesJobOperator( + task_id="test_task_id", + wait_until_job_complete=True, + on_finish_action="keep_pod", + ) + op.execute(context=dict(ti=mock.create_autospec(TaskInstance, instance=True))) + + mock_pod_manager.delete_pod.assert_not_called() + + @pytest.mark.non_db_test_override + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.pod_manager"), new_callable=mock.PropertyMock) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) + @patch(HOOK_CLASS) + def test_execute_deletes_pod_default( + self, + mock_hook, + mock_create_job, + mock_build_job_request_obj, + mock_get_pods, + mock_pod_manager_prop, + ): + """With the default on_finish_action=delete_pod the monitoring pod is deleted.""" + mock_pod_manager = mock.MagicMock() + mock_pod_manager_prop.return_value = mock_pod_manager + mock_hook.return_value.is_job_failed.return_value = False + remote_pod = mock.create_autospec(k8s.V1Pod, instance=True) + remote_pod.status.phase = "Succeeded" + remote_pod.status.container_statuses = [] + mock_pod_1 = mock.create_autospec(k8s.V1Pod, instance=True) + mock_get_pods.return_value = [mock_pod_1] + mock_hook.return_value.get_pod.return_value = remote_pod + + op = KubernetesJobOperator(task_id="test_task_id", wait_until_job_complete=True) + op.execute(context=dict(ti=mock.create_autospec(TaskInstance, instance=True))) + + mock_pod_manager.delete_pod.assert_called_once_with(remote_pod) + + @pytest.mark.non_db_test_override + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.post_complete_action")) + @patch(HOOK_CLASS) + def test_execute_complete_deletes_pod(self, mock_hook, mock_post_complete_action): + """The deferrable resume path cleans up monitoring pods too.""" + pod = mock.create_autospec(k8s.V1Pod, instance=True) + mock_hook.return_value.get_pod.return_value = pod + event = { + "status": "success", + "message": "ok", + "job": {"metadata": {"name": JOB_NAME, "namespace": JOB_NAMESPACE}}, + "pod_names": [POD_NAME], + "pod_namespace": POD_NAMESPACE, + "xcom_result": [], + } + + KubernetesJobOperator(task_id="test_task_id").execute_complete( + context=dict(ti=mock.create_autospec(TaskInstance, instance=True)), event=event + ) + + mock_hook.return_value.get_pod.assert_called_with(POD_NAME, POD_NAMESPACE) + mock_post_complete_action.assert_called_once() + assert mock_post_complete_action.call_args.kwargs["pod"] is pod + + @pytest.mark.non_db_test_override + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.post_complete_action")) + @patch(HOOK_CLASS) + def test_execute_complete_cleans_up_pods_when_get_logs_false(self, mock_hook, mock_post_complete_action): + """Monitoring pods are cleaned up in execute_complete even when get_logs=False.""" + pod = mock.create_autospec(k8s.V1Pod, instance=True) + mock_hook.return_value.get_pod.return_value = pod + # Simulate an event emitted by the trigger when get_logs=False: + # pod_names and pod_namespace are now always present in the event. + event = { + "status": "success", + "message": "ok", + "job": {"metadata": {"name": JOB_NAME, "namespace": JOB_NAMESPACE}}, + "pod_names": [POD_NAME], + "pod_namespace": POD_NAMESPACE, + "xcom_result": None, + } + + KubernetesJobOperator(task_id="test_task_id", get_logs=False).execute_complete( + context=dict(ti=mock.create_autospec(TaskInstance, instance=True)), event=event + ) + + mock_hook.return_value.get_pod.assert_called_with(POD_NAME, POD_NAMESPACE) + mock_post_complete_action.assert_called_once() + assert mock_post_complete_action.call_args.kwargs["pod"] is pod + + @pytest.mark.non_db_test_override + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.pod_manager"), new_callable=mock.PropertyMock) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.job_client")) + def test_on_kill_deletes_monitoring_pods(self, mock_client, mock_pod_manager_prop): + mock_pod_manager = mock.MagicMock() + mock_pod_manager_prop.return_value = mock_pod_manager + + mock_job = mock.create_autospec(k8s.V1Job, instance=True) + mock_job.metadata.name = JOB_NAME + mock_job.metadata.namespace = JOB_NAMESPACE + pod_1 = mock.create_autospec(k8s.V1Pod, instance=True) + pod_2 = mock.create_autospec(k8s.V1Pod, instance=True) + + op = KubernetesJobOperator(task_id="test_task_id") + op.job = mock_job + op.pods = [pod_1, pod_2] + op.on_kill() + + mock_client.delete_namespaced_job.assert_called_once() + assert mock_pod_manager.delete_pod.call_count == 2 + mock_pod_manager.delete_pod.assert_has_calls( + [mock.call(pod_1), mock.call(pod_2)], any_order=True + ) + + @pytest.mark.non_db_test_override + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.post_complete_action")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.find_pod")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) + @patch(HOOK_CLASS) + def test_cleanup_error_does_not_mask_job_failure( + self, + mock_hook, + mock_create_job, + mock_build_job_request_obj, + mock_get_pods, + mock_find_pod, + mock_post_complete_action, + ): + mock_hook.return_value.is_job_failed.return_value = "Error" + mock_pod_1 = mock.create_autospec(k8s.V1Pod, instance=True) + mock_get_pods.return_value = [mock_pod_1] + mock_find_pod.return_value = mock_pod_1 + mock_post_complete_action.side_effect = AirflowException("cleanup boom") + + op = KubernetesJobOperator(task_id="test_task_id", wait_until_job_complete=True) + with pytest.raises(AirflowException, match="is failed with error"): + op.execute(context=dict(ti=mock.create_autospec(TaskInstance, instance=True))) + + mock_post_complete_action.assert_called_once() + @pytest.mark.db_test @pytest.mark.execution_timeout(300)