Skip to content

Commit

Permalink
Move KubernetesPodTrigger hook to a cached property (#36290)
Browse files Browse the repository at this point in the history
  • Loading branch information
hussein-awala committed Dec 19, 2023
1 parent b5cd96a commit 5ab43d5
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 52 deletions.
28 changes: 15 additions & 13 deletions airflow/providers/cncf/kubernetes/triggers/pod.py
Expand Up @@ -22,6 +22,7 @@
import warnings
from asyncio import CancelledError
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Any, AsyncIterator

from airflow.exceptions import AirflowProviderDeprecationWarning
Expand Down Expand Up @@ -116,7 +117,6 @@ def __init__(
self.on_finish_action = OnFinishAction(on_finish_action)
self.should_delete_pod = self.on_finish_action == OnFinishAction.DELETE_POD

self._hook: AsyncKubernetesHook | None = None
self._since_time = None

def serialize(self) -> tuple[str, dict[str, Any]]:
Expand All @@ -142,11 +142,10 @@ def serialize(self) -> tuple[str, dict[str, Any]]:

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Get current pod status and yield a TriggerEvent."""
hook = self._get_async_hook()
self.log.info("Checking pod %r in namespace %r.", self.pod_name, self.pod_namespace)
try:
while True:
pod = await hook.get_pod(
pod = await self.hook.get_pod(
name=self.pod_name,
namespace=self.pod_namespace,
)
Expand Down Expand Up @@ -206,13 +205,13 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
# That means that task was marked as failed
if self.get_logs:
self.log.info("Outputting container logs...")
await self._get_async_hook().read_logs(
await self.hook.read_logs(
name=self.pod_name,
namespace=self.pod_namespace,
)
if self.on_finish_action == OnFinishAction.DELETE_POD:
self.log.info("Deleting pod...")
await self._get_async_hook().delete_pod(
await self.hook.delete_pod(
name=self.pod_name,
namespace=self.pod_namespace,
)
Expand All @@ -237,14 +236,17 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
)

def _get_async_hook(self) -> AsyncKubernetesHook:
if self._hook is None:
self._hook = AsyncKubernetesHook(
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
config_file=self.config_file,
cluster_context=self.cluster_context,
)
return self._hook
# TODO: Remove this method when the min version of kubernetes provider is 7.12.0 in Google provider.
return AsyncKubernetesHook(
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
config_file=self.config_file,
cluster_context=self.cluster_context,
)

@cached_property
def hook(self) -> AsyncKubernetesHook:
return self._get_async_hook()

def define_container_state(self, pod: V1Pod) -> ContainerState:
pod_containers = pod.status.container_statuses
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/google/cloud/triggers/kubernetes_engine.py
Expand Up @@ -19,6 +19,7 @@

import asyncio
import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence

from google.cloud.container_v1.types import Operation
Expand Down Expand Up @@ -137,7 +138,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
},
)

def _get_async_hook(self) -> GKEPodAsyncHook: # type: ignore[override]
@cached_property
def hook(self) -> GKEPodAsyncHook: # type: ignore[override]
return GKEPodAsyncHook(
cluster_url=self._cluster_url,
ssl_ca_cert=self._ssl_ca_cert,
Expand Down
44 changes: 22 additions & 22 deletions tests/providers/cncf/kubernetes/triggers/test_pod.py
Expand Up @@ -94,9 +94,9 @@ def test_serialize(self, trigger):

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_PATH}._get_async_hook")
@mock.patch(f"{TRIGGER_PATH}.hook")
async def test_run_loop_return_success_event(self, mock_hook, mock_method, trigger):
mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.TERMINATED

expected_event = TriggerEvent(
Expand All @@ -113,9 +113,9 @@ async def test_run_loop_return_success_event(self, mock_hook, mock_method, trigg

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_PATH}._get_async_hook")
@mock.patch(f"{TRIGGER_PATH}.hook")
async def test_run_loop_return_failed_event(self, mock_hook, mock_method, trigger):
mock_hook.return_value.get_pod.return_value = self._mock_pod_result(
mock_hook.get_pod.return_value = self._mock_pod_result(
mock.MagicMock(
status=mock.MagicMock(
message=FAILED_RESULT_MSG,
Expand All @@ -138,9 +138,9 @@ async def test_run_loop_return_failed_event(self, mock_hook, mock_method, trigge

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_PATH}._get_async_hook")
@mock.patch(f"{TRIGGER_PATH}.hook")
async def test_run_loop_return_waiting_event(self, mock_hook, mock_method, trigger, caplog):
mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.WAITING

caplog.set_level(logging.INFO)
Expand All @@ -154,9 +154,9 @@ async def test_run_loop_return_waiting_event(self, mock_hook, mock_method, trigg

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_PATH}._get_async_hook")
@mock.patch(f"{TRIGGER_PATH}.hook")
async def test_run_loop_return_running_event(self, mock_hook, mock_method, trigger, caplog):
mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.RUNNING

caplog.set_level(logging.INFO)
Expand All @@ -169,15 +169,15 @@ async def test_run_loop_return_running_event(self, mock_hook, mock_method, trigg
assert f"Sleeping for {POLL_INTERVAL} seconds."

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}._get_async_hook")
@mock.patch(f"{TRIGGER_PATH}.hook")
async def test_logging_in_trigger_when_exception_should_execute_successfully(
self, mock_hook, trigger, caplog
):
"""
Test that KubernetesPodTrigger fires the correct event in case of an error.
"""

mock_hook.return_value.get_pod.side_effect = Exception("Test exception")
mock_hook.get_pod.side_effect = Exception("Test exception")

generator = trigger.run()
actual = await generator.asend(None)
Expand All @@ -192,15 +192,15 @@ async def test_logging_in_trigger_when_exception_should_execute_successfully(

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_PATH}._get_async_hook")
@mock.patch(f"{TRIGGER_PATH}.hook")
async def test_logging_in_trigger_when_fail_should_execute_successfully(
self, mock_hook, mock_method, trigger, caplog
):
"""
Test that KubernetesPodTrigger fires the correct event in case of fail.
"""

mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.FAILED
caplog.set_level(logging.INFO)

Expand All @@ -209,7 +209,7 @@ async def test_logging_in_trigger_when_fail_should_execute_successfully(
assert "Container logs:"

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}._get_async_hook")
@mock.patch(f"{TRIGGER_PATH}.hook")
async def test_logging_in_trigger_when_cancelled_should_execute_successfully_and_delete_pod(
self,
mock_hook,
Expand All @@ -219,9 +219,9 @@ async def test_logging_in_trigger_when_cancelled_should_execute_successfully_and
Test that KubernetesPodTrigger fires the correct event in case if the task was cancelled.
"""

mock_hook.return_value.get_pod.side_effect = CancelledError()
mock_hook.return_value.read_logs.return_value = self._mock_pod_result(mock.MagicMock())
mock_hook.return_value.delete_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_hook.get_pod.side_effect = CancelledError()
mock_hook.read_logs.return_value = self._mock_pod_result(mock.MagicMock())
mock_hook.delete_pod.return_value = self._mock_pod_result(mock.MagicMock())

trigger = KubernetesPodTrigger(
pod_name=POD_NAME,
Expand Down Expand Up @@ -255,7 +255,7 @@ async def test_logging_in_trigger_when_cancelled_should_execute_successfully_and
assert "Deleting pod..." in caplog.text

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}._get_async_hook")
@mock.patch(f"{TRIGGER_PATH}.hook")
async def test_logging_in_trigger_when_cancelled_should_execute_successfully_without_delete_pod(
self,
mock_hook,
Expand All @@ -265,9 +265,9 @@ async def test_logging_in_trigger_when_cancelled_should_execute_successfully_wit
Test that KubernetesPodTrigger fires the correct event if the task was cancelled.
"""

mock_hook.return_value.get_pod.side_effect = CancelledError()
mock_hook.return_value.read_logs.return_value = self._mock_pod_result(mock.MagicMock())
mock_hook.return_value.delete_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_hook.get_pod.side_effect = CancelledError()
mock_hook.read_logs.return_value = self._mock_pod_result(mock.MagicMock())
mock_hook.delete_pod.return_value = self._mock_pod_result(mock.MagicMock())

trigger = KubernetesPodTrigger(
pod_name=POD_NAME,
Expand Down Expand Up @@ -341,12 +341,12 @@ def test_define_container_state_should_execute_successfully(
@pytest.mark.asyncio
@pytest.mark.parametrize("container_state", [ContainerState.WAITING, ContainerState.UNDEFINED])
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_PATH}._get_async_hook")
@mock.patch(f"{TRIGGER_PATH}.hook")
async def test_run_loop_return_timeout_event(
self, mock_hook, mock_method, trigger, caplog, container_state
):
trigger.trigger_start_time = TRIGGER_START_TIME - datetime.timedelta(minutes=2)
mock_hook.return_value.get_pod.return_value = self._mock_pod_result(
mock_hook.get_pod.return_value = self._mock_pod_result(
mock.MagicMock(
status=mock.MagicMock(
phase=PodPhase.PENDING,
Expand Down
32 changes: 16 additions & 16 deletions tests/providers/google/cloud/triggers/test_kubernetes_engine.py
Expand Up @@ -105,11 +105,11 @@ def test_serialize_should_execute_successfully(self, trigger):

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
@mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def test_run_loop_return_success_event_should_execute_successfully(
self, mock_hook, mock_method, trigger
):
mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.TERMINATED

expected_event = TriggerEvent(
Expand All @@ -126,11 +126,11 @@ async def test_run_loop_return_success_event_should_execute_successfully(

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
@mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def test_run_loop_return_failed_event_should_execute_successfully(
self, mock_hook, mock_method, trigger
):
mock_hook.return_value.get_pod.return_value = self._mock_pod_result(
mock_hook.get_pod.return_value = self._mock_pod_result(
mock.MagicMock(
status=mock.MagicMock(
message=FAILED_RESULT_MSG,
Expand All @@ -153,11 +153,11 @@ async def test_run_loop_return_failed_event_should_execute_successfully(

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
@mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def test_run_loop_return_waiting_event_should_execute_successfully(
self, mock_hook, mock_method, trigger, caplog
):
mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.WAITING

caplog.set_level(logging.INFO)
Expand All @@ -171,11 +171,11 @@ async def test_run_loop_return_waiting_event_should_execute_successfully(

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
@mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def test_run_loop_return_running_event_should_execute_successfully(
self, mock_hook, mock_method, trigger, caplog
):
mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.RUNNING

caplog.set_level(logging.INFO)
Expand All @@ -188,14 +188,14 @@ async def test_run_loop_return_running_event_should_execute_successfully(
assert f"Sleeping for {POLL_INTERVAL} seconds."

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
@mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def test_logging_in_trigger_when_exception_should_execute_successfully(
self, mock_hook, trigger, caplog
):
"""
Test that GKEStartPodTrigger fires the correct event in case of an error.
"""
mock_hook.return_value.get_pod.side_effect = Exception("Test exception")
mock_hook.get_pod.side_effect = Exception("Test exception")

generator = trigger.run()
actual = await generator.asend(None)
Expand All @@ -210,14 +210,14 @@ async def test_logging_in_trigger_when_exception_should_execute_successfully(

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
@mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def test_logging_in_trigger_when_fail_should_execute_successfully(
self, mock_hook, mock_method, trigger, caplog
):
"""
Test that GKEStartPodTrigger fires the correct event in case of fail.
"""
mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.FAILED
caplog.set_level(logging.INFO)

Expand All @@ -226,16 +226,16 @@ async def test_logging_in_trigger_when_fail_should_execute_successfully(
assert "Container logs:"

@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
@mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def test_logging_in_trigger_when_cancelled_should_execute_successfully(
self, mock_hook, trigger, caplog
):
"""
Test that GKEStartPodTrigger fires the correct event in case if the task was cancelled.
"""
mock_hook.return_value.get_pod.side_effect = CancelledError()
mock_hook.return_value.read_logs.return_value = self._mock_pod_result(mock.MagicMock())
mock_hook.return_value.delete_pod.return_value = self._mock_pod_result(mock.MagicMock())
mock_hook.get_pod.side_effect = CancelledError()
mock_hook.read_logs.return_value = self._mock_pod_result(mock.MagicMock())
mock_hook.delete_pod.return_value = self._mock_pod_result(mock.MagicMock())

generator = trigger.run()
actual = await generator.asend(None)
Expand Down

0 comments on commit 5ab43d5

Please sign in to comment.