Skip to content

Commit

Permalink
Logging from all containers in KubernetesOperatorPod (#31663)
Browse files Browse the repository at this point in the history
* Logging from all containers in KubernetesOperatorPod

* review comments from uranusjs

* Fixing init logic

* Fixing docs

* Fixing sphinx logging

* Addressing review comments from vincbeck

* nits from potiuk

* nits from potiuk

* reverting return type

* comment from uranusjr

* fixing tests

* fixing tests

* review comments from hussein

* handling nits from hussein

* fixing tests

---------

Co-authored-by: Amogh <adesai@cloudera.com>
Co-authored-by: Amogh Desai <adesai@adesai-MBP16.local>
  • Loading branch information
3 people committed Jul 6, 2023
1 parent e7587b3 commit 9a0f41b
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 10 deletions.
20 changes: 15 additions & 5 deletions airflow/providers/cncf/kubernetes/operators/pod.py
Expand Up @@ -27,7 +27,7 @@
from collections.abc import Container
from contextlib import AbstractContextManager
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
from typing import TYPE_CHECKING, Any, Iterable, Sequence

from kubernetes.client import CoreV1Api, models as k8s
from slugify import slugify
Expand Down Expand Up @@ -62,6 +62,7 @@
get_container_termination_message,
)
from airflow.settings import pod_mutation_hook
from airflow.typing_compat import Literal
from airflow.utils import yaml
from airflow.utils.helpers import prune_dict, validate_key
from airflow.utils.timezone import utcnow
Expand Down Expand Up @@ -178,6 +179,10 @@ class KubernetesPodOperator(BaseOperator):
:param labels: labels to apply to the Pod. (templated)
:param startup_timeout_seconds: timeout in seconds to startup the pod.
:param get_logs: get the stdout of the base container as logs of the tasks.
:param container_logs: list of containers whose logs will be published to stdout
Takes a sequence of containers, a single container name or True. If True,
all the containers logs are published. Works in conjunction with get_logs param.
The default value is the base container.
:param image_pull_policy: Specify a policy to cache or always pull an image.
:param annotations: non-identifying metadata you can attach to the Pod.
Can be a large range of data, and can include characters
Expand Down Expand Up @@ -278,6 +283,7 @@ def __init__(
reattach_on_restart: bool = True,
startup_timeout_seconds: int = 120,
get_logs: bool = True,
container_logs: Iterable[str] | str | Literal[True] = BASE_CONTAINER_NAME,
image_pull_policy: str | None = None,
annotations: dict | None = None,
container_resources: k8s.V1ResourceRequirements | None = None,
Expand Down Expand Up @@ -350,6 +356,11 @@ def __init__(
self.cluster_context = cluster_context
self.reattach_on_restart = reattach_on_restart
self.get_logs = get_logs
self.container_logs = container_logs
if self.container_logs == KubernetesPodOperator.BASE_CONTAINER_NAME:
self.container_logs = (
base_container_name if base_container_name else KubernetesPodOperator.BASE_CONTAINER_NAME
)
self.image_pull_policy = image_pull_policy
self.node_selector = node_selector or {}
self.annotations = annotations or {}
Expand Down Expand Up @@ -572,11 +583,10 @@ def execute_sync(self, context: Context):
self.await_pod_start(pod=self.pod)

if self.get_logs:
self.pod_manager.fetch_container_logs(
self.pod_manager.fetch_requested_container_logs(
pod=self.pod,
container_name=self.base_container_name,
follow=True,
post_termination_timeout=self.POST_TERMINATION_TIMEOUT,
container_logs=self.container_logs,
follow_logs=True,
)
else:
self.pod_manager.await_container_completion(
Expand Down
95 changes: 91 additions & 4 deletions airflow/providers/cncf/kubernetes/utils/pod_manager.py
Expand Up @@ -23,6 +23,7 @@
import math
import time
import warnings
from collections.abc import Iterable
from contextlib import closing, suppress
from dataclasses import dataclass
from datetime import datetime, timedelta
Expand All @@ -43,7 +44,7 @@

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.kubernetes.pod_generator import PodDefaults
from airflow.typing_compat import Protocol
from airflow.typing_compat import Literal, Protocol
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.timezone import utcnow

Expand Down Expand Up @@ -125,6 +126,17 @@ def container_is_running(pod: V1Pod, container_name: str) -> bool:
return container_status.state.running is not None


def container_is_completed(pod: V1Pod, container_name: str) -> bool:
"""
Examines V1Pod ``pod`` to determine whether ``container_name`` is completed.
If that container is present and completed, returns True. Returns False otherwise.
"""
container_status = get_container_status(pod, container_name)
if not container_status:
return False
return container_status.state.terminated is not None


def container_is_terminated(pod: V1Pod, container_name: str) -> bool:
"""
Examines V1Pod ``pod`` to determine whether ``container_name`` is terminated.
Expand Down Expand Up @@ -379,11 +391,12 @@ def consume_logs(
for raw_line in logs:
line = raw_line.decode("utf-8", errors="backslashreplace")
timestamp, message = self.parse_log_line(line)
self.log.info(message)
self.log.info("[%s] %s", container_name, message)
except BaseHTTPError as e:
self.log.warning(
"Reading of logs interrupted with error %r; will retry. "
"Reading of logs interrupted for container %r with error %r; will retry. "
"Set log level to DEBUG for traceback.",
container_name,
e,
)
self.log.debug(
Expand Down Expand Up @@ -413,14 +426,78 @@ def consume_logs(
)
time.sleep(1)

def fetch_requested_container_logs(
self, pod: V1Pod, container_logs: Iterable[str] | str | Literal[True], follow_logs=False
) -> list[PodLoggingStatus]:
"""
Follow the logs of containers in the pod specified by input parameter and publish
it to airflow logging. Returns when all the containers exit.
"""
pod_logging_statuses = []
all_containers = self.get_container_names(pod)
if len(all_containers) == 0:
self.log.error("Could not retrieve containers for the pod: %s", pod.metadata.name)
else:
if isinstance(container_logs, str):
# fetch logs only for requested container if only one container is provided
if container_logs in all_containers:
status = self.fetch_container_logs(
pod=pod, container_name=container_logs, follow=follow_logs
)
pod_logging_statuses.append(status)
else:
self.log.error(
"container %s whose logs were requested not found in the pod %s",
container_logs,
pod.metadata.name,
)
elif isinstance(container_logs, bool):
# if True is provided, get logs for all the containers
if container_logs is True:
for container_name in all_containers:
status = self.fetch_container_logs(
pod=pod, container_name=container_name, follow=follow_logs
)
pod_logging_statuses.append(status)
else:
self.log.error(
"False is not a valid value for container_logs",
)
else:
# if a sequence of containers are provided, iterate for every container in the pod
if isinstance(container_logs, Iterable):
for container in container_logs:
if container in all_containers:
status = self.fetch_container_logs(
pod=pod, container_name=container, follow=follow_logs
)
pod_logging_statuses.append(status)
else:
self.log.error(
"Container %s whose logs were requests not found in the pod %s",
container,
pod.metadata.name,
)
else:
self.log.error(
"Invalid type %s specified for container names input parameter", type(container_logs)
)

return pod_logging_statuses

def await_container_completion(self, pod: V1Pod, container_name: str) -> None:
"""
Waits for the given container in the given pod to be completed.
:param pod: pod spec that will be monitored
:param container_name: name of the container within the pod to monitor
"""
while not self.container_is_terminated(pod=pod, container_name=container_name):
while True:
remote_pod = self.read_pod(pod)
terminated = container_is_completed(remote_pod, container_name)
if terminated:
break
self.log.info("Waiting for container '%s' state to be completed", container_name)
time.sleep(1)

def await_pod_completion(self, pod: V1Pod) -> V1Pod:
Expand Down Expand Up @@ -513,6 +590,16 @@ def read_pod_logs(
post_termination_timeout=post_termination_timeout,
)

@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True)
def get_container_names(self, pod: V1Pod) -> list[str]:
"""Return container names from the POD except for the airflow-xcom-sidecar container."""
pod_info = self.read_pod(pod)
return [
container_spec.name
for container_spec in pod_info.spec.containers
if container_spec.name != PodDefaults.SIDECAR_CONTAINER_NAME
]

@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True)
def read_pod_events(self, pod: V1Pod) -> CoreV1EventList:
"""Reads events from the POD."""
Expand Down
2 changes: 1 addition & 1 deletion kubernetes_tests/test_kubernetes_pod_operator.py
Expand Up @@ -500,7 +500,7 @@ def test_volume_mount(self, mock_get_connection):
)
context = create_context(k)
k.execute(context=context)
mock_logger.info.assert_any_call("retrieved from mount")
mock_logger.info.assert_any_call("[%s] %s", "base", "retrieved from mount")
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod["spec"]["containers"][0]["args"] = args
self.expected_pod["spec"]["containers"][0]["volumeMounts"] = [
Expand Down
40 changes: 40 additions & 0 deletions tests/providers/cncf/kubernetes/utils/test_pod_manager.py
Expand Up @@ -317,6 +317,46 @@ def test_fetch_container_done(self, logs_available, container_running, follow):
assert ret.last_log_time is None
assert ret.running is False

# adds all valid types for container_logs
@pytest.mark.parametrize("follow", [True, False])
@pytest.mark.parametrize("container_logs", ["base", "alpine", True, ["base", "alpine"]])
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_running")
def test_fetch_requested_container_logs(self, container_is_running, container_logs, follow):
mock_pod = MagicMock()
self.pod_manager.read_pod = MagicMock()
self.pod_manager.get_container_names = MagicMock()
self.pod_manager.get_container_names.return_value = ["base", "alpine"]
container_is_running.return_value = False
self.mock_kube_client.read_namespaced_pod_log.return_value = mock.MagicMock(
stream=mock.MagicMock(return_value=[b"2021-01-01 hi"])
)

ret_values = self.pod_manager.fetch_requested_container_logs(
pod=mock_pod, container_logs=container_logs, follow_logs=follow
)
for ret in ret_values:
assert ret.running is False

# adds all invalid types for container_logs
@pytest.mark.parametrize("container_logs", [1, None, 6.8, False])
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_running")
def test_fetch_requested_container_logs_invalid(self, container_running, container_logs):
mock_pod = MagicMock()
self.pod_manager.read_pod = MagicMock()
self.pod_manager.get_container_names = MagicMock()
self.pod_manager.get_container_names.return_value = ["base", "alpine"]
container_running.return_value = False
self.mock_kube_client.read_namespaced_pod_log.return_value = mock.MagicMock(
stream=mock.MagicMock(return_value=[b"2021-01-01 hi"])
)

ret_values = self.pod_manager.fetch_requested_container_logs(
pod=mock_pod,
container_logs=container_logs,
)

assert len(ret_values) == 0

@mock.patch("pendulum.now")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_running")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodLogsConsumer.logs_available")
Expand Down

0 comments on commit 9a0f41b

Please sign in to comment.