Skip to content

Commit

Permalink
Revert Remove PodLoggingStatus object #35422 (#35822)
Browse files Browse the repository at this point in the history
* Revert Remove PodLoggingStatus object #35422

This object was completely unused in OSS but others may have depended on it and it is kinder to remove it in a major release.

* Fix param for the test

---------

Co-authored-by: Pankaj Koti <pankajkoti699@gmail.com>
  • Loading branch information
dstandish and pankajkoti committed Nov 23, 2023
1 parent ef2ad07 commit ca97fee
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 28 deletions.
34 changes: 24 additions & 10 deletions airflow/providers/cncf/kubernetes/utils/pod_manager.py
Expand Up @@ -25,8 +25,9 @@
import warnings
from collections.abc import Iterable
from contextlib import closing, suppress
from dataclasses import dataclass
from datetime import timedelta
from typing import TYPE_CHECKING, Callable, Generator, Literal, Protocol, cast
from typing import TYPE_CHECKING, Callable, Generator, Protocol, cast

import pendulum
import tenacity
Expand All @@ -35,6 +36,7 @@
from kubernetes.stream import stream as kubernetes_stream
from pendulum import DateTime
from pendulum.parsing.exceptions import ParserError
from typing_extensions import Literal
from urllib3.exceptions import HTTPError as BaseHTTPError

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
Expand Down Expand Up @@ -271,6 +273,14 @@ def read_pod(self):
return self.read_pod_cache


@dataclass
class PodLoggingStatus:
"""Return the status of the pod and last log time when exiting from `fetch_container_logs`."""

running: bool
last_log_time: DateTime | None


class PodManager(LoggingMixin):
"""Create, monitor, and otherwise interact with Kubernetes pods for use with the KubernetesPodOperator."""

Expand Down Expand Up @@ -355,7 +365,7 @@ def await_pod_start(
raise PodLaunchFailedException(msg)
time.sleep(startup_check_interval)

def follow_container_logs(self, pod: V1Pod, container_name: str) -> None:
def follow_container_logs(self, pod: V1Pod, container_name: str) -> PodLoggingStatus:
warnings.warn(
"Method `follow_container_logs` is deprecated. Use `fetch_container_logs` instead"
"with option `follow=True`.",
Expand All @@ -372,7 +382,7 @@ def fetch_container_logs(
follow=False,
since_time: DateTime | None = None,
post_termination_timeout: int = 120,
) -> None:
) -> PodLoggingStatus:
"""
Follow the logs of container and stream to airflow logging.
Expand Down Expand Up @@ -450,16 +460,17 @@ def consume_logs(*, since_time: DateTime | None = None) -> DateTime | None:
last_log_time = since_time
while True:
last_log_time = consume_logs(since_time=last_log_time)
if not self.container_is_running(pod, container_name=container_name):
return PodLoggingStatus(running=False, last_log_time=last_log_time)
if not follow:
return
if self.container_is_running(pod, container_name=container_name):
return PodLoggingStatus(running=True, last_log_time=last_log_time)
else:
self.log.warning(
"Follow requested but pod log read interrupted and container %s still running",
"Pod %s log read interrupted but container %s still running",
pod.metadata.name,
container_name,
)
time.sleep(1)
else: # follow requested, but container is done
break

def _reconcile_requested_log_containers(
self, requested: Iterable[str] | str | bool, actual: list[str], pod_name
Expand Down Expand Up @@ -507,22 +518,25 @@ def _reconcile_requested_log_containers(

def fetch_requested_container_logs(
self, pod: V1Pod, containers: Iterable[str] | str | Literal[True], follow_logs=False
) -> None:
) -> list[PodLoggingStatus]:
"""
Follow the logs of containers in the specified pod and publish it to airflow logging.
Returns when all the containers exit.
:meta private:
"""
pod_logging_statuses = []
all_containers = self.get_container_names(pod)
containers_to_log = self._reconcile_requested_log_containers(
requested=containers,
actual=all_containers,
pod_name=pod.metadata.name,
)
for c in containers_to_log:
self.fetch_container_logs(pod=pod, container_name=c, follow=follow_logs)
status = self.fetch_container_logs(pod=pod, container_name=c, follow=follow_logs)
pod_logging_statuses.append(status)
return pod_logging_statuses

def await_container_completion(self, pod: V1Pod, container_name: str) -> None:
"""
Expand Down
66 changes: 48 additions & 18 deletions tests/providers/cncf/kubernetes/utils/test_pod_manager.py
Expand Up @@ -20,6 +20,7 @@
from datetime import datetime
from json.decoder import JSONDecodeError
from types import SimpleNamespace
from typing import cast
from unittest import mock
from unittest.mock import MagicMock

Expand Down Expand Up @@ -257,6 +258,19 @@ def test_parse_log_line(self):
assert timestamp == pendulum.parse(real_timestamp)
assert line == log_message

@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs")
def test_fetch_container_logs_returning_last_timestamp(
self, mock_read_pod_logs, mock_container_is_running
):
timestamp_string = "2020-10-08T14:16:17.793417674Z"
mock_read_pod_logs.return_value = [bytes(f"{timestamp_string} message", "utf-8"), b"notimestamp"]
mock_container_is_running.side_effect = [True, False]

status = self.pod_manager.fetch_container_logs(mock.MagicMock(), mock.MagicMock(), follow=True)

assert status.last_log_time == cast(DateTime, pendulum.parse(timestamp_string))

@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs")
def test_fetch_container_logs_invoke_progress_callback(
Expand Down Expand Up @@ -291,7 +305,8 @@ def consumer_iter():
with mock.patch.object(PodLogsConsumer, "__iter__") as mock_consumer_iter:
mock_consumer_iter.side_effect = consumer_iter
mock_container_is_running.side_effect = [True, True, False]
self.pod_manager.fetch_container_logs(mock.MagicMock(), mock.MagicMock(), follow=True)
status = self.pod_manager.fetch_container_logs(mock.MagicMock(), mock.MagicMock(), follow=True)
assert status.last_log_time == cast(DateTime, pendulum.parse(last_timestamp_string))
assert self.mock_progress_callback.call_count == expected_call_count

@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running")
Expand Down Expand Up @@ -396,21 +411,15 @@ def test_fetch_container_done(self, logs_available, container_running, follow):
mock_pod = MagicMock()
logs_available.return_value = False
container_running.return_value = False
self.pod_manager.fetch_container_logs(pod=mock_pod, container_name="base", follow=follow)
ret = self.pod_manager.fetch_container_logs(pod=mock_pod, container_name="base", follow=follow)
assert ret.last_log_time is None
assert ret.running is False

# adds all valid types for container_logs
@pytest.mark.parametrize("follow", [True, False])
@pytest.mark.parametrize(
"container_logs, exp_cont",
[
("base", ["base"]),
("alpine", ["alpine"]),
(True, ["base", "alpine"]),
(["base", "alpine"], ["base", "alpine"]),
],
)
@pytest.mark.parametrize("container_logs", ["base", "alpine", True, ["base", "alpine"]])
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_running")
def test_fetch_requested_container_logs(self, container_is_running, container_logs, follow, exp_cont):
def test_fetch_requested_container_logs(self, container_is_running, container_logs, follow):
mock_pod = MagicMock()
self.pod_manager.read_pod = MagicMock()
self.pod_manager.get_container_names = MagicMock()
Expand All @@ -420,12 +429,31 @@ def test_fetch_requested_container_logs(self, container_is_running, container_lo
stream=mock.MagicMock(return_value=[b"2021-01-01 hi"])
)

self.pod_manager.fetch_requested_container_logs(
ret_values = self.pod_manager.fetch_requested_container_logs(
pod=mock_pod, containers=container_logs, follow_logs=follow
)
calls = {tuple(x[1].values()) for x in container_is_running.call_args_list}
pod = self.pod_manager.read_pod.return_value
assert calls == {(pod, x) for x in exp_cont}
for ret in ret_values:
assert ret.running is False

# adds all invalid types for container_logs
@pytest.mark.parametrize("container_logs", [1, None, 6.8, False])
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_running")
def test_fetch_requested_container_logs_invalid(self, container_running, container_logs):
mock_pod = MagicMock()
self.pod_manager.read_pod = MagicMock()
self.pod_manager.get_container_names = MagicMock()
self.pod_manager.get_container_names.return_value = ["base", "alpine"]
container_running.return_value = False
self.mock_kube_client.read_namespaced_pod_log.return_value = mock.MagicMock(
stream=mock.MagicMock(return_value=[b"2021-01-01 hi"])
)

ret_values = self.pod_manager.fetch_requested_container_logs(
pod=mock_pod,
containers=container_logs,
)

assert len(ret_values) == 0

@mock.patch("pendulum.now")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_running")
Expand All @@ -444,7 +472,7 @@ def test_fetch_container_since_time(self, logs_available, container_running, moc
args, kwargs = self.mock_kube_client.read_namespaced_pod_log.call_args_list[0]
assert kwargs["since_seconds"] == 5

@pytest.mark.parametrize("follow, is_running_calls, exp_running", [(True, 3, False), (False, 2, False)])
@pytest.mark.parametrize("follow, is_running_calls, exp_running", [(True, 3, False), (False, 3, False)])
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_running")
def test_fetch_container_running_follow(
self, container_running_mock, follow, is_running_calls, exp_running
Expand All @@ -458,8 +486,10 @@ def test_fetch_container_running_follow(
self.mock_kube_client.read_namespaced_pod_log.return_value = mock.MagicMock(
stream=mock.MagicMock(return_value=[b"2021-01-01 hi"])
)
self.pod_manager.fetch_container_logs(pod=mock_pod, container_name="base", follow=follow)
ret = self.pod_manager.fetch_container_logs(pod=mock_pod, container_name="base", follow=follow)
assert len(container_running_mock.call_args_list) == is_running_calls
assert ret.last_log_time == DateTime(2021, 1, 1, tzinfo=Timezone("UTC"))
assert ret.running is exp_running

@pytest.mark.parametrize(
"container_state, expected_is_terminated",
Expand Down

0 comments on commit ca97fee

Please sign in to comment.