Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@
from airflow.providers.cncf.kubernetes.secret import Secret
from airflow.sdk import Context

log = logging.getLogger(__name__)

alphanum_lower = string.ascii_lowercase + string.digits

KUBE_CONFIG_ENV_VAR = "KUBECONFIG"
Expand Down Expand Up @@ -972,7 +974,14 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:

if event["status"] in ("error", "failed", "timeout", "success"):
if self.get_logs:
self._write_logs(self.pod, follow=follow, since_time=last_log_time)
try:
self._write_logs(self.pod, follow=follow, since_time=last_log_time)
except (HTTPError, ApiException) as e:
self.log.warning(
"Reading of logs interrupted with error %r. "
"Set log level to DEBUG for traceback.",
e if not isinstance(e, ApiException) else e.reason,
)

for callback in self.callbacks:
callback.on_pod_completion(
Expand Down Expand Up @@ -1035,32 +1044,32 @@ def _clean(self, event: dict[str, Any], result: dict | None, context: Context) -
result=result,
)

@tenacity.retry(
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_exponential(max=15),
retry=tenacity.retry_if_exception_type((HTTPError, ApiException)),
before_sleep=tenacity.before_sleep_log(log, logging.WARNING),
reraise=True,
)
def _write_logs(self, pod: k8s.V1Pod, follow: bool = False, since_time: DateTime | None = None) -> None:
try:
since_seconds = (
math.ceil((datetime.datetime.now(tz=datetime.timezone.utc) - since_time).total_seconds())
if since_time
else None
)
logs = self.client.read_namespaced_pod_log(
name=pod.metadata.name,
namespace=pod.metadata.namespace,
container=self.base_container_name,
follow=follow,
timestamps=False,
since_seconds=since_seconds,
_preload_content=False,
)
for raw_line in logs:
line = raw_line.decode("utf-8", errors="backslashreplace").rstrip("\n")
if line:
self.log.info("[%s] logs: %s", self.base_container_name, line)
except (HTTPError, ApiException) as e:
self.log.warning(
"Reading of logs interrupted with error %r; will retry. "
"Set log level to DEBUG for traceback.",
e if not isinstance(e, ApiException) else e.reason,
)
since_seconds = (
math.ceil((datetime.datetime.now(tz=datetime.timezone.utc) - since_time).total_seconds())
if since_time
else None
)
logs = self.client.read_namespaced_pod_log(
name=pod.metadata.name,
namespace=pod.metadata.namespace,
container=self.base_container_name,
follow=follow,
timestamps=False,
since_seconds=since_seconds,
_preload_content=False,
)
for raw_line in logs:
line = raw_line.decode("utf-8", errors="backslashreplace").rstrip("\n")
if line:
self.log.info("[%s] logs: %s", self.base_container_name, line)

def post_complete_action(
self, *, pod: k8s.V1Pod, remote_pod: k8s.V1Pod, context: Context, result: dict | None, **kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import pendulum
import pytest
import tenacity
from kubernetes.client import ApiClient, V1Pod, V1PodSecurityContext, V1PodStatus, models as k8s
from kubernetes.client.exceptions import ApiException

Expand Down Expand Up @@ -2764,7 +2765,7 @@ def test_async_write_logs_should_execute_successfully(
@patch(HOOK_CLASS)
@patch(KUB_OP_PATH.format("pod_manager"))
def test_async_write_logs_handler_api_exception(
self, mock_manager, mocked_hook, mock_extract_xcom, post_complete_action, mocked_client
self, mock_manager, mocked_hook, mock_extract_xcom, mocked_client, post_complete_action
):
mocked_client.read_namespaced_pod_log.side_effect = ApiException(status=404)
mock_manager.await_pod_completion.side_effect = ApiException(status=404)
Expand All @@ -2777,9 +2778,70 @@ def test_async_write_logs_handler_api_exception(
get_logs=True,
deferrable=True,
)
# Patch tenacity wait to avoid real delays from _write_logs retries
k._write_logs.retry.wait = tenacity.wait_none()
self.run_pod_async(k)
post_complete_action.assert_not_called()

@patch(KUB_OP_PATH.format("post_complete_action"))
@patch(KUB_OP_PATH.format("client"))
@patch(HOOK_CLASS)
@patch(KUB_OP_PATH.format("pod_manager"))
def test_write_logs_retries_on_api_exception(
self, mock_manager, mocked_hook, mocked_client, post_complete_action
):
"""Test that _write_logs retries on ApiException and succeeds on subsequent attempt."""
test_logs = b"log line\n"
mocked_client.read_namespaced_pod_log.side_effect = [
ApiException(status=500),
[test_logs],
]
mock_manager.await_pod_completion.return_value = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(name=TEST_NAME, namespace=TEST_NAMESPACE)
)
mocked_hook.return_value.get_pod.return_value = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(name=TEST_NAME, namespace=TEST_NAMESPACE)
)
k = KubernetesPodOperator(
task_id="task",
get_logs=True,
deferrable=True,
)
# Patch tenacity wait to avoid real delays in tests
k._write_logs.retry.wait = tenacity.wait_none()
self.run_pod_async(k)
assert mocked_client.read_namespaced_pod_log.call_count == 2
post_complete_action.assert_called_once()

@patch(KUB_OP_PATH.format("post_complete_action"))
@patch(KUB_OP_PATH.format("client"))
@patch(HOOK_CLASS)
@patch(KUB_OP_PATH.format("pod_manager"))
def test_write_logs_gives_up_after_max_retries(
self, mock_manager, mocked_hook, mocked_client, post_complete_action, caplog
):
"""Test that _write_logs gives up after 3 failed attempts and trigger_reentry catches the error."""
mocked_client.read_namespaced_pod_log.side_effect = ApiException(status=500)
mock_manager.await_pod_completion.return_value = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(name=TEST_NAME, namespace=TEST_NAMESPACE)
)
mocked_hook.return_value.get_pod.return_value = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(name=TEST_NAME, namespace=TEST_NAMESPACE)
)
k = KubernetesPodOperator(
task_id="task",
get_logs=True,
deferrable=True,
)
# Patch tenacity wait to avoid real delays in tests
k._write_logs.retry.wait = tenacity.wait_none()
self.run_pod_async(k)
# 3 attempts (stop_after_attempt(3))
assert mocked_client.read_namespaced_pod_log.call_count > 1
# trigger_reentry catches the error and continues; post_complete_action still called via _clean
post_complete_action.assert_called_once()
assert "Reading of logs interrupted with error" in caplog.text

@pytest.mark.parametrize(
("log_pod_spec_on_failure", "expect_match"),
[
Expand Down