Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace pod_manager.read_pod_logs with client.read_namespaced_pod_log in KubernetesPodOperator._write_logs #39112

Merged
merged 7 commits into from
May 5, 2024
70 changes: 35 additions & 35 deletions airflow/providers/cncf/kubernetes/operators/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context: Context) -> k8s

return pod_request_obj

def await_pod_start(self, pod: k8s.V1Pod):
def await_pod_start(self, pod: k8s.V1Pod) -> None:
try:
self.pod_manager.await_pod_start(
pod=pod,
Expand All @@ -565,23 +565,23 @@ def await_pod_start(self, pod: k8s.V1Pod):
self._read_pod_events(pod, reraise=False)
raise

def extract_xcom(self, pod: k8s.V1Pod):
def extract_xcom(self, pod: k8s.V1Pod) -> dict[Any, Any] | None:
"""Retrieve xcom value and kill xcom sidecar container."""
result = self.pod_manager.extract_xcom(pod)
if isinstance(result, str) and result.rstrip() == EMPTY_XCOM_RESULT:
self.log.info("xcom result file is empty.")
return None
else:
self.log.info("xcom result: \n%s", result)
return json.loads(result)

self.log.info("xcom result: \n%s", result)
return json.loads(result)

def execute(self, context: Context):
"""Based on the deferrable parameter runs the pod asynchronously or synchronously."""
if self.deferrable:
self.execute_async(context)
else:
if not self.deferrable:
return self.execute_sync(context)

self.execute_async(context)

def execute_sync(self, context: Context):
result = None
try:
Expand Down Expand Up @@ -669,7 +669,7 @@ def _refresh_cached_properties(self):
del self.client
del self.pod_manager

def execute_async(self, context: Context):
def execute_async(self, context: Context) -> None:
self.pod_request_obj = self.build_pod_request_obj(context)
self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill`
pod_request_obj=self.pod_request_obj,
Expand All @@ -687,7 +687,7 @@ def execute_async(self, context: Context):

self.invoke_defer_method()

def invoke_defer_method(self, last_log_time: DateTime | None = None):
def invoke_defer_method(self, last_log_time: DateTime | None = None) -> None:
"""Redefine triggers which are being used in child classes."""
trigger_start_time = datetime.datetime.now(tz=datetime.timezone.utc)
self.defer(
Expand Down Expand Up @@ -742,7 +742,7 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
if event["status"] in ("error", "failed", "timeout"):
# fetch some logs when pod is failed
if self.get_logs:
self.write_logs(self.pod, follow=follow, since_time=last_log_time)
self._write_logs(self.pod, follow=follow, since_time=last_log_time)

if self.do_xcom_push:
_ = self.extract_xcom(pod=self.pod)
Expand Down Expand Up @@ -770,7 +770,7 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
elif event["status"] == "success":
# fetch some logs when pod is executed successfully
if self.get_logs:
self.write_logs(self.pod, follow=follow, since_time=last_log_time)
self._write_logs(self.pod, follow=follow, since_time=last_log_time)

if self.do_xcom_push:
xcom_sidecar_output = self.extract_xcom(pod=self.pod)
Expand All @@ -781,7 +781,7 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
finally:
self._clean(event)

def _clean(self, event: dict[str, Any]):
def _clean(self, event: dict[str, Any]) -> None:
if event["status"] == "running":
return
istio_enabled = self.is_istio_enabled(self.pod)
Expand All @@ -797,35 +797,35 @@ def _clean(self, event: dict[str, Any]):
remote_pod=self.pod,
)

@deprecated(reason="use `trigger_reentry` instead.", category=AirflowProviderDeprecationWarning)
def execute_complete(self, context: Context, event: dict, **kwargs):
return self.trigger_reentry(context=context, event=event)

def write_logs(self, pod: k8s.V1Pod, follow: bool = False, since_time: DateTime | None = None):
def _write_logs(self, pod: k8s.V1Pod, follow: bool = False, since_time: DateTime | None = None) -> None:
try:
since_seconds = (
math.ceil((datetime.datetime.now(tz=datetime.timezone.utc) - since_time).total_seconds())
if since_time
else None
)
logs = self.pod_manager.read_pod_logs(
logs = self.client.read_namespaced_pod_log(
name=pod.metadata.name,
namespace=pod.metadata.namespace,
pod=pod,
container_name=self.base_container_name,
follow=follow,
timestamps=False,
since_seconds=since_seconds,
_preload_content=False,
)
for raw_line in logs:
line = raw_line.decode("utf-8", errors="backslashreplace").rstrip("\n")
if line:
self.log.info("Container logs: %s", line)
self.log.info("[%s] logs: %s", self.base_container_name, line)
except HTTPError as e:
self.log.warning(
"Reading of logs interrupted with error %r; will retry. "
"Set log level to DEBUG for traceback.",
e,
)

def post_complete_action(self, *, pod, remote_pod, **kwargs):
def post_complete_action(self, *, pod, remote_pod, **kwargs) -> None:
"""Actions that must be done after operator finishes logic of the deferrable_execution."""
self.cleanup(
pod=pod,
Expand Down Expand Up @@ -893,7 +893,7 @@ def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod):
)
)

def _read_pod_events(self, pod, *, reraise=True):
def _read_pod_events(self, pod, *, reraise=True) -> None:
"""Will fetch and emit events from pod."""
with _optionally_suppress(reraise=reraise):
for event in self.pod_manager.read_pod_events(pod).items:
Expand Down Expand Up @@ -941,15 +941,11 @@ def kill_istio_sidecar(self, pod: V1Pod) -> None:
def process_pod_deletion(self, pod: k8s.V1Pod, *, reraise=True):
with _optionally_suppress(reraise=reraise):
if pod is not None:
should_delete_pod = (
(self.on_finish_action == OnFinishAction.DELETE_POD)
or (
self.on_finish_action == OnFinishAction.DELETE_SUCCEEDED_POD
and pod.status.phase == PodPhase.SUCCEEDED
)
or (
self.on_finish_action == OnFinishAction.DELETE_SUCCEEDED_POD
and container_is_succeeded(pod, self.base_container_name)
should_delete_pod = (self.on_finish_action == OnFinishAction.DELETE_POD) or (
self.on_finish_action == OnFinishAction.DELETE_SUCCEEDED_POD
and (
pod.status.phase == PodPhase.SUCCEEDED
or container_is_succeeded(pod, self.base_container_name)
)
)
if should_delete_pod:
Expand All @@ -966,8 +962,8 @@ def _build_find_pod_label_selector(self, context: Context | None = None, *, excl
label_strings = [f"{label_id}={label}" for label_id, label in sorted(labels.items())]
labels_value = ",".join(label_strings)
if exclude_checked:
labels_value += f",{self.POD_CHECKED_KEY}!=True"
labels_value += ",!airflow-worker"
labels_value = f"{labels_value},{self.POD_CHECKED_KEY}!=True"
labels_value = f"{labels_value},!airflow-worker"
return labels_value

@staticmethod
Expand Down Expand Up @@ -1129,6 +1125,10 @@ def dry_run(self) -> None:
pod = self.build_pod_request_obj()
print(yaml.dump(prune_dict(pod.to_dict(), mode="strict")))

@deprecated(reason="use `trigger_reentry` instead.", category=AirflowProviderDeprecationWarning)
def execute_complete(self, context: Context, event: dict, **kwargs):
return self.trigger_reentry(context=context, event=event)


class _optionally_suppress(AbstractContextManager):
"""
Expand All @@ -1142,15 +1142,15 @@ class _optionally_suppress(AbstractContextManager):
:meta private:
"""

def __init__(self, *exceptions, reraise=False):
def __init__(self, *exceptions, reraise: bool = False) -> None:
self._exceptions = exceptions or (Exception,)
self.reraise = reraise
self.exception = None

def __enter__(self):
return self

def __exit__(self, exctype, excinst, exctb):
def __exit__(self, exctype, excinst, exctb) -> bool:
error = exctype is not None
matching_error = error and issubclass(exctype, self._exceptions)
if (error and not matching_error) or (matching_error and self.reraise):
Expand Down
24 changes: 13 additions & 11 deletions tests/providers/cncf/kubernetes/operators/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,7 @@ def test_async_xcom_sidecar_container_resources_default_should_execute_successfu

@pytest.mark.parametrize("get_logs", [True, False])
@patch(KUB_OP_PATH.format("post_complete_action"))
@patch(KUB_OP_PATH.format("write_logs"))
@patch(KUB_OP_PATH.format("_write_logs"))
@patch(POD_MANAGER_CLASS)
@patch(HOOK_CLASS)
def test_async_get_logs_should_execute_successfully(
Expand Down Expand Up @@ -2075,9 +2075,9 @@ def test_cleanup_log_pod_spec_on_failure(self, log_pod_spec_on_failure, expect_m
with pytest.raises(AirflowException, match=expect_match):
k.cleanup(pod, pod)

@mock.patch(f"{HOOK_CLASS}.get_pod")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs")
@patch(f"{HOOK_CLASS}.get_pod")
@patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
@patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs")
def test_get_logs_running(
self,
fetch_container_logs,
Expand All @@ -2097,10 +2097,11 @@ def test_get_logs_running(
)
fetch_container_logs.is_called_with(pod, "base")

@mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.cleanup")
@mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.find_pod")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs")
def test_get_logs_not_running(self, fetch_container_logs, find_pod, cleanup):
@patch(KUB_OP_PATH.format("_write_logs"))
@patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.cleanup")
@patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.find_pod")
@patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs")
def test_get_logs_not_running(self, fetch_container_logs, find_pod, cleanup, mock_write_log):
pod = MagicMock()
find_pod.return_value = pod
op = KubernetesPodOperator(task_id="test_task", name="test-pod", get_logs=True)
Expand All @@ -2110,9 +2111,10 @@ def test_get_logs_not_running(self, fetch_container_logs, find_pod, cleanup):
)
fetch_container_logs.is_called_with(pod, "base")

@mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.cleanup")
@mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.find_pod")
def test_trigger_error(self, find_pod, cleanup):
@patch(KUB_OP_PATH.format("_write_logs"))
@patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.cleanup")
@patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.find_pod")
def test_trigger_error(self, find_pod, cleanup, mock_write_log):
"""Assert that trigger_reentry raise exception in case of error"""
find_pod.return_value = MagicMock()
op = KubernetesPodOperator(task_id="test_task", name="test-pod", get_logs=True)
Expand Down