Skip to content

Commit

Permalink
Fix get_logs, pod cleanup and XCom push in ``GKEStartPodOperato…
Browse files Browse the repository at this point in the history
…rAsync`` (#824)
  • Loading branch information
bharanidharan14 committed Jan 4, 2023
1 parent d939f5b commit b0afff9
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 19 deletions.
Expand Up @@ -107,7 +107,9 @@ async def wait_for_container_completion(self, v1_api: CoreV1Api) -> "TriggerEven
while True:
pod = await v1_api.read_namespaced_pod(self.pod_name, self.pod_namespace)
if not container_is_running(pod=pod, container_name=self.container_name):
return TriggerEvent({"status": "done"})
return TriggerEvent(
{"status": "done", "namespace": self.pod_namespace, "pod_name": self.pod_name}
)
if time_get_more_logs and timezone.utcnow() > time_get_more_logs:
return TriggerEvent({"status": "running", "last_log_time": self.last_log_time})
await asyncio.sleep(self.poll_interval)
Expand All @@ -120,7 +122,9 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: # noqa: D102
v1_api = CoreV1Api(api_client)
state = await self.wait_for_pod_start(v1_api)
if state in PodPhase.terminal_states:
event = TriggerEvent({"status": "done"})
event = TriggerEvent(
{"status": "done", "namespace": self.pod_namespace, "pod_name": self.pod_name}
)
else:
event = await self.wait_for_container_completion(v1_api)
yield event
Expand Down
89 changes: 84 additions & 5 deletions astronomer/providers/google/cloud/operators/kubernetes_engine.py
Expand Up @@ -10,6 +10,12 @@
)
from kubernetes.client import models as k8s

from astronomer.providers.cncf.kubernetes.operators.kubernetes_pod import (
PodNotFoundException,
)
from astronomer.providers.cncf.kubernetes.triggers.wait_container import (
PodLaunchTimeoutException,
)
from astronomer.providers.google.cloud.triggers.kubernetes_engine import (
GKEStartPodTrigger,
)
Expand Down Expand Up @@ -66,6 +72,8 @@ def __init__(
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
regional: bool = False,
poll_interval: float = 5,
logging_interval: Optional[int] = None,
do_xcom_push: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -79,6 +87,8 @@ def __init__(
self.pod_name: str = ""
self.pod_namespace: str = ""
self.poll_interval = poll_interval
self.logging_interval = logging_interval
self.do_xcom_push = do_xcom_push

def _get_or_create_pod(self, context: Context) -> None:
"""A wrapper to fetch GKE config and get or create a pod"""
Expand Down Expand Up @@ -116,13 +126,82 @@ def execute(self, context: Context) -> None:
regional=self.regional,
poll_interval=self.poll_interval,
pending_phase_timeout=self.startup_timeout_seconds,
logging_interval=self.logging_interval,
),
method_name="execute_complete",
method_name=self.trigger_reentry.__name__,
)

def execute_complete(self, context: Context, event: Dict[str, Any]) -> Any:
"""Callback for trigger once task reach terminal state"""
if event and event["status"] == "done":
self.log.info("Job completed successfully")
else:
raise AirflowException(event["description"])
self.trigger_reentry(context=context, event=event)

@staticmethod
def raise_for_trigger_status(event: Dict[str, Any]) -> None:
"""Raise exception if pod is not in expected state."""
if event["status"] == "error":
description = event["description"]
if "error_type" in event and event["error_type"] == "PodLaunchTimeoutException":
raise PodLaunchTimeoutException(description)
else:
raise AirflowException(description)

def trigger_reentry(self, context: Context, event: Dict[str, Any]) -> Any:
"""
Point of re-entry from trigger.
If ``logging_interval`` is None, then at this point the pod should be done and we'll just fetch
the logs and exit.
If ``logging_interval`` is not None, it could be that the pod is still running and we'll just
grab the latest logs and defer back to the trigger again.
"""
remote_pod = None
self.raise_for_trigger_status(event)
try:
with GKEStartPodOperator.get_gke_config_file(
gcp_conn_id=self.gcp_conn_id,
project_id=self.project_id,
cluster_name=self.cluster_name,
impersonation_chain=self.impersonation_chain,
regional=self.regional,
location=self.location,
use_internal_ip=self.use_internal_ip,
) as config_file:
self.config_file = config_file
self.pod = self.find_pod(
namespace=event["namespace"],
context=context,
)

if not self.pod:
raise PodNotFoundException("Could not find pod after resuming from deferral")

if self.get_logs:
last_log_time = event and event.get("last_log_time")
if last_log_time:
self.log.info("Resuming logs read from time %r", last_log_time) # pragma: no cover
self.pod_manager.fetch_container_logs(
pod=self.pod,
container_name=self.BASE_CONTAINER_NAME,
follow=self.logging_interval is None,
since_time=last_log_time,
)

if self.do_xcom_push:
result = self.extract_xcom(pod=self.pod)
remote_pod = self.pod_manager.await_pod_completion(self.pod)
except Exception:
self.cleanup(
pod=self.pod,
remote_pod=remote_pod,
)
raise
self.cleanup(
pod=self.pod,
remote_pod=remote_pod,
)
if self.do_xcom_push:
ti = context["ti"]
ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace)
return result
15 changes: 13 additions & 2 deletions astronomer/providers/google/cloud/triggers/kubernetes_engine.py
Expand Up @@ -118,8 +118,19 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:
async with await hook.get_api_client_async() as api_client:
v1_api = CoreV1Api(api_client)
state = await self.wait_for_pod_start(v1_api)
if state in PodPhase.terminal_states:
event = TriggerEvent({"status": "done"})
if state == PodPhase.SUCCEEDED:
event = TriggerEvent(
{"status": "done", "namespace": self.namespace, "pod_name": self.name}
)
elif state == PodPhase.FAILED:
event = TriggerEvent(
{
"status": "failed",
"namespace": self.namespace,
"pod_name": self.name,
"description": "Failed to start pod operator",
}
)
else:
event = await self.wait_for_container_completion(v1_api)
yield event
Expand Down
10 changes: 7 additions & 3 deletions tests/cncf/kubernetes/triggers/test_wait_container.py
Expand Up @@ -137,7 +137,9 @@ async def test_pending_succeeded(self, load_kube_config, wait_completion):
poll_interval=2,
)

assert await trigger.run().__anext__() == TriggerEvent({"status": "done"})
assert await trigger.run().__anext__() == TriggerEvent(
{"status": "done", "namespace": mock.ANY, "pod_name": mock.ANY}
)
wait_completion.assert_not_awaited()

@pytest.mark.asyncio
Expand Down Expand Up @@ -181,15 +183,17 @@ async def test_failed(self, load_kube_config, wait_completion):
poll_interval=2,
)

assert await trigger.run().__anext__() == TriggerEvent({"status": "done"})
assert await trigger.run().__anext__() == TriggerEvent(
{"status": "done", "namespace": mock.ANY, "pod_name": mock.ANY}
)
wait_completion.assert_not_awaited()

@pytest.mark.asyncio
@pytest.mark.parametrize(
"logging_interval, exp_event",
[
param(0, {"status": "running", "last_log_time": DateTime(2022, 1, 1)}, id="short_interval"),
param(None, {"status": "done"}, id="no_interval"),
param(None, {"status": "done", "namespace": mock.ANY, "pod_name": mock.ANY}, id="no_interval"),
],
)
@mock.patch(READ_NAMESPACED_POD_PATH, new=get_read_pod_mock_containers([1, 1, None, None]))
Expand Down
155 changes: 151 additions & 4 deletions tests/google/cloud/operators/test_kubernetes_engine.py
@@ -1,16 +1,25 @@
from unittest import mock
from unittest.mock import MagicMock

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodLoggingStatus
from kubernetes.client import models as k8s
from kubernetes.client.models.v1_object_meta import V1ObjectMeta

from astronomer.providers.cncf.kubernetes.operators.kubernetes_pod import (
PodNotFoundException,
)
from astronomer.providers.cncf.kubernetes.triggers.wait_container import (
PodLaunchTimeoutException,
)
from astronomer.providers.google.cloud.operators.kubernetes_engine import (
GKEStartPodOperatorAsync,
)
from astronomer.providers.google.cloud.triggers.kubernetes_engine import (
GKEStartPodTrigger,
)
from tests.utils.airflow_util import create_context

PROJECT_ID = "astronomer-***-providers"
LOCATION = "us-west1"
Expand Down Expand Up @@ -82,8 +91,12 @@ def test_execute(mock__get_or_create_pod):
assert isinstance(exc.value.trigger, GKEStartPodTrigger), "Trigger is not a GKEStartPodTrigger"


def test_execute_complete_success():
@mock.patch(
"astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync.trigger_reentry"
)
def test_execute_complete_success(mock_trigger_reentry):
"""assert that execute_complete_success log correct message when a task succeed"""
mock_trigger_reentry.return_value = {}
operator = GKEStartPodOperatorAsync(
task_id="start_pod",
project_id=PROJECT_ID,
Expand All @@ -94,9 +107,7 @@ def test_execute_complete_success():
image="ubuntu",
gcp_conn_id=GCP_CONN_ID,
)
with mock.patch.object(operator.log, "info") as mock_log_info:
operator.execute_complete(context=context, event={"status": "done"})
mock_log_info.assert_called_with("Job completed successfully")
assert operator.execute_complete(context=create_context(operator), event={}) is None


def test_execute_complete_fail():
Expand All @@ -113,3 +124,139 @@ def test_execute_complete_fail():
with pytest.raises(AirflowException):
"""assert that execute_complete_success raise exception when a task fail"""
operator.execute_complete(context=context, event={"status": "error", "description": "Pod not found"})


def test_raise_for_trigger_status_done():
"""Assert trigger don't raise exception in case of status is done"""
assert (
GKEStartPodOperatorAsync(
task_id="start_pod",
project_id=PROJECT_ID,
location=LOCATION,
cluster_name=GKE_CLUSTER_NAME,
name="astro_k8s_gke_test_pod",
namespace=NAMESPACE,
image="ubuntu",
gcp_conn_id=GCP_CONN_ID,
).raise_for_trigger_status({"status": "done"})
is None
)


@mock.patch("airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator.client")
@mock.patch("astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync.cleanup")
@mock.patch("airflow.kubernetes.kube_client.get_kube_client")
@mock.patch(
"astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync"
".raise_for_trigger_status"
)
@mock.patch("astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync.find_pod")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs")
@mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook._get_default_client")
@mock.patch(
"airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator.get_gke_config_file"
)
@mock.patch(
"astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync.extract_xcom"
)
def test_get_logs_not_running(
mock_extract_xcom,
mock_gke_config,
mock_get_default_client,
fetch_container_logs,
await_pod_completion,
find_pod,
raise_for_trigger_status,
get_kube_client,
cleanup,
mock_client,
):
mock_extract_xcom.return_value = "{}"
pod = MagicMock()
find_pod.return_value = pod
mock_client.return_value = {}
op = GKEStartPodOperatorAsync(
task_id="start_pod",
project_id=PROJECT_ID,
location=LOCATION,
cluster_name=GKE_CLUSTER_NAME,
name="astro_k8s_gke_test_pod",
namespace=NAMESPACE,
image="ubuntu",
gcp_conn_id=GCP_CONN_ID,
get_logs=True,
do_xcom_push=True,
)
context = create_context(op)
await_pod_completion.return_value = None
fetch_container_logs.return_value = PodLoggingStatus(False, None)
op.trigger_reentry(context, {"namespace": NAMESPACE})
fetch_container_logs.is_called_with(pod, "base")


@mock.patch("airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator.client")
@mock.patch("astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync.cleanup")
@mock.patch("airflow.kubernetes.kube_client.get_kube_client")
@mock.patch(
"astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync"
".raise_for_trigger_status"
)
@mock.patch("astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync.find_pod")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs")
@mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook._get_default_client")
@mock.patch(
"airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator.get_gke_config_file"
)
def test_no_pod(
mock_gke_config,
mock_get_default_client,
fetch_container_logs,
await_pod_completion,
find_pod,
raise_for_trigger_status,
get_kube_client,
cleanup,
mock_client,
):
"""Assert if pod not found then raise exception"""
find_pod.return_value = None
op = GKEStartPodOperatorAsync(
task_id="start_pod",
project_id=PROJECT_ID,
location=LOCATION,
cluster_name=GKE_CLUSTER_NAME,
name="astro_k8s_gke_test_pod",
namespace=NAMESPACE,
image="ubuntu",
gcp_conn_id=GCP_CONN_ID,
get_logs=True,
)
context = create_context(op)
with pytest.raises(PodNotFoundException):
op.trigger_reentry(context, {"namespace": NAMESPACE})


def test_trigger_error():
"""Assert that trigger_reentry raise exception in case of error"""
op = GKEStartPodOperatorAsync(
task_id="start_pod",
project_id=PROJECT_ID,
location=LOCATION,
cluster_name=GKE_CLUSTER_NAME,
name="astro_k8s_gke_test_pod",
namespace=NAMESPACE,
image="ubuntu",
gcp_conn_id=GCP_CONN_ID,
get_logs=True,
)
with pytest.raises(PodLaunchTimeoutException):
op.execute_complete(
context,
{
"status": "error",
"error_type": "PodLaunchTimeoutException",
"description": "any message",
},
)

0 comments on commit b0afff9

Please sign in to comment.