Skip to content

Commit

Permalink
Handle check_fn in S3KeySensor
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro committed Jun 6, 2023
1 parent ebb00d0 commit 9696d8f
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 39 deletions.
45 changes: 25 additions & 20 deletions astronomer/providers/amazon/aws/sensors/s3.py
@@ -1,9 +1,8 @@
from __future__ import annotations

import typing
import warnings
from datetime import timedelta
from typing import Any, Callable, List, Sequence, cast
from typing import Any, Callable, Sequence, cast

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor, S3KeysUnchangedSensor
Expand Down Expand Up @@ -78,6 +77,8 @@ def __init__(
verify=verify,
**kwargs,
)
self.check_fn = check_fn
self.should_check = True if check_fn else False

def execute(self, context: Context) -> None:
"""Check for a keys in s3 and defers using the trigger"""
Expand All @@ -89,35 +90,39 @@ def execute(self, context: Context) -> None:
else:
raise e
if not poke:
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=S3KeyTrigger(
bucket_name=cast(str, self.bucket_name),
bucket_key=self.bucket_key,
wildcard_match=self.wildcard_match,
check_fn=self.check_fn,
aws_conn_id=self.aws_conn_id,
verify=self.verify,
poke_interval=self.poke_interval,
soft_fail=self.soft_fail,
),
method_name="execute_complete",
)
self._defer()

def _defer(self) -> None:
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=S3KeyTrigger(
bucket_name=cast(str, self.bucket_name),
bucket_key=self.bucket_key,
wildcard_match=self.wildcard_match,
aws_conn_id=self.aws_conn_id,
verify=self.verify,
poke_interval=self.poke_interval,
soft_fail=self.soft_fail,
should_check=self.should_check,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: Any = None) -> bool | None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] == "running":
if self.check_fn(event["files"]): # type: ignore[misc]
return None
else:
self._defer()

Check warning on line 121 in astronomer/providers/amazon/aws/sensors/s3.py

View check run for this annotation

Codecov / codecov/patch

astronomer/providers/amazon/aws/sensors/s3.py#L121

Added line #L121 was not covered by tests
if event["status"] == "error":
if event["soft_fail"]:
raise AirflowSkipException(event["message"])
raise AirflowException(event["message"])
elif event["status"] == "success" and "s3_objects" in event:
files = typing.cast(List[str], event["s3_objects"])
if self.check_fn:
return self.check_fn(files)
return None


Expand Down
20 changes: 9 additions & 11 deletions astronomer/providers/amazon/aws/triggers/s3.py
Expand Up @@ -2,7 +2,7 @@

import asyncio
from datetime import datetime
from typing import Any, AsyncIterator, Callable
from typing import Any, AsyncIterator

from airflow.triggers.base import BaseTrigger, TriggerEvent

Expand All @@ -23,30 +23,28 @@ class S3KeyTrigger(BaseTrigger):
:param aws_conn_id: reference to the s3 connection
:param hook_params: params for hook its optional
:param soft_fail: Set to true to mark the task as SKIPPED on failure
:param check_fn: Function that receives the list of the S3 objects,
and returns a boolean
"""

def __init__(
self,
bucket_name: str,
bucket_key: list[str],
wildcard_match: bool = False,
check_fn: Callable[..., bool] | None = None,
aws_conn_id: str = "aws_default",
poke_interval: float = 5.0,
soft_fail: bool = False,
should_check: bool = False,
**hook_params: Any,
):
super().__init__()
self.bucket_name = bucket_name
self.bucket_key = bucket_key
self.wildcard_match = wildcard_match
self.check_fn = check_fn
self.aws_conn_id = aws_conn_id
self.hook_params = hook_params
self.poke_interval = poke_interval
self.soft_fail = soft_fail
self.should_check = should_check

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize S3KeyTrigger arguments and classpath."""
Expand All @@ -56,11 +54,11 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"bucket_name": self.bucket_name,
"bucket_key": self.bucket_key,
"wildcard_match": self.wildcard_match,
"check_fn": self.check_fn,
"aws_conn_id": self.aws_conn_id,
"hook_params": self.hook_params,
"poke_interval": self.poke_interval,
"soft_fail": self.soft_fail,
"should_check": self.should_check,
},
)

Expand All @@ -71,15 +69,15 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
async with await hook.get_client_async() as client:
while True:
if await hook.check_key(client, self.bucket_name, self.bucket_key, self.wildcard_match):
if self.check_fn is None:
yield TriggerEvent({"status": "success"})
else:
if self.should_check:
s3_objects = await hook.get_files(
client, self.bucket_name, self.bucket_key, self.wildcard_match
)
yield TriggerEvent({"status": "success", "s3_objects": s3_objects})
await asyncio.sleep(self.poke_interval)
yield TriggerEvent({"status": "running", "files": s3_objects})
else:
yield TriggerEvent({"status": "success"})
await asyncio.sleep(self.poke_interval)

except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e), "soft_fail": self.soft_fail})

Expand Down
3 changes: 1 addition & 2 deletions tests/amazon/aws/sensors/test_s3_sensors.py
Expand Up @@ -141,8 +141,7 @@ def check_fn(files: List[Any]) -> bool:
check_fn=check_fn,
)
assert (
sensor.execute_complete(context={}, event={"status": "success", "s3_objects": [{"Size": 10}]})
is True
sensor.execute_complete(context={}, event={"status": "running", "files": [{"Size": 10}]}) is None
)

@parameterized.expand(
Expand Down
9 changes: 3 additions & 6 deletions tests/amazon/aws/triggers/test_s3_triggers.py
Expand Up @@ -28,9 +28,9 @@ def test_serialization(self):
"wildcard_match": True,
"aws_conn_id": "aws_default",
"hook_params": {},
"check_fn": None,
"soft_fail": False,
"poke_interval": 5.0,
"should_check": False,
}

@pytest.mark.asyncio
Expand Down Expand Up @@ -89,17 +89,14 @@ async def test_run_exception(self, mock_client):
async def test_run_check_fn_success(self, mock_get_files, mock_client):
"""Test if the task is run is in trigger with check_fn."""

def dummy_check_fn(list_obj):
return True

mock_get_files.return_value = ["test"]
mock_client.return_value.check_key.return_value = True
trigger = S3KeyTrigger(
bucket_key="s3://test_bucket/file", bucket_name="test_bucket", check_fn=dummy_check_fn
bucket_key="s3://test_bucket/file", bucket_name="test_bucket", poke_interval=1, should_check=True
)
generator = trigger.run()
actual = await generator.asend(None)
assert TriggerEvent({"status": "success", "s3_objects": ["test"]}) == actual
assert TriggerEvent({"status": "running", "files": ["test"]}) == actual


class TestS3KeysUnchangedTrigger:
Expand Down

0 comments on commit 9696d8f

Please sign in to comment.