From 5ab43d5541a68c5c90fe849f19e344bcdeddd44f Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Tue, 19 Dec 2023 09:11:52 +0100 Subject: [PATCH] Move KubernetesPodTrigger hook to a cached property (#36290) --- .../providers/cncf/kubernetes/triggers/pod.py | 28 ++++++------ .../cloud/triggers/kubernetes_engine.py | 4 +- .../cncf/kubernetes/triggers/test_pod.py | 44 +++++++++---------- .../cloud/triggers/test_kubernetes_engine.py | 32 +++++++------- 4 files changed, 56 insertions(+), 52 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/triggers/pod.py b/airflow/providers/cncf/kubernetes/triggers/pod.py index b7f0348b667f7..3dd9eb173ca57 100644 --- a/airflow/providers/cncf/kubernetes/triggers/pod.py +++ b/airflow/providers/cncf/kubernetes/triggers/pod.py @@ -22,6 +22,7 @@ import warnings from asyncio import CancelledError from enum import Enum +from functools import cached_property from typing import TYPE_CHECKING, Any, AsyncIterator from airflow.exceptions import AirflowProviderDeprecationWarning @@ -116,7 +117,6 @@ def __init__( self.on_finish_action = OnFinishAction(on_finish_action) self.should_delete_pod = self.on_finish_action == OnFinishAction.DELETE_POD - self._hook: AsyncKubernetesHook | None = None self._since_time = None def serialize(self) -> tuple[str, dict[str, Any]]: @@ -142,11 +142,10 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Get current pod status and yield a TriggerEvent.""" - hook = self._get_async_hook() self.log.info("Checking pod %r in namespace %r.", self.pod_name, self.pod_namespace) try: while True: - pod = await hook.get_pod( + pod = await self.hook.get_pod( name=self.pod_name, namespace=self.pod_namespace, ) @@ -206,13 +205,13 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] # That means that task was marked as failed if self.get_logs: self.log.info("Outputting container logs...") - await self._get_async_hook().read_logs( + await self.hook.read_logs( name=self.pod_name, namespace=self.pod_namespace, ) if self.on_finish_action == OnFinishAction.DELETE_POD: self.log.info("Deleting pod...") - await self._get_async_hook().delete_pod( + await self.hook.delete_pod( name=self.pod_name, namespace=self.pod_namespace, ) @@ -237,14 +236,17 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] ) def _get_async_hook(self) -> AsyncKubernetesHook: - if self._hook is None: - self._hook = AsyncKubernetesHook( - conn_id=self.kubernetes_conn_id, - in_cluster=self.in_cluster, - config_file=self.config_file, - cluster_context=self.cluster_context, - ) - return self._hook + # TODO: Remove this method when the min version of kubernetes provider is 7.12.0 in Google provider. + return AsyncKubernetesHook( + conn_id=self.kubernetes_conn_id, + in_cluster=self.in_cluster, + config_file=self.config_file, + cluster_context=self.cluster_context, + ) + + @cached_property + def hook(self) -> AsyncKubernetesHook: + return self._get_async_hook() def define_container_state(self, pod: V1Pod) -> ContainerState: pod_containers = pod.status.container_statuses diff --git a/airflow/providers/google/cloud/triggers/kubernetes_engine.py b/airflow/providers/google/cloud/triggers/kubernetes_engine.py index 1fbaef72a9ce0..eb1194369ef23 100644 --- a/airflow/providers/google/cloud/triggers/kubernetes_engine.py +++ b/airflow/providers/google/cloud/triggers/kubernetes_engine.py @@ -19,6 +19,7 @@ import asyncio import warnings +from functools import cached_property from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence from google.cloud.container_v1.types import Operation @@ -137,7 +138,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: }, ) - def _get_async_hook(self) -> GKEPodAsyncHook: # type: ignore[override] + @cached_property + def hook(self) -> GKEPodAsyncHook: # type: ignore[override] return GKEPodAsyncHook( cluster_url=self._cluster_url, ssl_ca_cert=self._ssl_ca_cert, diff --git a/tests/providers/cncf/kubernetes/triggers/test_pod.py b/tests/providers/cncf/kubernetes/triggers/test_pod.py index 42a0196ed77ed..9c016ea8cfb9a 100644 --- a/tests/providers/cncf/kubernetes/triggers/test_pod.py +++ b/tests/providers/cncf/kubernetes/triggers/test_pod.py @@ -94,9 +94,9 @@ def test_serialize(self, trigger): @pytest.mark.asyncio @mock.patch(f"{TRIGGER_PATH}.define_container_state") - @mock.patch(f"{TRIGGER_PATH}._get_async_hook") + @mock.patch(f"{TRIGGER_PATH}.hook") async def test_run_loop_return_success_event(self, mock_hook, mock_method, trigger): - mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) + mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) mock_method.return_value = ContainerState.TERMINATED expected_event = TriggerEvent( @@ -113,9 +113,9 @@ async def test_run_loop_return_success_event(self, mock_hook, mock_method, trigg @pytest.mark.asyncio @mock.patch(f"{TRIGGER_PATH}.define_container_state") - @mock.patch(f"{TRIGGER_PATH}._get_async_hook") + @mock.patch(f"{TRIGGER_PATH}.hook") async def test_run_loop_return_failed_event(self, mock_hook, mock_method, trigger): - mock_hook.return_value.get_pod.return_value = self._mock_pod_result( + mock_hook.get_pod.return_value = self._mock_pod_result( mock.MagicMock( status=mock.MagicMock( message=FAILED_RESULT_MSG, @@ -138,9 +138,9 @@ async def test_run_loop_return_failed_event(self, mock_hook, mock_method, trigge @pytest.mark.asyncio @mock.patch(f"{TRIGGER_PATH}.define_container_state") - @mock.patch(f"{TRIGGER_PATH}._get_async_hook") + @mock.patch(f"{TRIGGER_PATH}.hook") async def test_run_loop_return_waiting_event(self, mock_hook, mock_method, trigger, caplog): - mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) + mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) mock_method.return_value = ContainerState.WAITING caplog.set_level(logging.INFO) @@ -154,9 +154,9 @@ async def test_run_loop_return_waiting_event(self, mock_hook, mock_method, trigg @pytest.mark.asyncio @mock.patch(f"{TRIGGER_PATH}.define_container_state") - @mock.patch(f"{TRIGGER_PATH}._get_async_hook") + @mock.patch(f"{TRIGGER_PATH}.hook") async def test_run_loop_return_running_event(self, mock_hook, mock_method, trigger, caplog): - mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) + mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) mock_method.return_value = ContainerState.RUNNING caplog.set_level(logging.INFO) @@ -169,7 +169,7 @@ async def test_run_loop_return_running_event(self, mock_hook, mock_method, trigg assert f"Sleeping for {POLL_INTERVAL} seconds." @pytest.mark.asyncio - @mock.patch(f"{TRIGGER_PATH}._get_async_hook") + @mock.patch(f"{TRIGGER_PATH}.hook") async def test_logging_in_trigger_when_exception_should_execute_successfully( self, mock_hook, trigger, caplog ): @@ -177,7 +177,7 @@ async def test_logging_in_trigger_when_exception_should_execute_successfully( Test that KubernetesPodTrigger fires the correct event in case of an error. """ - mock_hook.return_value.get_pod.side_effect = Exception("Test exception") + mock_hook.get_pod.side_effect = Exception("Test exception") generator = trigger.run() actual = await generator.asend(None) @@ -192,7 +192,7 @@ async def test_logging_in_trigger_when_exception_should_execute_successfully( @pytest.mark.asyncio @mock.patch(f"{TRIGGER_PATH}.define_container_state") - @mock.patch(f"{TRIGGER_PATH}._get_async_hook") + @mock.patch(f"{TRIGGER_PATH}.hook") async def test_logging_in_trigger_when_fail_should_execute_successfully( self, mock_hook, mock_method, trigger, caplog ): @@ -200,7 +200,7 @@ async def test_logging_in_trigger_when_fail_should_execute_successfully( Test that KubernetesPodTrigger fires the correct event in case of fail. """ - mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) + mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) mock_method.return_value = ContainerState.FAILED caplog.set_level(logging.INFO) @@ -209,7 +209,7 @@ async def test_logging_in_trigger_when_fail_should_execute_successfully( assert "Container logs:" @pytest.mark.asyncio - @mock.patch(f"{TRIGGER_PATH}._get_async_hook") + @mock.patch(f"{TRIGGER_PATH}.hook") async def test_logging_in_trigger_when_cancelled_should_execute_successfully_and_delete_pod( self, mock_hook, @@ -219,9 +219,9 @@ async def test_logging_in_trigger_when_cancelled_should_execute_successfully_and Test that KubernetesPodTrigger fires the correct event in case if the task was cancelled. """ - mock_hook.return_value.get_pod.side_effect = CancelledError() - mock_hook.return_value.read_logs.return_value = self._mock_pod_result(mock.MagicMock()) - mock_hook.return_value.delete_pod.return_value = self._mock_pod_result(mock.MagicMock()) + mock_hook.get_pod.side_effect = CancelledError() + mock_hook.read_logs.return_value = self._mock_pod_result(mock.MagicMock()) + mock_hook.delete_pod.return_value = self._mock_pod_result(mock.MagicMock()) trigger = KubernetesPodTrigger( pod_name=POD_NAME, @@ -255,7 +255,7 @@ async def test_logging_in_trigger_when_cancelled_should_execute_successfully_and assert "Deleting pod..." in caplog.text @pytest.mark.asyncio - @mock.patch(f"{TRIGGER_PATH}._get_async_hook") + @mock.patch(f"{TRIGGER_PATH}.hook") async def test_logging_in_trigger_when_cancelled_should_execute_successfully_without_delete_pod( self, mock_hook, @@ -265,9 +265,9 @@ async def test_logging_in_trigger_when_cancelled_should_execute_successfully_wit Test that KubernetesPodTrigger fires the correct event if the task was cancelled. """ - mock_hook.return_value.get_pod.side_effect = CancelledError() - mock_hook.return_value.read_logs.return_value = self._mock_pod_result(mock.MagicMock()) - mock_hook.return_value.delete_pod.return_value = self._mock_pod_result(mock.MagicMock()) + mock_hook.get_pod.side_effect = CancelledError() + mock_hook.read_logs.return_value = self._mock_pod_result(mock.MagicMock()) + mock_hook.delete_pod.return_value = self._mock_pod_result(mock.MagicMock()) trigger = KubernetesPodTrigger( pod_name=POD_NAME, @@ -341,12 +341,12 @@ def test_define_container_state_should_execute_successfully( @pytest.mark.asyncio @pytest.mark.parametrize("container_state", [ContainerState.WAITING, ContainerState.UNDEFINED]) @mock.patch(f"{TRIGGER_PATH}.define_container_state") - @mock.patch(f"{TRIGGER_PATH}._get_async_hook") + @mock.patch(f"{TRIGGER_PATH}.hook") async def test_run_loop_return_timeout_event( self, mock_hook, mock_method, trigger, caplog, container_state ): trigger.trigger_start_time = TRIGGER_START_TIME - datetime.timedelta(minutes=2) - mock_hook.return_value.get_pod.return_value = self._mock_pod_result( + mock_hook.get_pod.return_value = self._mock_pod_result( mock.MagicMock( status=mock.MagicMock( phase=PodPhase.PENDING, diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py index b28d9418ef2b1..b252ea4e30a3d 100644 --- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py @@ -105,11 +105,11 @@ def test_serialize_should_execute_successfully(self, trigger): @pytest.mark.asyncio @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state") - @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook") + @mock.patch(f"{TRIGGER_GKE_PATH}.hook") async def test_run_loop_return_success_event_should_execute_successfully( self, mock_hook, mock_method, trigger ): - mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) + mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) mock_method.return_value = ContainerState.TERMINATED expected_event = TriggerEvent( @@ -126,11 +126,11 @@ async def test_run_loop_return_success_event_should_execute_successfully( @pytest.mark.asyncio @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state") - @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook") + @mock.patch(f"{TRIGGER_GKE_PATH}.hook") async def test_run_loop_return_failed_event_should_execute_successfully( self, mock_hook, mock_method, trigger ): - mock_hook.return_value.get_pod.return_value = self._mock_pod_result( + mock_hook.get_pod.return_value = self._mock_pod_result( mock.MagicMock( status=mock.MagicMock( message=FAILED_RESULT_MSG, @@ -153,11 +153,11 @@ async def test_run_loop_return_failed_event_should_execute_successfully( @pytest.mark.asyncio @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state") - @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook") + @mock.patch(f"{TRIGGER_GKE_PATH}.hook") async def test_run_loop_return_waiting_event_should_execute_successfully( self, mock_hook, mock_method, trigger, caplog ): - mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) + mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) mock_method.return_value = ContainerState.WAITING caplog.set_level(logging.INFO) @@ -171,11 +171,11 @@ async def test_run_loop_return_waiting_event_should_execute_successfully( @pytest.mark.asyncio @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state") - @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook") + @mock.patch(f"{TRIGGER_GKE_PATH}.hook") async def test_run_loop_return_running_event_should_execute_successfully( self, mock_hook, mock_method, trigger, caplog ): - mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) + mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) mock_method.return_value = ContainerState.RUNNING caplog.set_level(logging.INFO) @@ -188,14 +188,14 @@ async def test_run_loop_return_running_event_should_execute_successfully( assert f"Sleeping for {POLL_INTERVAL} seconds." @pytest.mark.asyncio - @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook") + @mock.patch(f"{TRIGGER_GKE_PATH}.hook") async def test_logging_in_trigger_when_exception_should_execute_successfully( self, mock_hook, trigger, caplog ): """ Test that GKEStartPodTrigger fires the correct event in case of an error. """ - mock_hook.return_value.get_pod.side_effect = Exception("Test exception") + mock_hook.get_pod.side_effect = Exception("Test exception") generator = trigger.run() actual = await generator.asend(None) @@ -210,14 +210,14 @@ async def test_logging_in_trigger_when_exception_should_execute_successfully( @pytest.mark.asyncio @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state") - @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook") + @mock.patch(f"{TRIGGER_GKE_PATH}.hook") async def test_logging_in_trigger_when_fail_should_execute_successfully( self, mock_hook, mock_method, trigger, caplog ): """ Test that GKEStartPodTrigger fires the correct event in case of fail. """ - mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) + mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) mock_method.return_value = ContainerState.FAILED caplog.set_level(logging.INFO) @@ -226,16 +226,16 @@ async def test_logging_in_trigger_when_fail_should_execute_successfully( assert "Container logs:" @pytest.mark.asyncio - @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook") + @mock.patch(f"{TRIGGER_GKE_PATH}.hook") async def test_logging_in_trigger_when_cancelled_should_execute_successfully( self, mock_hook, trigger, caplog ): """ Test that GKEStartPodTrigger fires the correct event in case if the task was cancelled. """ - mock_hook.return_value.get_pod.side_effect = CancelledError() - mock_hook.return_value.read_logs.return_value = self._mock_pod_result(mock.MagicMock()) - mock_hook.return_value.delete_pod.return_value = self._mock_pod_result(mock.MagicMock()) + mock_hook.get_pod.side_effect = CancelledError() + mock_hook.read_logs.return_value = self._mock_pod_result(mock.MagicMock()) + mock_hook.delete_pod.return_value = self._mock_pod_result(mock.MagicMock()) generator = trigger.run() actual = await generator.asend(None)