diff --git a/astronomer/providers/amazon/aws/sensors/s3.py b/astronomer/providers/amazon/aws/sensors/s3.py index 3b3ceed21..c2eb7454d 100644 --- a/astronomer/providers/amazon/aws/sensors/s3.py +++ b/astronomer/providers/amazon/aws/sensors/s3.py @@ -1,9 +1,8 @@ from __future__ import annotations -import typing import warnings from datetime import timedelta -from typing import Any, Callable, List, Sequence, cast +from typing import Any, Callable, Sequence, cast from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor, S3KeysUnchangedSensor @@ -78,6 +77,8 @@ def __init__( verify=verify, **kwargs, ) + self.check_fn = check_fn + self.should_check = True if check_fn else False def execute(self, context: Context) -> None: """Check for a keys in s3 and defers using the trigger""" @@ -89,20 +90,23 @@ def execute(self, context: Context) -> None: else: raise e if not poke: - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=S3KeyTrigger( - bucket_name=cast(str, self.bucket_name), - bucket_key=self.bucket_key, - wildcard_match=self.wildcard_match, - check_fn=self.check_fn, - aws_conn_id=self.aws_conn_id, - verify=self.verify, - poke_interval=self.poke_interval, - soft_fail=self.soft_fail, - ), - method_name="execute_complete", - ) + self._defer() + + def _defer(self) -> None: + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=S3KeyTrigger( + bucket_name=cast(str, self.bucket_name), + bucket_key=self.bucket_key, + wildcard_match=self.wildcard_match, + aws_conn_id=self.aws_conn_id, + verify=self.verify, + poke_interval=self.poke_interval, + soft_fail=self.soft_fail, + should_check=self.should_check, + ), + method_name="execute_complete", + ) def execute_complete(self, context: Context, event: Any = None) -> bool | None: """ @@ -110,14 +114,15 @@ def execute_complete(self, context: Context, event: Any = None) -> bool | None: Relies on trigger to throw an exception, otherwise it assumes execution was successful. """ + if event["status"] == "running": + if self.check_fn(event["files"]): # type: ignore[misc] + return None + else: + self._defer() if event["status"] == "error": if event["soft_fail"]: raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) - elif event["status"] == "success" and "s3_objects" in event: - files = typing.cast(List[str], event["s3_objects"]) - if self.check_fn: - return self.check_fn(files) return None diff --git a/astronomer/providers/amazon/aws/triggers/s3.py b/astronomer/providers/amazon/aws/triggers/s3.py index 0f1d14142..039608ed0 100644 --- a/astronomer/providers/amazon/aws/triggers/s3.py +++ b/astronomer/providers/amazon/aws/triggers/s3.py @@ -2,7 +2,7 @@ import asyncio from datetime import datetime -from typing import Any, AsyncIterator, Callable +from typing import Any, AsyncIterator from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -23,8 +23,6 @@ class S3KeyTrigger(BaseTrigger): :param aws_conn_id: reference to the s3 connection :param hook_params: params for hook its optional :param soft_fail: Set to true to mark the task as SKIPPED on failure - :param check_fn: Function that receives the list of the S3 objects, - and returns a boolean """ def __init__( @@ -32,21 +30,21 @@ def __init__( bucket_name: str, bucket_key: list[str], wildcard_match: bool = False, - check_fn: Callable[..., bool] | None = None, aws_conn_id: str = "aws_default", poke_interval: float = 5.0, soft_fail: bool = False, + should_check: bool = False, **hook_params: Any, ): super().__init__() self.bucket_name = bucket_name self.bucket_key = bucket_key self.wildcard_match = wildcard_match - self.check_fn = check_fn self.aws_conn_id = aws_conn_id self.hook_params = hook_params self.poke_interval = poke_interval self.soft_fail = soft_fail + self.should_check = should_check def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize S3KeyTrigger arguments and classpath.""" @@ -56,11 +54,11 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "bucket_name": self.bucket_name, "bucket_key": self.bucket_key, "wildcard_match": self.wildcard_match, - "check_fn": self.check_fn, "aws_conn_id": self.aws_conn_id, "hook_params": self.hook_params, "poke_interval": self.poke_interval, "soft_fail": self.soft_fail, + "should_check": self.should_check, }, ) @@ -71,15 +69,15 @@ async def run(self) -> AsyncIterator[TriggerEvent]: async with await hook.get_client_async() as client: while True: if await hook.check_key(client, self.bucket_name, self.bucket_key, self.wildcard_match): - if self.check_fn is None: - yield TriggerEvent({"status": "success"}) - else: + if self.should_check: s3_objects = await hook.get_files( client, self.bucket_name, self.bucket_key, self.wildcard_match ) - yield TriggerEvent({"status": "success", "s3_objects": s3_objects}) + await asyncio.sleep(self.poke_interval) + yield TriggerEvent({"status": "running", "files": s3_objects}) + else: + yield TriggerEvent({"status": "success"}) await asyncio.sleep(self.poke_interval) - except Exception as e: yield TriggerEvent({"status": "error", "message": str(e), "soft_fail": self.soft_fail}) diff --git a/tests/amazon/aws/sensors/test_s3_sensors.py b/tests/amazon/aws/sensors/test_s3_sensors.py index 238ba40c3..1983aecd4 100644 --- a/tests/amazon/aws/sensors/test_s3_sensors.py +++ b/tests/amazon/aws/sensors/test_s3_sensors.py @@ -141,8 +141,7 @@ def check_fn(files: List[Any]) -> bool: check_fn=check_fn, ) assert ( - sensor.execute_complete(context={}, event={"status": "success", "s3_objects": [{"Size": 10}]}) - is True + sensor.execute_complete(context={}, event={"status": "running", "files": [{"Size": 10}]}) is None ) @parameterized.expand( diff --git a/tests/amazon/aws/triggers/test_s3_triggers.py b/tests/amazon/aws/triggers/test_s3_triggers.py index 94645bd94..0768279e4 100644 --- a/tests/amazon/aws/triggers/test_s3_triggers.py +++ b/tests/amazon/aws/triggers/test_s3_triggers.py @@ -28,9 +28,9 @@ def test_serialization(self): "wildcard_match": True, "aws_conn_id": "aws_default", "hook_params": {}, - "check_fn": None, "soft_fail": False, "poke_interval": 5.0, + "should_check": False, } @pytest.mark.asyncio @@ -89,17 +89,14 @@ async def test_run_exception(self, mock_client): async def test_run_check_fn_success(self, mock_get_files, mock_client): """Test if the task is run is in trigger with check_fn.""" - def dummy_check_fn(list_obj): - return True - mock_get_files.return_value = ["test"] mock_client.return_value.check_key.return_value = True trigger = S3KeyTrigger( - bucket_key="s3://test_bucket/file", bucket_name="test_bucket", check_fn=dummy_check_fn + bucket_key="s3://test_bucket/file", bucket_name="test_bucket", poke_interval=1, should_check=True ) generator = trigger.run() actual = await generator.asend(None) - assert TriggerEvent({"status": "success", "s3_objects": ["test"]}) == actual + assert TriggerEvent({"status": "running", "files": ["test"]}) == actual class TestS3KeysUnchangedTrigger: