diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/airflow/providers/cncf/kubernetes/utils/pod_manager.py index ffd3e7ec4bd88..fec783d39e7e1 100644 --- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -267,7 +267,14 @@ def consume_logs(*, since_time: DateTime | None = None, follow: bool = True) -> time.sleep(1) def await_container_completion(self, pod: V1Pod, container_name: str) -> None: - while not self.container_is_running(pod=pod, container_name=container_name): + """ + Monitors a container of a pod until it reaches its final state + + :param pod: pod spec that will be monitored + :param container_name: name of the container of the pod that will be monitored + :return: + """ + while self.container_is_running(pod=pod, container_name=container_name): time.sleep(1) def await_pod_completion(self, pod: V1Pod) -> V1Pod: diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py index e1e6812018fc0..e6e3441d21879 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator.py +++ b/kubernetes_tests/test_kubernetes_pod_operator.py @@ -20,10 +20,9 @@ import logging import os import random +import re import shutil import sys -import textwrap -import unittest from copy import copy from tempfile import NamedTemporaryFile from unittest import mock @@ -77,19 +76,26 @@ def get_kubeconfig_path(): return kubeconfig_path if kubeconfig_path else os.path.expanduser("~/.kube/config") +def get_normalized_test_name(): + test = os.environ["PYTEST_CURRENT_TEST"] + test = test.replace("(setup)", "").replace("call", "") + return "".join(filter(str.isalnum, test)).lower() + + def get_label(): - test = os.environ.get("PYTEST_CURRENT_TEST") - label = "".join(filter(str.isalnum, test)).lower() + label = get_normalized_test_name() + label = label.strip() return label[-63] @pytest.mark.execution_timeout(180) -class TestKubernetesPodOperatorSystem(unittest.TestCase): +class TestKubernetesPodOperatorSystem: def get_current_task_name(self): + test_name = get_normalized_test_name() # reverse test name to make pod name unique (it has limited length) - return "_" + unittest.TestCase.id(self).replace(".", "_")[::-1] + return "_" + test_name[::-1] - def setUp(self): + def setup_method(self, method): self.maxDiff = None self.api_client = ApiClient() self.labels = {"test_label": get_label()} @@ -810,7 +816,11 @@ def test_pod_template_file_with_full_pod_spec(self): assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")] assert result == {"hello": "world"} - def test_full_pod_spec(self): + @pytest.mark.parametrize( + "should_get_logs", + [True, False], + ) + def test_full_pod_spec(self, should_get_logs: bool): pod_spec = k8s.V1Pod( metadata=k8s.V1ObjectMeta( labels={"test_label": get_label(), "fizz": "buzz"}, namespace="default", name="test-pod" @@ -836,11 +846,13 @@ def test_full_pod_spec(self): do_xcom_push=True, is_delete_operator_pod=False, startup_timeout_seconds=30, + get_logs=should_get_logs, ) context = create_context(k) result = k.execute(context) assert result is not None + assert k.pod is not None assert k.pod.metadata.labels == { "fizz": "buzz", "test_label": get_label(), @@ -916,7 +928,12 @@ def test_init_container(self): @mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock) @mock.patch(HOOK_CLASS) def test_pod_template_file( - self, hook_mock, await_pod_completion_mock, extract_xcom_mock, await_xcom_sidecar_container_start_mock + self, + hook_mock, + await_pod_completion_mock, + extract_xcom_mock, + await_xcom_sidecar_container_start_mock, + caplog, ): # todo: This isn't really a system test await_xcom_sidecar_container_start_mock.return_value = None @@ -934,21 +951,19 @@ def test_pod_template_file( pod_mock.status.phase = "Succeeded" await_pod_completion_mock.return_value = pod_mock context = create_context(k) - with self.assertLogs(k.log, level=logging.DEBUG) as cm: + with caplog.at_level(logging.DEBUG, logger="airflow.task"): k.execute(context) - expected_line = textwrap.dedent( - """\ - DEBUG:airflow.task.operators:Starting pod: - api_version: v1 - kind: Pod - metadata: - annotations: {} - cluster_name: null - creation_timestamp: null - deletion_grace_period_seconds: null\ - """ - ).strip() - assert any(line.startswith(expected_line) for line in cm.output) + expected_line = ( + r"DEBUG\s+airflow.task.operators.kubernetes_pod.py:\d+\sStarting pod:\n" + r"api_version: v1\n" + r"kind: Pod\n" + r"metadata:\n" + r"\s+annotations: {}\n" + r"\s+cluster_name: null\n" + r"\s+creation_timestamp: null\n" + r"\s+deletion_grace_period_seconds: null\n" + ) + assert re.search(expected_line, caplog.text) is not None actual_pod = self.api_client.sanitize_for_serialization(k.pod) expected_dict = {