diff --git a/airflow/providers/cncf/kubernetes/CHANGELOG.rst b/airflow/providers/cncf/kubernetes/CHANGELOG.rst index d3cd9d28173a..26f185246bd6 100644 --- a/airflow/providers/cncf/kubernetes/CHANGELOG.rst +++ b/airflow/providers/cncf/kubernetes/CHANGELOG.rst @@ -19,6 +19,90 @@ Changelog --------- +3.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +* ``Simplify KubernetesPodOperator (#19572)`` + +.. warning:: Many methods in :class:`~.KubernetesPodOperator` and class:`~.PodLauncher` have been renamed. + If you have subclassed :class:`~.KubernetesPodOperator` will need to update your subclass to reflect + the new structure. Additionally ``PodStatus`` enum has been renamed to ``PodPhase``. + +Notes on changes KubernetesPodOperator and PodLauncher +`````````````````````````````````````````````````````` + +Overview +'''''''' + +Generally speaking if you did not subclass ``KubernetesPodOperator`` and you didn't use the ``PodLauncher`` class directly, +then you don't need to worry about this change. If however you have subclassed ``KubernetesPodOperator``, what +follows are some notes on the changes in this release. + +One of the principal goals of the refactor is to clearly separate the "get or create pod" and +"wait for pod completion" phases. Previously the "wait for pod completion" logic would be invoked +differently depending on whether the operator were to "attach to an existing pod" (e.g. after a +worker failure) or "create a new pod" and this resulted in some code duplication and a bit more +nesting of logic. With this refactor we encapsulate the "get or create" step +into method :meth:`~.KubernetesPodOperator.get_or_create_pod`, and pull the monitoring and XCom logic up +into the top level of ``execute`` because it can be the same for "attached" pods and "new" pods. + +:meth:`~.KubernetesPodOperator.get_or_create_pod` tries first to find an existing pod using labels +specific to the task instance (see :meth:`~.KubernetesPodOperator.find_pod`). +If one does not exist it :meth:`creates a pod <~.PodLauncher.create_pod>`. + +The "waiting" part of execution has three components. The first step is to wait for the pod to leave the +``Pending`` phase (:meth:`~.KubernetesPodOperator.await_pod_start`). Next, if configured to do so, +the operator will :meth:`follow the base container logs <~.KubernetesPodOperator.await_pod_start>` +and forward these logs to the task logger until the ``base`` container is done. If not configured to harvest the +logs, the operator will instead :meth:`poll for container completion until done <~.KubernetesPodOperator.await_container_completion>`; +either way, we must await container completion before harvesting xcom. After (optionally) extracting the xcom +value from the base container, we :meth:`await pod completion <~.PodLauncher.await_pod_completion>`. + +Previously, depending on whether the pod was "reattached to" (e.g. after a worker failure) or +created anew, the waiting logic may have occurred in either ``handle_pod_overlap`` or ``create_new_pod_for_operator``. + +After the pod terminates, we execute different cleanup tasks depending on whether the pod terminated successfully. + +If the pod terminates *unsuccessfully*, we attempt to :meth:`log the pod events <~.PodLauncher.read_pod_events>`. If +additionally the task is configured *not* to delete the pod after termination, :meth:`we apply a label <~.KubernetesPodOperator.patch_already_checked>` +indicating that the pod failed and should not be "reattached to" in a retry. If the task is configured +to delete its pod, we :meth:`delete it <~.KubernetesPodOperator.process_pod_deletion>`. Finally, +we raise an AirflowException to fail the task instance. + +If the pod terminates successfully, we :meth:`delete the pod <~.KubernetesPodOperator.process_pod_deletion>` +(if configured to delete the pod) and push XCom (if configured to push XCom). + +Details on method renames, refactors, and deletions +''''''''''''''''''''''''''''''''''''''''''''''''''' + +In ``KubernetesPodOperator``: + +* Method ``create_pod_launcher`` is converted to cached property ``launcher`` +* Construction of k8s ``CoreV1Api`` client is now encapsulated within cached property ``client`` +* Logic to search for an existing pod (e.g. after an airflow worker failure) is moved out of ``execute`` and into method ``find_pod``. +* Method ``handle_pod_overlap`` is removed. Previously it monitored a "found" pod until completion. With this change the pod monitoring (and log following) is orchestrated directly from ``execute`` and it is the same whether it's a "found" pod or a "new" pod. See methods ``await_pod_start``, ``follow_container_logs``, ``await_container_completion`` and ``await_pod_completion``. +* Method ``create_pod_request_obj`` is renamed ``build_pod_request_obj``. It now takes argument ``context`` in order to add TI-specific pod labels; previously they were added after return. +* Method ``create_labels_for_pod`` is renamed ``_get_ti_pod_labels``. This method doesn't return *all* labels, but only those specific to the TI. We also add parameter ``include_try_number`` to control the inclusion of this label instead of possibly filtering it out later. +* Method ``_get_pod_identifying_label_string`` is renamed ``_build_find_pod_label_selector`` +* Method ``_try_numbers_match`` is removed. +* Method ``create_new_pod_for_operator`` is removed. Previously it would mutate the labels on ``self.pod``, launch the pod, monitor the pod to completion etc. Now this logic is in part handled by ``get_or_create_pod``, where a new pod will be created if necessary. The monitoring etc is now orchestrated directly from ``execute``. Again, see the calls to methods ``await_pod_start``, ``follow_container_logs``, ``await_container_completion`` and ``await_pod_completion``. + +In ``pod_launcher.py``, in class ``PodLauncher``: + +* Method ``start_pod`` is removed and split into two methods: ``create_pod`` and ``await_pod_start``. +* Method ``monitor_pod`` is removed and split into methods ``follow_container_logs``, ``await_container_completion``, ``await_pod_completion`` +* Methods ``pod_not_started``, ``pod_is_running``, ``process_status``, and ``_task_status`` are removed. These were needed due to the way in which pod ``phase`` was mapped to task instance states; but we no longer do such a mapping and instead deal with pod phases directly and untransformed. +* Method ``_extract_xcom`` is renamed ``extract_xcom``. +* Method ``read_pod_logs`` now takes kwarg ``container_name`` + + +Other changes in ``pod_launcher.py``: + +* Enum-like class ``PodStatus`` is renamed ``PodPhase``, and the values are no longer lower-cased. + 2.2.0 ..... diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py index 50e75a456599..7fd93951a1a9 100644 --- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py +++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py @@ -15,17 +15,27 @@ # specific language governing permissions and limitations # under the License. """Executes task in a Kubernetes POD""" +import json +import logging import re import warnings -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type +from contextlib import AbstractContextManager +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional from kubernetes.client import CoreV1Api, models as k8s +from airflow.providers.cncf.kubernetes.utils.pod_launcher import PodLaunchFailedException, PodPhase + try: import airflow.utils.yaml as yaml except ImportError: import yaml +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + from airflow.exceptions import AirflowException from airflow.kubernetes import kube_client, pod_generator from airflow.kubernetes.pod_generator import PodGenerator @@ -46,13 +56,16 @@ from airflow.providers.cncf.kubernetes.backcompat.pod_runtime_info_env import PodRuntimeInfoEnv from airflow.providers.cncf.kubernetes.utils import pod_launcher, xcom_sidecar from airflow.utils.helpers import validate_key -from airflow.utils.state import State from airflow.version import version as airflow_version if TYPE_CHECKING: import jinja2 +class PodReattachFailure(AirflowException): + """When we expect to be able to find a pod but cannot.""" + + class KubernetesPodOperator(BaseOperator): """ Execute a task in a Kubernetes Pod @@ -163,8 +176,12 @@ class KubernetesPodOperator(BaseOperator): :param termination_grace_period: Termination grace period if task killed in UI, defaults to kubernetes default :type termination_grace_period: int + """ + BASE_CONTAINER_NAME = 'base' + POD_CHECKED_KEY = 'already_checked' + template_fields: Iterable[str] = ( 'image', 'cmds', @@ -176,9 +193,7 @@ class KubernetesPodOperator(BaseOperator): 'namespace', ) - # fmt: off def __init__( - # fmt: on self, *, namespace: Optional[str] = None, @@ -269,8 +284,9 @@ def __init__( self.service_account_name = service_account_name self.is_delete_operator_pod = is_delete_operator_pod self.hostnetwork = hostnetwork - self.tolerations = [convert_toleration(toleration) for toleration in tolerations] \ - if tolerations else [] + self.tolerations = ( + [convert_toleration(toleration) for toleration in tolerations] if tolerations else [] + ) self.security_context = security_context or {} self.dnspolicy = dnspolicy self.schedulername = schedulername @@ -282,8 +298,8 @@ def __init__( self.name = self._set_name(name) self.random_name_suffix = random_name_suffix self.termination_grace_period = termination_grace_period - self.client: CoreV1Api = None - self.pod: k8s.V1Pod = None + self.pod_request_obj: Optional[k8s.V1Pod] = None + self.pod: Optional[k8s.V1Pod] = None def _render_nested_template_fields( self, @@ -297,27 +313,26 @@ def _render_nested_template_fields( self._do_render_template_fields(content, ('value', 'name'), context, jinja_env, seen_oids) return - super()._render_nested_template_fields( - content, - context, - jinja_env, - seen_oids - ) + super()._render_nested_template_fields(content, context, jinja_env, seen_oids) @staticmethod - def create_labels_for_pod(context) -> dict: + def _get_ti_pod_labels(context: Optional[dict] = None, include_try_number: bool = True) -> dict: """ Generate labels for the pod to track the pod in case of Operator crash :param context: task context provided by airflow DAG :return: dict """ + if not context: + return {} + labels = { 'dag_id': context['dag'].dag_id, 'task_id': context['task'].task_id, 'execution_date': context['ts'], - 'try_number': context['ti'].try_number, } + if include_try_number: + labels.update(try_number=context['ti'].try_number) # In the case of sub dags this is just useful if context['dag'].is_subdag: labels['parent_dag_id'] = context['dag'].parent_dag.dag_id @@ -328,101 +343,125 @@ def create_labels_for_pod(context) -> dict: labels[label_id] = safe_label return labels - def create_pod_launcher(self) -> Type[pod_launcher.PodLauncher]: - return pod_launcher.PodLauncher(kube_client=self.client, extract_xcom=self.do_xcom_push) + @cached_property + def launcher(self) -> pod_launcher.PodLauncher: + return pod_launcher.PodLauncher(kube_client=self.client) - def execute(self, context) -> Optional[str]: + @cached_property + def client(self) -> CoreV1Api: + # todo: use airflow Connection / hook to authenticate to the cluster + kwargs: Dict[str, Any] = dict( + cluster_context=self.cluster_context, + config_file=self.config_file, + ) + if self.in_cluster is not None: + kwargs.update(in_cluster=self.in_cluster) + return kube_client.get_kube_client(**kwargs) + + def find_pod(self, namespace, context) -> Optional[k8s.V1Pod]: + """Returns an already-running pod for this task instance if one exists.""" + label_selector = self._build_find_pod_label_selector(context) + pod_list = self.client.list_namespaced_pod( + namespace=namespace, + label_selector=label_selector, + ).items + + num_pods = len(pod_list) + if num_pods > 1: + raise AirflowException(f'More than one pod running with labels {label_selector}') + elif num_pods == 1: + pod = pod_list[0] + self.log.info("Found matching pod %s with labels %s", pod.metadata.name, pod.metadata.labels) + self.log.info("`try_number` of task_instance: %s", context['ti'].try_number) + self.log.info("`try_number` of pod: %s", pod.metadata.labels['try_number']) + return pod + + def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context): + if self.reattach_on_restart: + pod = self.find_pod(self.namespace or pod_request_obj.metadata.namespace, context=context) + if pod: + return pod + self.log.debug("Starting pod:\n%s", yaml.safe_dump(pod_request_obj.to_dict())) + self.launcher.create_pod(pod=pod_request_obj) + return pod_request_obj + + def await_pod_start(self, pod): try: - if self.in_cluster is not None: - client = kube_client.get_kube_client( - in_cluster=self.in_cluster, - cluster_context=self.cluster_context, - config_file=self.config_file, - ) - else: - client = kube_client.get_kube_client( - cluster_context=self.cluster_context, config_file=self.config_file - ) - - self.client = client - - self.pod = self.create_pod_request_obj() - self.namespace = self.pod.metadata.namespace - - # Add combination of labels to uniquely identify a running pod - labels = self.create_labels_for_pod(context) - - label_selector = self._get_pod_identifying_label_string(labels) - - pod_list = self.client.list_namespaced_pod(self.namespace, label_selector=label_selector) + self.launcher.await_pod_start(pod=pod, startup_timeout=self.startup_timeout_seconds) + except PodLaunchFailedException: + if self.log_events_on_failure: + for event in self.launcher.read_pod_events(pod).items: + self.log.error("Pod Event: %s - %s", event.reason, event.message) + raise - if len(pod_list.items) > 1 and self.reattach_on_restart: - raise AirflowException( - f'More than one pod running with labels: {label_selector}' - ) + def extract_xcom(self, pod): + """Retrieves xcom value and kills xcom sidecar container""" + result = self.launcher.extract_xcom(pod) + self.log.info("xcom result: \n%s", result) + return json.loads(result) - launcher = self.create_pod_launcher() + def execute(self, context): + remote_pod = None + try: + self.pod_request_obj = self.build_pod_request_obj(context) + self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill` + pod_request_obj=self.pod_request_obj, + context=context, + ) + self.await_pod_start(pod=self.pod) - if len(pod_list.items) == 1: - try_numbers_match = self._try_numbers_match(context, pod_list.items[0]) - final_state, remote_pod, result = self.handle_pod_overlap( - labels, try_numbers_match, launcher, pod_list.items[0] + if self.get_logs: + self.launcher.follow_container_logs( + pod=self.pod, + container_name=self.BASE_CONTAINER_NAME, ) else: - self.log.info("creating pod with labels %s and launcher %s", labels, launcher) - final_state, remote_pod, result = self.create_new_pod_for_operator(labels, launcher) - if final_state != State.SUCCESS: - raise AirflowException(f'Pod {self.pod.metadata.name} returned a failure: {remote_pod}') - context['task_instance'].xcom_push(key='pod_name', value=self.pod.metadata.name) - context['task_instance'].xcom_push(key='pod_namespace', value=self.namespace) - return result - except AirflowException as ex: - raise AirflowException(f'Pod Launching failed: {ex}') - - def handle_pod_overlap( - self, labels: dict, try_numbers_match: bool, launcher: Any, pod: k8s.V1Pod - ) -> Tuple[State, k8s.V1Pod, Optional[str]]: - """ + self.launcher.await_container_completion( + pod=self.pod, container_name=self.BASE_CONTAINER_NAME + ) - In cases where the Scheduler restarts while a KubernetesPodOperator task is running, - this function will either continue to monitor the existing pod or launch a new pod - based on the `reattach_on_restart` parameter. + if self.do_xcom_push: + result = self.extract_xcom(pod=self.pod) + remote_pod = self.launcher.await_pod_completion(self.pod) + finally: + self.cleanup( + pod=self.pod or self.pod_request_obj, + remote_pod=remote_pod, + ) + ti = context['ti'] + ti.xcom_push(key='pod_name', value=self.pod.metadata.name) + ti.xcom_push(key='pod_namespace', value=self.pod.metadata.namespace) + if self.do_xcom_push: + return result - :param labels: labels used to determine if a pod is repeated - :type labels: dict - :param try_numbers_match: do the try numbers match? Only needed for logging purposes - :type try_numbers_match: bool - :param launcher: PodLauncher - :param pod: Pod found with matching labels - """ - if try_numbers_match: - log_line = f"found a running pod with labels {labels} and the same try_number." - else: - log_line = f"found a running pod with labels {labels} but a different try_number." - - # In case of failed pods, should reattach the first time, but only once - # as the task will have already failed. - if self.reattach_on_restart and not pod.metadata.labels.get("already_checked"): - log_line += " Will attach to this pod and monitor instead of starting new one" - self.log.info(log_line) - self.pod = pod - final_state, remote_pod, result = self.monitor_launched_pod(launcher, pod) + def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod): + pod_phase = remote_pod.status.phase if hasattr(remote_pod, 'status') else None + if pod_phase != PodPhase.SUCCEEDED: + if self.log_events_on_failure: + with _suppress(Exception): + for event in self.launcher.read_pod_events(pod).items: + self.log.error("Pod Event: %s - %s", event.reason, event.message) + if not self.is_delete_operator_pod: + with _suppress(Exception): + self.patch_already_checked(pod) + with _suppress(Exception): + self.process_pod_deletion(pod) + raise AirflowException(f'Pod {pod and pod.metadata.name} returned a failure: {remote_pod}') else: - log_line += f"creating pod with labels {labels} and launcher {launcher}" - self.log.info(log_line) - final_state, remote_pod, result = self.create_new_pod_for_operator(labels, launcher) - return final_state, remote_pod, result + with _suppress(Exception): + self.process_pod_deletion(pod) - @staticmethod - def _get_pod_identifying_label_string(labels) -> str: - label_strings = [ - f'{label_id}={label}' for label_id, label in sorted(labels.items()) if label_id != 'try_number' - ] - return ','.join(label_strings) + ',already_checked!=True' + def process_pod_deletion(self, pod): + if self.is_delete_operator_pod: + self.log.info("Deleting pod: %s", pod.metadata.name) + self.launcher.delete_pod(pod) + else: + self.log.info("skipping deleting pod: %s", pod.metadata.name) - @staticmethod - def _try_numbers_match(context, pod) -> bool: - return pod.metadata.labels['try_number'] == context['ti'].try_number + def _build_find_pod_label_selector(self, context: Optional[dict] = None) -> str: + labels = self._get_ti_pod_labels(context, include_try_number=False) + label_strings = [f'{label_id}={label}' for label_id, label in sorted(labels.items())] + return ','.join(label_strings) + f',{self.POD_CHECKED_KEY}!=True' def _set_name(self, name): if name is None: @@ -433,7 +472,24 @@ def _set_name(self, name): validate_key(name, max_length=220) return re.sub(r'[^a-z0-9.-]+', '-', name.lower()) - def create_pod_request_obj(self) -> k8s.V1Pod: + def patch_already_checked(self, pod: k8s.V1Pod): + """Add an "already checked" annotation to ensure we don't reattach on retries""" + pod.metadata.labels[self.POD_CHECKED_KEY] = "True" + body = PodGenerator.serialize_pod(pod) + self.client.patch_namespaced_pod(pod.metadata.name, pod.metadata.namespace, body) + + def on_kill(self) -> None: + if self.pod: + pod = self.pod + kwargs = dict( + name=pod.metadata.name, + namespace=pod.metadata.namespace, + ) + if self.termination_grace_period is not None: + kwargs.update(grace_period_seconds=self.termination_grace_period) + self.client.delete_namespaced_pod(**kwargs) + + def build_pod_request_obj(self, context=None): """ Creates a V1Pod based on user parameters. Note that a `pod` or `pod_template_file` will supersede all other values. @@ -467,7 +523,7 @@ def create_pod_request_obj(self) -> k8s.V1Pod: containers=[ k8s.V1Container( image=self.image, - name="base", + name=self.BASE_CONTAINER_NAME, command=self.cmds, ports=self.ports, image_pull_policy=self.image_pull_policy, @@ -501,83 +557,43 @@ def create_pod_request_obj(self) -> k8s.V1Pod: if self.do_xcom_push: self.log.debug("Adding xcom sidecar to task %s", self.task_id) pod = xcom_sidecar.add_xcom_sidecar(pod) - return pod - def create_new_pod_for_operator(self, labels, launcher) -> Tuple[State, k8s.V1Pod, Optional[str]]: - """ - Creates a new pod and monitors for duration of task - - :param labels: labels used to track pod - :param launcher: pod launcher that will manage launching and monitoring pods - :return: - """ - self.log.debug( - "Adding KubernetesPodOperator labels to pod before launch for task %s", self.task_id - ) + labels = self._get_ti_pod_labels(context) + self.log.info("Creating pod %s with labels: %s", pod.metadata.name, labels) # Merge Pod Identifying labels with labels passed to operator - self.pod.metadata.labels.update(labels) + pod.metadata.labels.update(labels) # Add Airflow Version to the label # And a label to identify that pod is launched by KubernetesPodOperator - self.pod.metadata.labels.update( + pod.metadata.labels.update( { 'airflow_version': airflow_version.replace('+', '-'), 'kubernetes_pod_operator': 'True', } ) + return pod - self.log.debug("Starting pod:\n%s", yaml.safe_dump(self.pod.to_dict())) - final_state = None - try: - launcher.start_pod(self.pod, startup_timeout=self.startup_timeout_seconds) - final_state, remote_pod, result = launcher.monitor_pod(pod=self.pod, get_logs=self.get_logs) - except AirflowException: - if self.log_events_on_failure: - for event in launcher.read_pod_events(self.pod).items: - self.log.error("Pod Event: %s - %s", event.reason, event.message) - raise - finally: - if self.is_delete_operator_pod: - self.log.debug("Deleting pod for task %s", self.task_id) - launcher.delete_pod(self.pod) - elif final_state != State.SUCCESS: - self.patch_already_checked(self.pod) - return final_state, remote_pod, result - def patch_already_checked(self, pod: k8s.V1Pod): - """Add an "already tried annotation to ensure we only retry once""" - pod.metadata.labels["already_checked"] = "True" - body = PodGenerator.serialize_pod(pod) - self.client.patch_namespaced_pod(pod.metadata.name, pod.metadata.namespace, body) +class _suppress(AbstractContextManager): + """ + This behaves the same as ``contextlib.suppress`` but logs the suppressed + exceptions as errors with traceback. - def monitor_launched_pod(self, launcher, pod) -> Tuple[State, Optional[str]]: - """ - Monitors a pod to completion that was created by a previous KubernetesPodOperator + The caught exception is also stored on the context manager instance under + attribute ``exception``. + """ - :param launcher: pod launcher that will manage launching and monitoring pods - :param pod: podspec used to find pod using k8s API - :return: - """ - try: - (final_state, remote_pod, result) = launcher.monitor_pod(pod, get_logs=self.get_logs) - finally: - if self.is_delete_operator_pod: - launcher.delete_pod(pod) - if final_state != State.SUCCESS: - if self.log_events_on_failure: - for event in launcher.read_pod_events(pod).items: - self.log.error("Pod Event: %s - %s", event.reason, event.message) - if not self.is_delete_operator_pod: - self.patch_already_checked(pod) - raise AirflowException(f'Pod returned a failure: {final_state}') - return final_state, remote_pod, result + def __init__(self, *exceptions): + self._exceptions = exceptions + self.exception = None - def on_kill(self) -> None: - if self.pod: - pod: k8s.V1Pod = self.pod - namespace = pod.metadata.namespace - name = pod.metadata.name - kwargs = {} - if self.termination_grace_period is not None: - kwargs = {"grace_period_seconds": self.termination_grace_period} - self.client.delete_namespaced_pod(name=name, namespace=namespace, **kwargs) + def __enter__(self): + return self + + def __exit__(self, exctype, excinst, exctb): + caught_error = exctype is not None and issubclass(exctype, self._exceptions) + if caught_error: + self.exception = excinst + logger = logging.getLogger() + logger.error(str(excinst), exc_info=True) + return caught_error diff --git a/airflow/providers/cncf/kubernetes/utils/pod_launcher.py b/airflow/providers/cncf/kubernetes/utils/pod_launcher.py index e76ae40b7b1e..43c4ebe597ef 100644 --- a/airflow/providers/cncf/kubernetes/utils/pod_launcher.py +++ b/airflow/providers/cncf/kubernetes/utils/pod_launcher.py @@ -18,13 +18,13 @@ import json import math import time -from datetime import datetime as dt +from contextlib import closing +from datetime import datetime from typing import Iterable, Optional, Tuple, Union import pendulum import tenacity from kubernetes import client, watch -from kubernetes.client.models.v1_event import V1Event from kubernetes.client.models.v1_event_list import V1EventList from kubernetes.client.models.v1_pod import V1Pod from kubernetes.client.rest import ApiException @@ -38,7 +38,10 @@ from airflow.kubernetes.pod_generator import PodDefaults from airflow.settings import pod_mutation_hook from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.state import State + + +class PodLaunchFailedException(AirflowException): + """When pod launching fails in KubernetesPodOperator.""" def should_retry_start_pod(exception: Exception) -> bool: @@ -48,13 +51,32 @@ def should_retry_start_pod(exception: Exception) -> bool: return False -class PodStatus: - """Status of the PODs""" +class PodPhase: + """ + Possible pod phases + See https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-phase. + """ + + PENDING = 'Pending' + RUNNING = 'Running' + FAILED = 'Failed' + SUCCEEDED = 'Succeeded' - PENDING = 'pending' - RUNNING = 'running' - FAILED = 'failed' - SUCCEEDED = 'succeeded' + terminal_states = {FAILED, SUCCEEDED} + + +def container_is_running(pod: V1Pod, container_name: str) -> bool: + """ + Examines V1Pod ``pod`` to determine whether ``container_name`` is running. + If that container is present and running, returns True. Returns False otherwise. + """ + container_statuses = pod.status.container_statuses if pod and pod.status else None + if not container_statuses: + return False + container_status = next(iter([x for x in container_statuses if x.name == container_name]), None) + if not container_status: + return False + return container_status.state.running is not None class PodLauncher(LoggingMixin): @@ -65,7 +87,6 @@ def __init__( kube_client: client.CoreV1Api = None, in_cluster: bool = True, cluster_context: Optional[str] = None, - extract_xcom: bool = False, ): """ Creates the launcher. @@ -73,12 +94,10 @@ def __init__( :param kube_client: kubernetes client :param in_cluster: whether we are in cluster :param cluster_context: context of the cluster - :param extract_xcom: whether we should extract xcom """ super().__init__() self._client = kube_client or get_kube_client(in_cluster=in_cluster, cluster_context=cluster_context) self._watch = watch.Watch() - self.extract_xcom = extract_xcom def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod: """Runs POD asynchronously""" @@ -117,79 +136,104 @@ def delete_pod(self, pod: V1Pod) -> None: reraise=True, retry=tenacity.retry_if_exception(should_retry_start_pod), ) - def start_pod(self, pod: V1Pod, startup_timeout: int = 120) -> None: + def create_pod(self, pod: V1Pod) -> V1Pod: + """Launches the pod asynchronously.""" + return self.run_pod_async(pod) + + def await_pod_start(self, pod: V1Pod, startup_timeout: int = 120) -> None: """ - Launches the pod synchronously and waits for completion. + Waits for the pod to reach phase other than ``Pending`` :param pod: :param startup_timeout: Timeout (in seconds) for startup of the pod (if pod is pending for too long, fails task) :return: """ - resp = self.run_pod_async(pod) - curr_time = dt.now() - if resp.status.start_time is None: - while self.pod_not_started(pod): - self.log.warning("Pod not yet started: %s", pod.metadata.name) - delta = dt.now() - curr_time - if delta.total_seconds() >= startup_timeout: - msg = ( - f"Pod took longer than {startup_timeout} seconds to start. " - "Check the pod events in kubernetes to determine why." - ) - raise AirflowException(msg) - time.sleep(1) - - def monitor_pod(self, pod: V1Pod, get_logs: bool) -> Tuple[State, V1Pod, Optional[str]]: + curr_time = datetime.now() + while True: + remote_pod = self.read_pod(pod) + if remote_pod.status.phase != PodPhase.PENDING: + break + self.log.warning("Pod not yet started: %s", pod.metadata.name) + delta = datetime.now() - curr_time + if delta.total_seconds() >= startup_timeout: + msg = ( + f"Pod took longer than {startup_timeout} seconds to start. " + "Check the pod events in kubernetes to determine why." + ) + raise PodLaunchFailedException(msg) + time.sleep(1) + + def follow_container_logs(self, pod: V1Pod, container_name: str) -> None: """ - Monitors a pod and returns the final state, pod and xcom result + Follows the logs of container and streams to airflow logging. + Returns when container exits. - :param pod: pod spec that will be monitored - :param get_logs: whether to read the logs locally - :return: Tuple[State, Optional[str]] + .. note:: :meth:`read_pod_logs` follows the logs, so we shouldn't necessarily *need* to + loop as we do here. But in a long-running process we might temporarily lose connectivity. + So the looping logic is there to let us resume following the logs. """ - if get_logs: - read_logs_since_sec = None - last_log_time = None - while True: - try: - logs = self.read_pod_logs(pod, timestamps=True, since_seconds=read_logs_since_sec) - for line in logs: - timestamp, message = self.parse_log_line(line.decode('utf-8')) - self.log.info(message) - if timestamp: - last_log_time = timestamp - except BaseHTTPError: - # Catches errors like ProtocolError(TimeoutError). - self.log.warning( - 'Failed to read logs for pod %s', - pod.metadata.name, - exc_info=True, - ) + def follow_logs(since_time: Optional[datetime] = None) -> Optional[datetime]: + """ + Tries to follow container logs until container completes. + For a long-running container, sometimes the log read may be interrupted + Such errors of this kind are suppressed. + + Returns the last timestamp observed in logs. + """ + timestamp = None + try: + logs = self.read_pod_logs( + pod=pod, + container_name=container_name, + timestamps=True, + since_seconds=( + math.ceil((pendulum.now() - since_time).total_seconds()) if since_time else None + ), + ) + for line in logs: # type: bytes + timestamp, message = self.parse_log_line(line.decode('utf-8')) + self.log.info(message) + except BaseHTTPError: # Catches errors like ProtocolError(TimeoutError). + self.log.warning( + 'Failed to read logs for pod %s', + pod.metadata.name, + exc_info=True, + ) + return timestamp or since_time + + last_log_time = None + while True: + last_log_time = follow_logs(since_time=last_log_time) + if not self.container_is_running(pod, container_name=container_name): + return + else: + self.log.warning( + 'Pod %s log read interrupted but container %s still running', + pod.metadata.name, + container_name, + ) time.sleep(1) - if not self.base_container_is_running(pod): - break + def await_container_completion(self, pod: V1Pod, container_name: str) -> None: + while not self.container_is_running(pod=pod, container_name=container_name): + time.sleep(1) - self.log.warning('Pod %s log read interrupted', pod.metadata.name) - if last_log_time: - delta = pendulum.now() - last_log_time - # Prefer logs duplication rather than loss - read_logs_since_sec = math.ceil(delta.total_seconds()) - result = None - if self.extract_xcom: - while self.base_container_is_running(pod): - self.log.info('Container %s has state %s', pod.metadata.name, State.RUNNING) - time.sleep(2) - result = self._extract_xcom(pod) - self.log.info(result) - result = json.loads(result) - while self.pod_is_running(pod): - self.log.info('Pod %s has state %s', pod.metadata.name, State.RUNNING) + def await_pod_completion(self, pod: V1Pod) -> V1Pod: + """ + Monitors a pod and returns the final state + + :param pod: pod spec that will be monitored + :return: Tuple[State, Optional[str]] + """ + while True: + remote_pod = self.read_pod(pod) + if remote_pod.status.phase in PodPhase.terminal_states: + break + self.log.info('Pod %s has phase %s', pod.metadata.name, remote_pod.status.phase) time.sleep(2) - remote_pod = self.read_pod(pod) - return self._task_status(remote_pod), remote_pod, result + return remote_pod def parse_log_line(self, line: str) -> Tuple[Optional[Union[Date, Time, DateTime, Duration]], str]: """ @@ -212,35 +256,16 @@ def parse_log_line(self, line: str) -> Tuple[Optional[Union[Date, Time, DateTime return None, line return last_log_time, message - def _task_status(self, event: V1Event) -> str: - self.log.info('Event: %s had an event of type %s', event.metadata.name, event.status.phase) - status = self.process_status(event.metadata.name, event.status.phase) - return status - - def pod_not_started(self, pod: V1Pod) -> bool: - """Tests if pod has not started""" - state = self._task_status(self.read_pod(pod)) - return state == State.QUEUED - - def pod_is_running(self, pod: V1Pod) -> bool: - """Tests if pod is running""" - state = self._task_status(self.read_pod(pod)) - return state not in (State.SUCCESS, State.FAILED) - - def base_container_is_running(self, pod: V1Pod) -> bool: - """Tests if base container is running""" - event = self.read_pod(pod) - if not (event and event.status and event.status.container_statuses): - return False - status = next(iter(filter(lambda s: s.name == 'base', event.status.container_statuses)), None) - if not status: - return False - return status.state.running is not None + def container_is_running(self, pod: V1Pod, container_name: str) -> bool: + """Reads pod and checks if container is running""" + remote_pod = self.read_pod(pod) + return container_is_running(pod=remote_pod, container_name=container_name) @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True) def read_pod_logs( self, pod: V1Pod, + container_name: str, tail_lines: Optional[int] = None, timestamps: bool = False, since_seconds: Optional[int] = None, @@ -257,7 +282,7 @@ def read_pod_logs( return self._client.read_namespaced_pod_log( name=pod.metadata.name, namespace=pod.metadata.namespace, - container='base', + container=container_name, follow=True, timestamps=timestamps, _preload_content=False, @@ -265,7 +290,6 @@ def read_pod_logs( ) except BaseHTTPError: self.log.exception('There was an error reading the kubernetes API.') - # Reraise to be caught by self.monitor_pod. raise @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True) @@ -286,29 +310,29 @@ def read_pod(self, pod: V1Pod) -> V1Pod: except BaseHTTPError as e: raise AirflowException(f'There was an error reading the kubernetes API: {e}') - def _extract_xcom(self, pod: V1Pod) -> str: - resp = kubernetes_stream( - self._client.connect_get_namespaced_pod_exec, - pod.metadata.name, - pod.metadata.namespace, - container=PodDefaults.SIDECAR_CONTAINER_NAME, - command=['/bin/sh'], - stdin=True, - stdout=True, - stderr=True, - tty=False, - _preload_content=False, - ) - try: + def extract_xcom(self, pod: V1Pod) -> str: + """Retrieves XCom value and kills xcom sidecar container""" + with closing( + kubernetes_stream( + self._client.connect_get_namespaced_pod_exec, + pod.metadata.name, + pod.metadata.namespace, + container=PodDefaults.SIDECAR_CONTAINER_NAME, + command=['/bin/sh'], + stdin=True, + stdout=True, + stderr=True, + tty=False, + _preload_content=False, + ) + ) as resp: result = self._exec_pod_command(resp, f'cat {PodDefaults.XCOM_MOUNT_PATH}/return.json') self._exec_pod_command(resp, 'kill -s SIGINT 1') - finally: - resp.close() if result is None: raise AirflowException(f'Failed to extract xcom from pod: {pod.metadata.name}') return result - def _exec_pod_command(self, resp, command: str) -> None: + def _exec_pod_command(self, resp, command: str) -> Optional[str]: if resp.is_open(): self.log.info('Running command... %s\n', command) resp.write_stdin(command + '\n') @@ -317,23 +341,6 @@ def _exec_pod_command(self, resp, command: str) -> None: if resp.peek_stdout(): return resp.read_stdout() if resp.peek_stderr(): - self.log.info(resp.read_stderr()) + self.log.info("stderr from command: %s", resp.read_stderr()) break return None - - def process_status(self, job_id: str, status: str) -> str: - """Process status information for the JOB""" - status = status.lower() - if status == PodStatus.PENDING: - return State.QUEUED - elif status == PodStatus.FAILED: - self.log.error('Event with job id %s Failed', job_id) - return State.FAILED - elif status == PodStatus.SUCCEEDED: - self.log.info('Event with job id %s Succeeded', job_id) - return State.SUCCESS - elif status == PodStatus.RUNNING: - return State.RUNNING - else: - self.log.error('Event: Invalid state %s on job %s', status, job_id) - return State.FAILED diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b196744b5b24..d652742c8749 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1297,6 +1297,7 @@ storedInfoType str stringified subchart +subclassed subclasses subclassing subcluster @@ -1409,6 +1410,7 @@ unpausing unpredicted unqueued unterminated +untransformed unutilized updateMask updateonly diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py index 3944dc7974db..082d2ceb39b1 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator.py +++ b/kubernetes_tests/test_kubernetes_pod_operator.py @@ -23,7 +23,7 @@ import textwrap import unittest from unittest import mock -from unittest.mock import ANY +from unittest.mock import ANY, MagicMock import pendulum import pytest @@ -34,7 +34,7 @@ from airflow.exceptions import AirflowException from airflow.kubernetes import kube_client from airflow.kubernetes.secret import Secret -from airflow.models import DAG, DagRun, TaskInstance +from airflow.models import DAG, XCOM_RETURN_KEY, DagRun, TaskInstance from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.utils.pod_launcher import PodLauncher from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults @@ -156,6 +156,7 @@ def test_config_path_move(self): task_id="task" + self.get_current_task_name(), in_cluster=False, do_xcom_push=False, + is_delete_operator_pod=False, config_file=new_config_path, ) context = create_context(k) @@ -516,12 +517,10 @@ def test_faulty_service_account(self): startup_timeout_seconds=5, service_account_name=bad_service_account_name, ) - with pytest.raises(ApiException): - context = create_context(k) - k.execute(context) - actual_pod = self.api_client.sanitize_for_serialization(k.pod) - self.expected_pod['spec']['serviceAccountName'] = bad_service_account_name - assert self.expected_pod == actual_pod + context = create_context(k) + pod = k.build_pod_request_obj(context) + with pytest.raises(ApiException, match="error looking up service account default/foobar"): + k.get_or_create_pod(pod, context) def test_pod_failure(self): """ @@ -546,7 +545,8 @@ def test_pod_failure(self): self.expected_pod['spec']['containers'][0]['args'] = bad_internal_command assert self.expected_pod == actual_pod - def test_xcom_push(self): + @mock.patch("airflow.models.taskinstance.TaskInstance.xcom_push") + def test_xcom_push(self, xcom_push): return_value = '{"foo": "bar"\n, "buzz": 2}' args = [f'echo \'{return_value}\' > /airflow/xcom/return.json'] k = KubernetesPodOperator( @@ -561,7 +561,8 @@ def test_xcom_push(self): do_xcom_push=True, ) context = create_context(k) - assert k.execute(context) == json.loads(return_value) + k.execute(context) + assert xcom_push.called_once_with(key=XCOM_RETURN_KEY, value=json.loads(return_value)) actual_pod = self.api_client.sanitize_for_serialization(k.pod) volume = self.api_client.sanitize_for_serialization(PodDefaults.VOLUME) volume_mount = self.api_client.sanitize_for_serialization(PodDefaults.VOLUME_MOUNT) @@ -572,12 +573,11 @@ def test_xcom_push(self): self.expected_pod['spec']['containers'].append(container) assert self.expected_pod == actual_pod - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.start_pod") - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.monitor_pod") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.create_pod") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.await_pod_completion") @mock.patch("airflow.kubernetes.kube_client.get_kube_client") - def test_envs_from_secrets(self, mock_client, monitor_mock, start_mock): + def test_envs_from_secrets(self, mock_client, await_pod_completion_mock, create_pod): # GIVEN - from airflow.utils.state import State secret_ref = 'secret_name' secrets = [Secret('env', None, secret_ref)] @@ -595,10 +595,11 @@ def test_envs_from_secrets(self, mock_client, monitor_mock, start_mock): do_xcom_push=False, ) # THEN - monitor_mock.return_value = (State.SUCCESS, None, None) + await_pod_completion_mock.return_value = None context = create_context(k) - k.execute(context) - assert start_mock.call_args[0][0].spec.containers[0].env_from == [ + with pytest.raises(AirflowException): + k.execute(context) + assert create_pod.call_args[1]['pod'].spec.containers[0].env_from == [ k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name=secret_ref)) ] @@ -625,12 +626,9 @@ def test_env_vars(self): in_cluster=False, do_xcom_push=False, ) - - context = create_context(k) - k.execute(context) - # THEN - actual_pod = self.api_client.sanitize_for_serialization(k.pod) + context = create_context(k) + actual_pod = self.api_client.sanitize_for_serialization(k.build_pod_request_obj(context)) self.expected_pod['spec']['containers'][0]['env'] = [ {'name': 'ENV1', 'value': 'val1'}, {'name': 'ENV2', 'value': 'val2'}, @@ -741,6 +739,7 @@ def test_full_pod_spec(self): in_cluster=False, full_pod_spec=pod_spec, do_xcom_push=True, + is_delete_operator_pod=False, ) context = create_context(k) @@ -814,12 +813,12 @@ def test_init_container(self): ] assert self.expected_pod == actual_pod - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.start_pod") - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.monitor_pod") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.extract_xcom") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.create_pod") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.await_pod_completion") @mock.patch("airflow.kubernetes.kube_client.get_kube_client") - def test_pod_template_file(self, mock_client, monitor_mock, start_mock): - from airflow.utils.state import State - + def test_pod_template_file(self, mock_client, await_pod_completion_mock, create_mock, extract_xcom_mock): + extract_xcom_mock.return_value = '{}' path = sys.path[0] + '/tests/kubernetes/pod.yaml' k = KubernetesPodOperator( task_id="task" + self.get_current_task_name(), @@ -827,8 +826,9 @@ def test_pod_template_file(self, mock_client, monitor_mock, start_mock): pod_template_file=path, do_xcom_push=True, ) - - monitor_mock.return_value = (State.SUCCESS, None, None) + pod_mock = MagicMock() + pod_mock.status.phase = 'Succeeded' + await_pod_completion_mock.return_value = pod_mock context = create_context(k) with self.assertLogs(k.log, level=logging.DEBUG) as cm: k.execute(context) @@ -899,12 +899,11 @@ def test_pod_template_file(self, mock_client, monitor_mock, start_mock): del actual_pod['metadata']['labels']['airflow_version'] assert expected_dict == actual_pod - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.start_pod") - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.monitor_pod") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.create_pod") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.await_pod_completion") @mock.patch("airflow.kubernetes.kube_client.get_kube_client") - def test_pod_priority_class_name(self, mock_client, monitor_mock, start_mock): + def test_pod_priority_class_name(self, mock_client, await_pod_completion_mock, create_mock): """Test ability to assign priorityClassName to pod""" - from airflow.utils.state import State priority_class_name = "medium-test" k = KubernetesPodOperator( @@ -920,7 +919,9 @@ def test_pod_priority_class_name(self, mock_client, monitor_mock, start_mock): priority_class_name=priority_class_name, ) - monitor_mock.return_value = (State.SUCCESS, None, None) + pod_mock = MagicMock() + pod_mock.status.phase = 'Succeeded' + await_pod_completion_mock.return_value = pod_mock context = create_context(k) k.execute(context) actual_pod = self.api_client.sanitize_for_serialization(k.pod) @@ -942,9 +943,8 @@ def test_pod_name(self): do_xcom_push=False, ) - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.monitor_pod") - def test_on_kill(self, monitor_mock): - from airflow.utils.state import State + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.await_pod_completion") + def test_on_kill(self, await_pod_completion_mock): client = kube_client.get_kube_client(in_cluster=False) name = "test" @@ -959,21 +959,20 @@ def test_on_kill(self, monitor_mock): task_id=name, in_cluster=False, do_xcom_push=False, + get_logs=False, termination_grace_period=0, ) context = create_context(k) - monitor_mock.return_value = (State.SUCCESS, None, None) - k.execute(context) + with pytest.raises(AirflowException): + k.execute(context) name = k.pod.metadata.name pod = client.read_namespaced_pod(name=name, namespace=namespace) assert pod.status.phase == "Running" k.on_kill() - with pytest.raises(ApiException): - pod = client.read_namespaced_pod(name=name, namespace=namespace) + with pytest.raises(ApiException, match=r'pods \\"test.[a-z0-9]+\\" not found'): + client.read_namespaced_pod(name=name, namespace=namespace) def test_reattach_failing_pod_once(self): - from airflow.utils.state import State - client = kube_client.get_kube_client(in_cluster=False) name = "test" namespace = "default" @@ -993,24 +992,38 @@ def test_reattach_failing_pod_once(self): context = create_context(k) + # launch pod with mock.patch( - "airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.monitor_pod" - ) as monitor_mock: - monitor_mock.return_value = (State.SUCCESS, None, None) + "airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.await_pod_completion" + ) as await_pod_completion_mock: + pod_mock = MagicMock() + + # we don't want failure because we don't want the pod to be patched as "already_checked" + pod_mock.status.phase = 'Succeeded' + await_pod_completion_mock.return_value = pod_mock k.execute(context) name = k.pod.metadata.name pod = client.read_namespaced_pod(name=name, namespace=namespace) while pod.status.phase != "Failed": pod = client.read_namespaced_pod(name=name, namespace=namespace) - with pytest.raises(AirflowException): - k.execute(context) - pod = client.read_namespaced_pod(name=name, namespace=namespace) - assert pod.metadata.labels["already_checked"] == "True" + assert 'already_checked' not in pod.metadata.labels + + # should not call `create_pod`, because there's a pod there it should find + # should use the found pod and patch as "already_checked" (in failure block) with mock.patch( - "airflow.providers.cncf.kubernetes" - ".operators.kubernetes_pod.KubernetesPodOperator" - ".create_new_pod_for_operator" + "airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.create_pod" ) as create_mock: - create_mock.return_value = ("success", {}, {}) - k.execute(context) + with pytest.raises(AirflowException): + k.execute(context) + pod = client.read_namespaced_pod(name=name, namespace=namespace) + assert pod.metadata.labels["already_checked"] == "True" + create_mock.assert_not_called() + + # `create_pod` should be called because though there's still a pod to be found, + # it will be `already_checked` + with mock.patch( + "airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.create_pod" + ) as create_mock: + with pytest.raises(AirflowException): + k.execute(context) create_mock.assert_called_once() diff --git a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py index 6cbfabd5e91a..c11375579a67 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py +++ b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py @@ -19,7 +19,7 @@ import sys import unittest from unittest import mock -from unittest.mock import patch +from unittest.mock import MagicMock, patch import kubernetes.client.models as k8s import pendulum @@ -39,7 +39,6 @@ from airflow.providers.cncf.kubernetes.utils.pod_launcher import PodLauncher from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults from airflow.utils import timezone -from airflow.utils.state import State from airflow.version import version as airflow_version # noinspection DuplicatedCode @@ -118,10 +117,10 @@ def tearDown(self): client = kube_client.get_kube_client(in_cluster=False) client.delete_collection_namespaced_pod(namespace="default") - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.start_pod") - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.monitor_pod") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.create_pod") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.await_pod_completion") @mock.patch("airflow.kubernetes.kube_client.get_kube_client") - def test_image_pull_secrets_correctly_set(self, mock_client, monitor_mock, start_mock): + def test_image_pull_secrets_correctly_set(self, mock_client, await_pod_completion_mock, create_mock): fake_pull_secrets = "fakeSecret" k = KubernetesPodOperator( namespace='default', @@ -136,10 +135,12 @@ def test_image_pull_secrets_correctly_set(self, mock_client, monitor_mock, start image_pull_secrets=fake_pull_secrets, cluster_context='default', ) - monitor_mock.return_value = (State.SUCCESS, None, None) + mock_pod = MagicMock() + mock_pod.status.phase = 'Succeeded' + await_pod_completion_mock.return_value = mock_pod context = create_context(k) k.execute(context=context) - assert start_mock.call_args[0][0].spec.image_pull_secrets == [ + assert create_mock.call_args[1]['pod'].spec.image_pull_secrets == [ k8s.V1LocalObjectReference(name=fake_pull_secrets) ] @@ -378,9 +379,11 @@ def test_fs_group(self): assert self.expected_pod == actual_pod def test_faulty_service_account(self): - bad_service_account_name = "foobar" + """pod creation should fail when service account does not exist""" + service_account = "foobar" + namespace = "default" k = KubernetesPodOperator( - namespace='default', + namespace=namespace, image="ubuntu:16.04", cmds=["bash", "-cx"], arguments=["echo 10"], @@ -390,14 +393,14 @@ def test_faulty_service_account(self): in_cluster=False, do_xcom_push=False, startup_timeout_seconds=5, - service_account_name=bad_service_account_name, + service_account_name=service_account, ) - with pytest.raises(ApiException): - context = create_context(k) - k.execute(context) - actual_pod = self.api_client.sanitize_for_serialization(k.pod) - self.expected_pod['spec']['serviceAccountName'] = bad_service_account_name - assert self.expected_pod == actual_pod + context = create_context(k) + pod = k.build_pod_request_obj(context) + with pytest.raises( + ApiException, match=f"error looking up service account {namespace}/{service_account}" + ): + k.get_or_create_pod(pod, context) def test_pod_failure(self): """ @@ -448,8 +451,8 @@ def test_xcom_push(self): self.expected_pod['spec']['containers'].append(container) assert self.expected_pod == actual_pod - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.start_pod") - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.monitor_pod") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.create_pod") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.await_pod_completion") @mock.patch("airflow.kubernetes.kube_client.get_kube_client") def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start): # GIVEN @@ -468,17 +471,19 @@ def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start): configmaps=[configmap], ) # THEN - mock_monitor.return_value = (State.SUCCESS, None, None) + mock_pod = MagicMock() + mock_pod.status.phase = 'Succeeded' + mock_monitor.return_value = mock_pod context = create_context(k) k.execute(context) - assert mock_start.call_args[0][0].spec.containers[0].env_from == [ + assert mock_start.call_args[1]['pod'].spec.containers[0].env_from == [ k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name=configmap)) ] - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.start_pod") - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.monitor_pod") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.create_pod") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.await_pod_completion") @mock.patch("airflow.kubernetes.kube_client.get_kube_client") - def test_envs_from_secrets(self, mock_client, monitor_mock, start_mock): + def test_envs_from_secrets(self, mock_client, await_pod_completion_mock, create_mock): # GIVEN secret_ref = 'secret_name' secrets = [Secret('env', None, secret_ref)] @@ -496,10 +501,13 @@ def test_envs_from_secrets(self, mock_client, monitor_mock, start_mock): do_xcom_push=False, ) # THEN - monitor_mock.return_value = (State.SUCCESS, None, None) + + mock_pod = MagicMock() + mock_pod.status.phase = 'Succeeded' + await_pod_completion_mock.return_value = mock_pod context = create_context(k) k.execute(context) - assert start_mock.call_args[0][0].spec.containers[0].env_from == [ + assert create_mock.call_args[1]['pod'].spec.containers[0].env_from == [ k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name=secret_ref)) ] diff --git a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py index 5caa50062c5b..6b897592ca6d 100644 --- a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py +++ b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py @@ -14,9 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import unittest from tempfile import NamedTemporaryFile from unittest import mock +from unittest.mock import MagicMock import pytest from kubernetes.client import ApiClient, models as k8s @@ -25,34 +25,51 @@ from airflow.exceptions import AirflowException from airflow.models import DAG, DagRun, TaskInstance from airflow.models.xcom import IN_MEMORY_DAGRUN_ID -from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator +from airflow.operators.dummy import DummyOperator +from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator, _suppress from airflow.utils import timezone -from airflow.utils.state import State DEFAULT_DATE = timezone.datetime(2016, 1, 1, 1, 0, 0) -class TestKubernetesPodOperator(unittest.TestCase): - def setUp(self): - self.start_patch = mock.patch( - "airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.start_pod" - ) - self.monitor_patch = mock.patch( - "airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.monitor_pod" - ) +def create_context(task): + dag = DAG(dag_id="dag") + task_instance = TaskInstance(task=task, run_id=IN_MEMORY_DAGRUN_ID) + task_instance.dag_run = DagRun(run_id=IN_MEMORY_DAGRUN_ID) + return { + "dag": dag, + "ts": DEFAULT_DATE.isoformat(), + "task": task, + "ti": task_instance, + "task_instance": task_instance, + } + + +POD_LAUNCHER_CLASS = "airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher" + + +class TestKubernetesPodOperator: + def setup_method(self): + self.create_pod_patch = mock.patch(f"{POD_LAUNCHER_CLASS}.create_pod") + self.await_pod_patch = mock.patch(f"{POD_LAUNCHER_CLASS}.await_pod_start") + self.await_pod_completion_patch = mock.patch(f"{POD_LAUNCHER_CLASS}.await_pod_completion") self.client_patch = mock.patch("airflow.kubernetes.kube_client.get_kube_client") - self.start_mock = self.start_patch.start() - self.monitor_mock = self.monitor_patch.start() + self.create_mock = self.create_pod_patch.start() + self.await_start_mock = self.await_pod_patch.start() + self.await_pod_mock = self.await_pod_completion_patch.start() self.client_mock = self.client_patch.start() - self.addCleanup(self.start_patch.stop) - self.addCleanup(self.monitor_patch.stop) - self.addCleanup(self.client_patch.stop) + + def teardown_method(self): + self.create_pod_patch.stop() + self.await_pod_patch.stop() + self.await_pod_completion_patch.stop() + self.client_patch.stop() @staticmethod def create_context(task): dag = DAG(dag_id="dag") task_instance = TaskInstance(task=task, run_id=IN_MEMORY_DAGRUN_ID) - task_instance.dag_run = DagRun(run_id=IN_MEMORY_DAGRUN_ID, execution_date=DEFAULT_DATE) + task_instance.dag_run = DagRun(run_id=IN_MEMORY_DAGRUN_ID) return { "dag": dag, "ts": DEFAULT_DATE.isoformat(), @@ -62,10 +79,12 @@ def create_context(task): } def run_pod(self, operator) -> k8s.V1Pod: - self.monitor_mock.return_value = (State.SUCCESS, None, None) - context = self.create_context(operator) + context = create_context(operator) + remote_pod_mock = MagicMock() + remote_pod_mock.status.phase = 'Succeeded' + self.await_pod_mock.return_value = remote_pod_mock operator.execute(context=context) - return self.start_mock.call_args[0][0] + return self.await_start_mock.call_args[1]['pod'] def sanitize_for_serialization(self, obj): return ApiClient().sanitize_for_serialization(obj) @@ -85,9 +104,11 @@ def test_config_path(self): config_file=file_path, cluster_context="default", ) - self.monitor_mock.return_value = (State.SUCCESS, None, None) + remote_pod_mock = MagicMock() + remote_pod_mock.status.phase = 'Succeeded' + self.await_pod_mock.return_value = remote_pod_mock self.client_mock.list_namespaced_pod.return_value = [] - context = self.create_context(k) + context = create_context(k) k.execute(context=context) self.client_mock.assert_called_once_with( in_cluster=False, @@ -155,6 +176,25 @@ def test_labels(self): "execution_date": mock.ANY, } + def test_find_pod_labels(self): + k = KubernetesPodOperator( + namespace="default", + image="ubuntu:16.04", + cmds=["bash", "-cx"], + labels={"foo": "bar"}, + name="test", + task_id="task", + in_cluster=False, + do_xcom_push=False, + ) + self.run_pod(k) + self.client_mock.return_value.list_namespaced_pod.assert_called_once() + _, kwargs = self.client_mock.return_value.list_namespaced_pod.call_args + assert ( + kwargs['label_selector'] + == 'dag_id=dag,execution_date=2016-01-01T0100000000-26816529d,task_id=task,already_checked!=True' + ) + def test_image_pull_secrets_correctly_set(self): fake_pull_secrets = "fakeSecret" k = KubernetesPodOperator( @@ -170,7 +210,8 @@ def test_image_pull_secrets_correctly_set(self): image_pull_secrets=[k8s.V1LocalObjectReference(fake_pull_secrets)], cluster_context="default", ) - pod = k.create_pod_request_obj() + + pod = k.build_pod_request_obj(create_context(k)) assert pod.spec.image_pull_secrets == [k8s.V1LocalObjectReference(name=fake_pull_secrets)] def test_image_pull_policy_correctly_set(self): @@ -187,7 +228,7 @@ def test_image_pull_policy_correctly_set(self): image_pull_policy="Always", cluster_context="default", ) - pod = k.create_pod_request_obj() + pod = k.build_pod_request_obj(create_context(k)) assert pod.spec.containers[0].image_pull_policy == "Always" @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.delete_pod") @@ -205,29 +246,30 @@ def test_pod_delete_even_on_launcher_error(self, delete_pod_mock): cluster_context="default", is_delete_operator_pod=True, ) - self.monitor_mock.side_effect = AirflowException("fake failure") + self.await_pod_mock.side_effect = AirflowException("fake failure") with pytest.raises(AirflowException): - context = self.create_context(k) + context = create_context(k) k.execute(context=context) assert delete_pod_mock.called - @parameterized.expand([[True], [False]]) - def test_provided_pod_name(self, randomize_name): + @pytest.mark.parametrize('randomize', [True, False]) + def test_provided_pod_name(self, randomize): name_base = "test" k = KubernetesPodOperator( namespace="default", image="ubuntu:16.04", name=name_base, - random_name_suffix=randomize_name, + random_name_suffix=randomize, task_id="task", in_cluster=False, do_xcom_push=False, cluster_context="default", ) - pod = k.create_pod_request_obj() + context = create_context(k) + pod = k.build_pod_request_obj(context) - if randomize_name: + if randomize: assert pod.metadata.name.startswith(name_base) assert pod.metadata.name != name_base else: @@ -462,7 +504,9 @@ def test_pod_template_file(self, randomize_name): "execution_date": mock.ANY, } - def test_describes_pod_on_failure(self): + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.follow_container_logs") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.await_container_completion") + def test_describes_pod_on_failure(self, await_container_mock, follow_container_mock): name_base = "test" k = KubernetesPodOperator( @@ -477,22 +521,20 @@ def test_describes_pod_on_failure(self): do_xcom_push=False, cluster_context="default", ) - failed_pod_status = "read_pod_namespaced_result" - self.monitor_mock.return_value = (State.FAILED, failed_pod_status, None) - read_namespaced_pod_mock = self.client_mock.return_value.read_namespaced_pod - read_namespaced_pod_mock.return_value = failed_pod_status + follow_container_mock.return_value = None + remote_pod_mock = MagicMock() + remote_pod_mock.status.phase = 'Failed' + self.await_pod_mock.return_value = remote_pod_mock - with pytest.raises(AirflowException) as ctx: - context = self.create_context(k) + with pytest.raises(AirflowException, match=f"Pod {name_base}.[a-z0-9]+ returned a failure: .*"): + context = create_context(k) k.execute(context=context) - assert ( - str(ctx.value) - == f"Pod Launching failed: Pod {k.pod.metadata.name} returned a failure: {failed_pod_status}" - ) assert not self.client_mock.return_value.read_namespaced_pod.called - def test_no_need_to_describe_pod_on_success(self): + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.follow_container_logs") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.await_container_completion") + def test_no_handle_failure_on_success(self, await_container_mock, follow_container_mock): name_base = "test" k = KubernetesPodOperator( @@ -507,12 +549,16 @@ def test_no_need_to_describe_pod_on_success(self): do_xcom_push=False, cluster_context="default", ) - self.monitor_mock.return_value = (State.SUCCESS, None, None) - context = self.create_context(k) - k.execute(context=context) + follow_container_mock.return_value = None + remote_pod_mock = MagicMock() + remote_pod_mock.status.phase = 'Succeeded' + self.await_pod_mock.return_value = remote_pod_mock - assert not self.client_mock.return_value.read_namespaced_pod.called + context = create_context(k) + + # assert does not raise + k.execute(context=context) def test_create_with_affinity(self): name_base = "test" @@ -544,7 +590,7 @@ def test_create_with_affinity(self): affinity=affinity, ) - pod = k.create_pod_request_obj() + pod = k.build_pod_request_obj(create_context(k)) sanitized_pod = self.sanitize_for_serialization(pod) assert isinstance(pod.spec.affinity, k8s.V1Affinity) assert sanitized_pod["spec"]["affinity"] == affinity @@ -578,7 +624,7 @@ def test_create_with_affinity(self): affinity=k8s_api_affinity, ) - pod = k.create_pod_request_obj() + pod = k.build_pod_request_obj(create_context(k)) sanitized_pod = self.sanitize_for_serialization(pod) assert isinstance(pod.spec.affinity, k8s.V1Affinity) assert sanitized_pod["spec"]["affinity"] == affinity @@ -602,7 +648,7 @@ def test_tolerations(self): tolerations=tolerations, ) - pod = k.create_pod_request_obj() + pod = k.build_pod_request_obj(create_context(k)) sanitized_pod = self.sanitize_for_serialization(pod) assert isinstance(pod.spec.tolerations[0], k8s.V1Toleration) assert sanitized_pod["spec"]["tolerations"] == tolerations @@ -621,7 +667,7 @@ def test_tolerations(self): tolerations=k8s_api_tolerations, ) - pod = k.create_pod_request_obj() + pod = k.build_pod_request_obj(create_context(k)) sanitized_pod = self.sanitize_for_serialization(pod) assert isinstance(pod.spec.tolerations[0], k8s.V1Toleration) assert sanitized_pod["spec"]["tolerations"] == tolerations @@ -643,7 +689,7 @@ def test_node_selector(self): node_selector=node_selector, ) - pod = k.create_pod_request_obj() + pod = k.build_pod_request_obj(create_context(k)) sanitized_pod = self.sanitize_for_serialization(pod) assert isinstance(pod.spec.node_selector, dict) assert sanitized_pod["spec"]["nodeSelector"] == node_selector @@ -666,12 +712,16 @@ def test_node_selector(self): node_selectors=node_selector, ) - pod = k.create_pod_request_obj() + pod = k.build_pod_request_obj(create_context(k)) sanitized_pod = self.sanitize_for_serialization(pod) assert isinstance(pod.spec.node_selector, dict) assert sanitized_pod["spec"]["nodeSelector"] == node_selector - def test_push_xcom_pod_info(self): + @pytest.mark.parametrize('do_xcom_push', [True, False]) + @mock.patch(f"{POD_LAUNCHER_CLASS}.extract_xcom") + def test_push_xcom_pod_info(self, extract_xcom, do_xcom_push): + """pod name and namespace are *always* pushed; do_xcom_push only controls xcom sidecar""" + extract_xcom.return_value = '{}' k = KubernetesPodOperator( namespace="default", image="ubuntu:16.04", @@ -679,11 +729,12 @@ def test_push_xcom_pod_info(self): name="test", task_id="task", in_cluster=False, - do_xcom_push=False, + do_xcom_push=do_xcom_push, ) pod = self.run_pod(k) - ti = TaskInstance(task=k, run_id=IN_MEMORY_DAGRUN_ID) - ti.dag_run = DagRun(run_id=IN_MEMORY_DAGRUN_ID, execution_date=DEFAULT_DATE) + other_task = DummyOperator(task_id='task_to_pull_xcom') + ti = TaskInstance(task=other_task, run_id=IN_MEMORY_DAGRUN_ID) + ti.dag_run = DagRun(run_id=IN_MEMORY_DAGRUN_ID) pod_name = ti.xcom_pull(task_ids=k.task_id, key='pod_name') pod_namespace = ti.xcom_pull(task_ids=k.task_id, key='pod_namespace') assert pod_name and pod_name == pod.metadata.name @@ -719,8 +770,10 @@ def test_mark_created_pod_if_not_deleted(self, mock_patch_already_checked, mock_ task_id="task", is_delete_operator_pod=False, ) - self.monitor_mock.return_value = (State.FAILED, None, None) - context = self.create_context(k) + remote_pod_mock = MagicMock() + remote_pod_mock.status.phase = 'Failed' + self.await_pod_mock.return_value = remote_pod_mock + context = create_context(k) with pytest.raises(AirflowException): k.execute(context=context) mock_patch_already_checked.assert_called_once() @@ -742,8 +795,8 @@ def test_mark_created_pod_if_not_deleted_during_exception( task_id="task", is_delete_operator_pod=False, ) - self.monitor_mock.side_effect = AirflowException("oops") - context = self.create_context(k) + self.await_pod_mock.side_effect = AirflowException("oops") + context = create_context(k) with pytest.raises(AirflowException): k.execute(context=context) mock_patch_already_checked.assert_called_once() @@ -751,8 +804,8 @@ def test_mark_created_pod_if_not_deleted_during_exception( @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.delete_pod") @mock.patch( - "airflow.providers.cncf.kubernetes.operators.kubernetes_pod" - ".KubernetesPodOperator.patch_already_checked" + "airflow.providers.cncf.kubernetes.operators." + "kubernetes_pod.KubernetesPodOperator.patch_already_checked" ) def test_mark_reattached_pod_if_not_deleted(self, mock_patch_already_checked, mock_delete_pod): """If we aren't deleting pods and have a failure, mark it so we don't reattach to it""" @@ -763,17 +816,21 @@ def test_mark_reattached_pod_if_not_deleted(self, mock_patch_already_checked, mo task_id="task", is_delete_operator_pod=False, ) - # Run it first to easily get the pod - pod = self.run_pod(k) + remote_pod_mock = MagicMock() + remote_pod_mock.status.phase = 'Failed' + self.await_pod_mock.return_value = remote_pod_mock - # Now try and "reattach" - mock_patch_already_checked.reset_mock() - mock_delete_pod.reset_mock() - self.client_mock.return_value.list_namespaced_pod.return_value.items = [pod] - self.monitor_mock.return_value = (State.FAILED, None, None) - - context = self.create_context(k) + context = create_context(k) with pytest.raises(AirflowException): k.execute(context=context) mock_patch_already_checked.assert_called_once() mock_delete_pod.assert_not_called() + + +def test__suppress(): + with mock.patch('logging.Logger.error') as mock_error: + + with _suppress(ValueError): + raise ValueError("failure") + + mock_error.assert_called_once_with("failure", exc_info=True) diff --git a/tests/providers/cncf/kubernetes/utils/test_pod_launcher.py b/tests/providers/cncf/kubernetes/utils/test_pod_launcher.py index eaef295d37e3..c02c025fcd87 100644 --- a/tests/providers/cncf/kubernetes/utils/test_pod_launcher.py +++ b/tests/providers/cncf/kubernetes/utils/test_pod_launcher.py @@ -14,8 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import unittest from unittest import mock +from unittest.mock import MagicMock import pendulum import pytest @@ -23,18 +23,18 @@ from urllib3.exceptions import HTTPError as BaseHTTPError from airflow.exceptions import AirflowException -from airflow.providers.cncf.kubernetes.utils.pod_launcher import PodLauncher, PodStatus +from airflow.providers.cncf.kubernetes.utils.pod_launcher import PodLauncher, PodPhase, container_is_running -class TestPodLauncher(unittest.TestCase): - def setUp(self): +class TestPodLauncher: + def setup_method(self): self.mock_kube_client = mock.Mock() self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client) def test_read_pod_logs_successfully_returns_logs(self): mock.sentinel.metadata = mock.MagicMock() self.mock_kube_client.read_namespaced_pod_log.return_value = mock.sentinel.logs - logs = self.pod_launcher.read_pod_logs(mock.sentinel) + logs = self.pod_launcher.read_pod_logs(pod=mock.sentinel, container_name='base') assert mock.sentinel.logs == logs def test_read_pod_logs_retries_successfully(self): @@ -43,7 +43,7 @@ def test_read_pod_logs_retries_successfully(self): BaseHTTPError('Boom'), mock.sentinel.logs, ] - logs = self.pod_launcher.read_pod_logs(mock.sentinel) + logs = self.pod_launcher.read_pod_logs(pod=mock.sentinel, container_name='base') assert mock.sentinel.logs == logs self.mock_kube_client.read_namespaced_pod_log.assert_has_calls( [ @@ -74,12 +74,12 @@ def test_read_pod_logs_retries_fails(self): BaseHTTPError('Boom'), ] with pytest.raises(BaseHTTPError): - self.pod_launcher.read_pod_logs(mock.sentinel) + self.pod_launcher.read_pod_logs(pod=mock.sentinel, container_name='base') def test_read_pod_logs_successfully_with_tail_lines(self): mock.sentinel.metadata = mock.MagicMock() self.mock_kube_client.read_namespaced_pod_log.side_effect = [mock.sentinel.logs] - logs = self.pod_launcher.read_pod_logs(mock.sentinel, tail_lines=100) + logs = self.pod_launcher.read_pod_logs(pod=mock.sentinel, container_name='base', tail_lines=100) assert mock.sentinel.logs == logs self.mock_kube_client.read_namespaced_pod_log.assert_has_calls( [ @@ -98,7 +98,7 @@ def test_read_pod_logs_successfully_with_tail_lines(self): def test_read_pod_logs_successfully_with_since_seconds(self): mock.sentinel.metadata = mock.MagicMock() self.mock_kube_client.read_namespaced_pod_log.side_effect = [mock.sentinel.logs] - logs = self.pod_launcher.read_pod_logs(mock.sentinel, since_seconds=2) + logs = self.pod_launcher.read_pod_logs(mock.sentinel, 'base', since_seconds=2) assert mock.sentinel.logs == logs self.mock_kube_client.read_namespaced_pod_log.assert_has_calls( [ @@ -177,7 +177,7 @@ def test_monitor_pod_empty_logs(self): running_status = mock.MagicMock() running_status.configure_mock(**{'name': 'base', 'state.running': True}) pod_info_running = mock.MagicMock(**{'status.container_statuses': [running_status]}) - pod_info_succeeded = mock.MagicMock(**{'status.phase': PodStatus.SUCCEEDED}) + pod_info_succeeded = mock.MagicMock(**{'status.phase': PodPhase.SUCCEEDED}) def pod_state_gen(): yield pod_info_running @@ -186,14 +186,14 @@ def pod_state_gen(): self.mock_kube_client.read_namespaced_pod.side_effect = pod_state_gen() self.mock_kube_client.read_namespaced_pod_log.return_value = iter(()) - self.pod_launcher.monitor_pod(mock.sentinel, get_logs=True) + self.pod_launcher.follow_container_logs(mock.sentinel, 'base') def test_monitor_pod_logs_failures_non_fatal(self): mock.sentinel.metadata = mock.MagicMock() running_status = mock.MagicMock() running_status.configure_mock(**{'name': 'base', 'state.running': True}) pod_info_running = mock.MagicMock(**{'status.container_statuses': [running_status]}) - pod_info_succeeded = mock.MagicMock(**{'status.phase': PodStatus.SUCCEEDED}) + pod_info_succeeded = mock.MagicMock(**{'status.phase': PodPhase.SUCCEEDED}) def pod_state_gen(): yield pod_info_running @@ -209,7 +209,7 @@ def pod_log_gen(): self.mock_kube_client.read_namespaced_pod_log.side_effect = pod_log_gen() - self.pod_launcher.monitor_pod(mock.sentinel, get_logs=True) + self.pod_launcher.follow_container_logs(mock.sentinel, 'base') def test_read_pod_retries_fails(self): mock.sentinel.metadata = mock.MagicMock() @@ -224,13 +224,13 @@ def test_read_pod_retries_fails(self): def test_parse_log_line(self): log_message = "This should return no timestamp" timestamp, line = self.pod_launcher.parse_log_line(log_message) - self.assertEqual(timestamp, None) - self.assertEqual(line, log_message) + assert timestamp is None + assert line == log_message real_timestamp = "2020-10-08T14:16:17.793417674Z" timestamp, line = self.pod_launcher.parse_log_line(" ".join([real_timestamp, log_message])) - self.assertEqual(timestamp, pendulum.parse(real_timestamp)) - self.assertEqual(line, log_message) + assert timestamp == pendulum.parse(real_timestamp) + assert line == log_message with pytest.raises(Exception): self.pod_launcher.parse_log_line('2020-10-08T14:16:17.793417674ZInvalidmessage\n') @@ -241,14 +241,14 @@ def test_start_pod_retries_on_409_error(self, mock_run_pod_async): ApiException(status=409), mock.MagicMock(), ] - self.pod_launcher.start_pod(mock.sentinel) + self.pod_launcher.create_pod(mock.sentinel) assert mock_run_pod_async.call_count == 2 @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.run_pod_async") def test_start_pod_fails_on_other_exception(self, mock_run_pod_async): mock_run_pod_async.side_effect = [ApiException(status=504)] with pytest.raises(ApiException): - self.pod_launcher.start_pod(mock.sentinel) + self.pod_launcher.create_pod(mock.sentinel) @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.run_pod_async") def test_start_pod_retries_three_times(self, mock_run_pod_async): @@ -259,31 +259,83 @@ def test_start_pod_retries_three_times(self, mock_run_pod_async): ApiException(status=409), ] with pytest.raises(ApiException): - self.pod_launcher.start_pod(mock.sentinel) + self.pod_launcher.create_pod(mock.sentinel) assert mock_run_pod_async.call_count == 3 - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.pod_not_started") - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_launcher.PodLauncher.run_pod_async") - def test_start_pod_raises_informative_error_on_timeout(self, mock_run_pod_async, mock_pod_not_started): + def test_start_pod_raises_informative_error_on_timeout(self): pod_response = mock.MagicMock() - pod_response.status.start_time = None - mock_run_pod_async.return_value = pod_response - mock_pod_not_started.return_value = True + pod_response.status.phase = 'Pending' + self.mock_kube_client.read_namespaced_pod.return_value = pod_response expected_msg = "Check the pod events in kubernetes" + mock_pod = MagicMock() with pytest.raises(AirflowException, match=expected_msg): - self.pod_launcher.start_pod( - pod=mock.sentinel, + self.pod_launcher.await_pod_start( + pod=mock_pod, startup_timeout=0, ) - def test_base_container_is_running_none_event(self): - event = mock.MagicMock() - event_status = mock.MagicMock() - event_status.status = None - event_container_statuses = mock.MagicMock() - event_container_statuses.status = mock.MagicMock() - event_container_statuses.status.container_statuses = None - for e in [event, event_status, event_container_statuses]: - self.pod_launcher.read_pod = mock.MagicMock(return_value=e) - assert self.pod_launcher.base_container_is_running(None) is False + @mock.patch('airflow.providers.cncf.kubernetes.utils.pod_launcher.container_is_running') + def test_container_is_running(self, container_is_running_mock): + mock_pod = MagicMock() + self.pod_launcher.read_pod = mock.MagicMock(return_value=mock_pod) + self.pod_launcher.container_is_running(None, 'base') + container_is_running_mock.assert_called_with(pod=mock_pod, container_name='base') + + +def params_for_test_container_is_running(): + """The `container_is_running` method is designed to handle an assortment of bad objects + returned from `read_pod`. E.g. a None object, an object `e` such that `e.status` is None, + an object `e` such that `e.status.container_statuses` is None, and so on. This function + emits params used in `test_container_is_running` to verify this behavior. + + We create mock classes not derived from MagicMock because with an instance `e` of MagicMock, + tests like `e.hello is not None` are always True. + """ + + class RemotePodMock: + pass + + class ContainerStatusMock: + def __init__(self, name): + self.name = name + + def remote_pod(running=None, not_running=None): + e = RemotePodMock() + e.status = RemotePodMock() + e.status.container_statuses = [] + for r in not_running or []: + e.status.container_statuses.append(container(r, False)) + for r in running or []: + e.status.container_statuses.append(container(r, True)) + return e + + def container(name, running): + c = ContainerStatusMock(name) + c.state = RemotePodMock() + c.state.running = {'a': 'b'} if running else None + return c + + pod_mock_list = [] + pod_mock_list.append(pytest.param(None, False, id='None remote_pod')) + p = RemotePodMock() + p.status = None + pod_mock_list.append(pytest.param(p, False, id='None remote_pod.status')) + p = RemotePodMock() + p.status = RemotePodMock() + p.status.container_statuses = [] + pod_mock_list.append(pytest.param(p, False, id='empty remote_pod.status.container_statuses')) + pod_mock_list.append(pytest.param(remote_pod(), False, id='filter empty')) + pod_mock_list.append(pytest.param(remote_pod(None, ['base']), False, id='filter 0 running')) + pod_mock_list.append(pytest.param(remote_pod(['hello'], ['base']), False, id='filter 1 not running')) + pod_mock_list.append(pytest.param(remote_pod(['base'], ['hello']), True, id='filter 1 running')) + return pod_mock_list + + +@pytest.mark.parametrize('remote_pod, result', params_for_test_container_is_running()) +def test_container_is_running(remote_pod, result): + """The `container_is_running` function is designed to handle an assortment of bad objects + returned from `read_pod`. E.g. a None object, an object `e` such that `e.status` is None, + an object `e` such that `e.status.container_statuses` is None, and so on. This test + verifies the expected behavior.""" + assert container_is_running(remote_pod, 'base') is result