Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix condition for k8 container completion check #27502

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion airflow/providers/cncf/kubernetes/utils/pod_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,14 @@ def consume_logs(*, since_time: DateTime | None = None, follow: bool = True) ->
time.sleep(1)

def await_container_completion(self, pod: V1Pod, container_name: str) -> None:
while not self.container_is_running(pod=pod, container_name=container_name):
"""
Monitors a container of a pod until it reaches its final state

:param pod: pod spec that will be monitored
:param container_name: name of the container of the pod that will be monitored
:return:
"""
while self.container_is_running(pod=pod, container_name=container_name):
time.sleep(1)

def await_pod_completion(self, pod: V1Pod) -> V1Pod:
Expand Down
61 changes: 38 additions & 23 deletions kubernetes_tests/test_kubernetes_pod_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
import logging
import os
import random
import re
import shutil
import sys
import textwrap
import unittest
from copy import copy
from tempfile import NamedTemporaryFile
from unittest import mock
Expand Down Expand Up @@ -77,19 +76,26 @@ def get_kubeconfig_path():
return kubeconfig_path if kubeconfig_path else os.path.expanduser("~/.kube/config")


def get_normalized_test_name():
test = os.environ["PYTEST_CURRENT_TEST"]
test = test.replace("(setup)", "").replace("call", "")
return "".join(filter(str.isalnum, test)).lower()


def get_label():
test = os.environ.get("PYTEST_CURRENT_TEST")
label = "".join(filter(str.isalnum, test)).lower()
label = get_normalized_test_name()
label = label.strip()
return label[-63]


@pytest.mark.execution_timeout(180)
class TestKubernetesPodOperatorSystem(unittest.TestCase):
class TestKubernetesPodOperatorSystem:
def get_current_task_name(self):
test_name = get_normalized_test_name()
# reverse test name to make pod name unique (it has limited length)
return "_" + unittest.TestCase.id(self).replace(".", "_")[::-1]
return "_" + test_name[::-1]

def setUp(self):
def setup_method(self, method):
self.maxDiff = None
self.api_client = ApiClient()
self.labels = {"test_label": get_label()}
Expand Down Expand Up @@ -810,7 +816,11 @@ def test_pod_template_file_with_full_pod_spec(self):
assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")]
assert result == {"hello": "world"}

def test_full_pod_spec(self):
@pytest.mark.parametrize(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding support for this required me to move from a unittest.TestCase to pytest. This required a couple of other fixes, but is worth it in my mind, since other test classes are using pytest.

"should_get_logs",
[True, False],
)
def test_full_pod_spec(self, should_get_logs: bool):
pod_spec = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
labels={"test_label": get_label(), "fizz": "buzz"}, namespace="default", name="test-pod"
Expand All @@ -836,11 +846,13 @@ def test_full_pod_spec(self):
do_xcom_push=True,
is_delete_operator_pod=False,
startup_timeout_seconds=30,
get_logs=should_get_logs,
)

context = create_context(k)
result = k.execute(context)
assert result is not None
assert k.pod is not None
assert k.pod.metadata.labels == {
"fizz": "buzz",
"test_label": get_label(),
Expand Down Expand Up @@ -916,7 +928,12 @@ def test_init_container(self):
@mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock)
@mock.patch(HOOK_CLASS)
def test_pod_template_file(
self, hook_mock, await_pod_completion_mock, extract_xcom_mock, await_xcom_sidecar_container_start_mock
self,
hook_mock,
await_pod_completion_mock,
extract_xcom_mock,
await_xcom_sidecar_container_start_mock,
caplog,
):
# todo: This isn't really a system test
await_xcom_sidecar_container_start_mock.return_value = None
Expand All @@ -934,21 +951,19 @@ def test_pod_template_file(
pod_mock.status.phase = "Succeeded"
await_pod_completion_mock.return_value = pod_mock
context = create_context(k)
with self.assertLogs(k.log, level=logging.DEBUG) as cm:
with caplog.at_level(logging.DEBUG, logger="airflow.task"):
k.execute(context)
expected_line = textwrap.dedent(
"""\
DEBUG:airflow.task.operators:Starting pod:
api_version: v1
kind: Pod
metadata:
annotations: {}
cluster_name: null
creation_timestamp: null
deletion_grace_period_seconds: null\
"""
).strip()
assert any(line.startswith(expected_line) for line in cm.output)
expected_line = (
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to the different log format of caplog, I decided to reformat this to a regex search instead of the any(... logic. This was easier to do with this string format instead of the multiline.

r"DEBUG\s+airflow.task.operators.kubernetes_pod.py:\d+\sStarting pod:\n"
r"api_version: v1\n"
r"kind: Pod\n"
r"metadata:\n"
r"\s+annotations: {}\n"
r"\s+cluster_name: null\n"
r"\s+creation_timestamp: null\n"
r"\s+deletion_grace_period_seconds: null\n"
)
assert re.search(expected_line, caplog.text) is not None

actual_pod = self.api_client.sanitize_for_serialization(k.pod)
expected_dict = {
Expand Down