Skip to content

Commit

Permalink
Fix reattach_on_restart parameter for the sync mode
Browse files Browse the repository at this point in the history
  • Loading branch information
e-galan committed May 6, 2024
1 parent 98c05af commit 0519403
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 9 deletions.
61 changes: 52 additions & 9 deletions airflow/providers/cncf/kubernetes/operators/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ class PodReattachFailure(AirflowException):
"""When we expect to be able to find a pod but cannot."""


class IdenticalLabelPodError(AirflowException):
"""Thrown if we do not expect to have multiple pods with identical labels."""


class KubernetesPodOperator(BaseOperator):
"""
Execute a task in a Kubernetes Pod.
Expand Down Expand Up @@ -468,7 +472,9 @@ def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool
"""
Generate labels for the pod to track the pod in case of Operator crash.
:param context: task context provided by airflow DAG
:param context: task context provided by airflow DAG.
:param include_try_number: if set to True will add the try number
from the task context to the pod labels.
:return: dict
"""
if not context:
Expand Down Expand Up @@ -533,24 +539,31 @@ def find_pod(self, namespace: str, context: Context, *, exclude_checked: bool =

pod = None
num_pods = len(pod_list)
if num_pods > 1:
raise AirflowException(f"More than one pod running with labels {label_selector}")
elif num_pods == 1:

if num_pods == 1:
pod = pod_list[0]
self.log.info("Found matching pod %s with labels %s", pod.metadata.name, pod.metadata.labels)
self.log.info("`try_number` of task_instance: %s", context["ti"].try_number)
self.log.info("`try_number` of pod: %s", pod.metadata.labels["try_number"])
self.log_matching_pod(pod=pod, context=context)
elif num_pods > 1 and not self.reattach_on_restart:
self.log.warning("Found more than one pod running with labels %s, resolving ...", label_selector)
pod = self.process_duplicate_label_pods(pod_list)
self.log_matching_pod(pod=pod, context=context)
elif num_pods > 1:
raise IdenticalLabelPodError(f"More than one pod running with labels {label_selector}")

return pod

def log_matching_pod(self, pod: k8s.V1Pod, context: Context) -> None:
self.log.info("Found matching pod %s with labels %s", pod.metadata.name, pod.metadata.labels)
self.log.info("`try_number` of task_instance: %s", context["ti"].try_number)
self.log.info("`try_number` of pod: %s", pod.metadata.labels["try_number"])

def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context: Context) -> k8s.V1Pod:
if self.reattach_on_restart:
pod = self.find_pod(self.namespace or pod_request_obj.metadata.namespace, context=context)
if pod:
return pod

self.log.debug("Starting pod:\n%s", yaml.safe_dump(pod_request_obj.to_dict()))
self.pod_manager.create_pod(pod=pod_request_obj)

return pod_request_obj

def await_pod_start(self, pod: k8s.V1Pod) -> None:
Expand Down Expand Up @@ -1129,6 +1142,36 @@ def dry_run(self) -> None:
def execute_complete(self, context: Context, event: dict, **kwargs):
return self.trigger_reentry(context=context, event=event)

def process_duplicate_label_pods(self, pod_list: list[k8s.V1Pod]) -> k8s.V1Pod:
"""
Patch or delete the existing pod with duplicate labels.
This is to handle an edge case that can happen only if reattach_on_restart
flag is False, and the previous run attempt has failed because the task
process has been killed externally by the cluster or another process.
If the task process is killed externally, it breaks the code execution and
immediately exists the task. As a result the pod created in the previous attempt
will not be properly deleted or patched by cleanup() method.
Return the newly created pod to be used for the next run attempt.
"""
new_pod = pod_list.pop(self._get_most_recent_pod_index(pod_list))
old_pod = pod_list[0]
self.patch_already_checked(old_pod, reraise=False)
if self.on_finish_action == OnFinishAction.DELETE_POD:
self.process_pod_deletion(old_pod)
return new_pod

@staticmethod
def _get_most_recent_pod_index(pod_list: list[k8s.V1Pod]) -> int:
"""Loop through a list of V1Pod objects and get the index of the most recent one."""
pod_start_times: list[datetime.datetime] = [
pod.to_dict().get("status").get("start_time") for pod in pod_list
]
most_recent_start_time = max(pod_start_times)
return pod_start_times.index(most_recent_start_time)


class _optionally_suppress(AbstractContextManager):
"""
Expand Down
69 changes: 69 additions & 0 deletions tests/providers/cncf/kubernetes/operators/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import datetime
import re
from contextlib import contextmanager, nullcontext
from io import BytesIO
Expand Down Expand Up @@ -1677,6 +1678,74 @@ def test_await_container_completion_retries_on_specific_exception(
[mock.call(pod=pod, container_name=container_name)] * expected_call_count
)

@pytest.mark.parametrize(
"on_finish_action", [OnFinishAction.KEEP_POD, OnFinishAction.DELETE_SUCCEEDED_POD]
)
@patch(KUB_OP_PATH.format("patch_already_checked"))
@patch(KUB_OP_PATH.format("process_pod_deletion"))
def test_process_duplicate_label_pods__label_patched_if_action_is_not_delete_pod(
self,
process_pod_deletion_mock,
patch_already_checked_mock,
on_finish_action,
):
now = datetime.datetime.now()
k = KubernetesPodOperator(
namespace="default",
image="ubuntu:22.04",
cmds=["bash", "-cx"],
arguments=["echo 12"],
name="test",
task_id="task",
do_xcom_push=False,
reattach_on_restart=False,
on_finish_action=on_finish_action,
)
context = create_context(k)
pod_1 = k.get_or_create_pod(pod_request_obj=k.build_pod_request_obj(context), context=context)
pod_2 = k.get_or_create_pod(pod_request_obj=k.build_pod_request_obj(context), context=context)

pod_1.status = {"start_time": now}
pod_2.status = {"start_time": now + datetime.timedelta(seconds=60)}
pod_2.metadata.labels.update({"try_number": "2"})

result = k.process_duplicate_label_pods([pod_1, pod_2])

patch_already_checked_mock.assert_called_once_with(pod_1, reraise=False)
process_pod_deletion_mock.assert_not_called()
assert result.metadata.name == pod_2.metadata.name

@patch(KUB_OP_PATH.format("patch_already_checked"))
@patch(KUB_OP_PATH.format("process_pod_deletion"))
def test_process_duplicate_label_pods__pod_removed_if_delete_pod(
self, process_pod_deletion_mock, patch_already_checked_mock
):
now = datetime.datetime.now()
k = KubernetesPodOperator(
namespace="default",
image="ubuntu:22.04",
cmds=["bash", "-cx"],
arguments=["echo 12"],
name="test",
task_id="task",
do_xcom_push=False,
reattach_on_restart=False,
on_finish_action=OnFinishAction.DELETE_POD,
)
context = create_context(k)
pod_1 = k.get_or_create_pod(pod_request_obj=k.build_pod_request_obj(context), context=context)
pod_2 = k.get_or_create_pod(pod_request_obj=k.build_pod_request_obj(context), context=context)

pod_1.status = {"start_time": now}
pod_2.status = {"start_time": now + datetime.timedelta(seconds=60)}
pod_2.metadata.labels.update({"try_number": "2"})

result = k.process_duplicate_label_pods([pod_1, pod_2])

patch_already_checked_mock.assert_called_once_with(pod_1, reraise=False)
process_pod_deletion_mock.assert_called_once_with(pod_1)
assert result.metadata.name == pod_2.metadata.name


class TestSuppress:
def test__suppress(self, caplog):
Expand Down

0 comments on commit 0519403

Please sign in to comment.