Skip to content

Commit

Permalink
Add deferrable mode to GCSObjectsWithPrefixExistenceSensor (#30618)
Browse files Browse the repository at this point in the history
  • Loading branch information
phanikumv committed Apr 27, 2023
1 parent 35b1aad commit eed5d5b
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 6 deletions.
41 changes: 37 additions & 4 deletions airflow/providers/google/cloud/sensors/gcs.py
Expand Up @@ -29,7 +29,11 @@

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.cloud.triggers.gcs import GCSBlobTrigger, GCSCheckBlobUpdateTimeTrigger
from airflow.providers.google.cloud.triggers.gcs import (
GCSBlobTrigger,
GCSCheckBlobUpdateTimeTrigger,
GCSPrefixBlobTrigger,
)
from airflow.sensors.base import BaseSensorOperator, poke_mode_only

if TYPE_CHECKING:
Expand Down Expand Up @@ -274,6 +278,7 @@ class GCSObjectsWithPrefixExistenceSensor(BaseSensorOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run sensor in deferrable mode
"""

template_fields: Sequence[str] = (
Expand All @@ -289,6 +294,7 @@ def __init__(
prefix: str,
google_cloud_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -297,6 +303,7 @@ def __init__(
self.google_cloud_conn_id = google_cloud_conn_id
self._matches: list[str] = []
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable

def poke(self, context: Context) -> bool:
self.log.info("Checking for existence of object: %s, %s", self.bucket, self.prefix)
Expand All @@ -307,10 +314,36 @@ def poke(self, context: Context) -> bool:
self._matches = hook.list(self.bucket, prefix=self.prefix)
return bool(self._matches)

def execute(self, context: Context) -> list[str]:
def execute(self, context: Context):
"""Overridden to allow matches to be passed"""
super().execute(context)
return self._matches
self.log.info("Checking for existence of object: %s, %s", self.bucket, self.prefix)
if not self.deferrable:
super().execute(context)
return self._matches
else:
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=GCSPrefixBlobTrigger(
bucket=self.bucket,
prefix=self.prefix,
poke_interval=self.poke_interval,
google_cloud_conn_id=self.google_cloud_conn_id,
hook_params={
"impersonation_chain": self.impersonation_chain,
},
),
method_name="execute_complete",
)

def execute_complete(self, context: dict[str, Any], event: dict[str, str | list[str]]) -> str | list[str]:
"""
Callback for when the trigger fires; returns immediately.
Relies on trigger to throw a success event
"""
self.log.info("Resuming from trigger and checking status")
if event["status"] == "success":
return event["matches"]
raise AirflowException(event["message"])


def get_time():
Expand Down
81 changes: 81 additions & 0 deletions airflow/providers/google/cloud/triggers/gcs.py
Expand Up @@ -36,6 +36,8 @@ class GCSBlobTrigger(BaseTrigger):
:param object_name: the file or folder present in the bucket
:param google_cloud_conn_id: reference to the Google Connection
:param poke_interval: polling period in seconds to check for file/folder
:param hook_params: Extra config params to be passed to the underlying hook.
Should match the desired hook constructor params.
"""

def __init__(
Expand Down Expand Up @@ -200,3 +202,82 @@ async def _is_blob_updated_after(
if blob_updated_time > target_date:
return True, {"status": "success", "message": "success"}
return False, {"status": "pending", "message": "pending"}


class GCSPrefixBlobTrigger(GCSBlobTrigger):
"""
Looks for objects in bucket matching a prefix.
If none found, sleep for interval and check again. Otherwise, return matches.
:param bucket: the bucket in the google cloud storage where the objects are residing.
:param prefix: The prefix of the blob_names to match in the Google cloud storage bucket
:param google_cloud_conn_id: reference to the Google Connection
:param poke_interval: polling period in seconds to check
:param hook_params: Extra config params to be passed to the underlying hook.
Should match the desired hook constructor params.
"""

def __init__(
self,
bucket: str,
prefix: str,
poke_interval: float,
google_cloud_conn_id: str,
hook_params: dict[str, Any],
):
super().__init__(
bucket=bucket,
object_name=prefix,
poke_interval=poke_interval,
google_cloud_conn_id=google_cloud_conn_id,
hook_params=hook_params,
)
self.prefix = prefix

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes GCSPrefixBlobTrigger arguments and classpath."""
return (
"airflow.providers.google.cloud.triggers.gcs.GCSPrefixBlobTrigger",
{
"bucket": self.bucket,
"prefix": self.prefix,
"poke_interval": self.poke_interval,
"google_cloud_conn_id": self.google_cloud_conn_id,
"hook_params": self.hook_params,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Loop until the matches are found for the given prefix on the bucket."""
try:
hook = self._get_async_hook()
while True:
self.log.info(
"Checking for existence of blobs with prefix %s in bucket %s", self.prefix, self.bucket
)
res = await self._list_blobs_with_prefix(
hook=hook, bucket_name=self.bucket, prefix=self.prefix
)
if len(res) > 0:
yield TriggerEvent(
{"status": "success", "message": "Successfully completed", "matches": res}
)
await asyncio.sleep(self.poke_interval)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
return

async def _list_blobs_with_prefix(self, hook: GCSAsyncHook, bucket_name: str, prefix: str) -> list[str]:
"""
Returns names of blobs which match the given prefix for a given bucket.
:param hook: The async hook to use for listing the blobs
:param bucket_name: The Google Cloud Storage bucket where the object is.
:param prefix: The prefix of the blob_names to match in the Google cloud
storage bucket.
"""
async with ClientSession() as session:
client = await hook.get_storage_client(session)
bucket = client.get_bucket(bucket_name)
object_response = await bucket.list_blobs(prefix=prefix)
return object_response
12 changes: 12 additions & 0 deletions docs/apache-airflow-providers-google/operators/cloud/gcs.rst
Expand Up @@ -199,8 +199,20 @@ Use the :class:`~airflow.providers.google.cloud.sensors.gcs.GCSObjectsWithPrefix
:start-after: [START howto_sensor_object_with_prefix_exists_task]
:end-before: [END howto_sensor_object_with_prefix_exists_task]

You can set the ``deferrable`` param to True if you want this sensor to run asynchronously, leading to more
efficient utilization of resources in your Airflow deployment. However the triggerer component needs to be enabled
for this functionality to work.

.. exampleinclude:: /../../tests/system/providers/google/cloud/gcs/example_gcs_sensor.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_object_with_prefix_exists_task_async]
:end-before: [END howto_sensor_object_with_prefix_exists_task_async]


.. _howto/sensor:GCSUploadSessionCompleteSensor:


GCSUploadSessionCompleteSensor
------------------------------

Expand Down
44 changes: 43 additions & 1 deletion tests/providers/google/cloud/sensors/test_gcs.py
Expand Up @@ -34,7 +34,11 @@
GCSUploadSessionCompleteSensor,
ts_function,
)
from airflow.providers.google.cloud.triggers.gcs import GCSBlobTrigger, GCSCheckBlobUpdateTimeTrigger
from airflow.providers.google.cloud.triggers.gcs import (
GCSBlobTrigger,
GCSCheckBlobUpdateTimeTrigger,
GCSPrefixBlobTrigger,
)
from tests.providers.google.cloud.utils.airflow_util import create_context

TEST_BUCKET = "TEST_BUCKET"
Expand Down Expand Up @@ -333,6 +337,44 @@ def test_execute_timeout(self, mock_hook):
mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET, prefix=TEST_PREFIX)


class TestGCSObjectsWithPrefixExistenceSensorAsync:
OPERATOR = GCSObjectsWithPrefixExistenceSensor(
task_id="gcs-obj-prefix",
bucket=TEST_BUCKET,
prefix=TEST_OBJECT,
google_cloud_conn_id=TEST_GCP_CONN_ID,
deferrable=True,
)

def test_gcs_object_with_prefix_existence_sensor_async(self, context):
"""
Asserts that a task is deferred and a GCSPrefixBlobTrigger will be fired
when the GCSObjectsWithPrefixExistenceSensorAsync is executed.
"""

with pytest.raises(TaskDeferred) as exc:
self.OPERATOR.execute(context)
assert isinstance(exc.value.trigger, GCSPrefixBlobTrigger), "Trigger is not a GCSPrefixBlobTrigger"

def test_gcs_object_with_prefix_existence_sensor_async_execute_failure(self, context):
"""Tests that an AirflowException is raised in case of error event"""

with pytest.raises(AirflowException):
self.OPERATOR.execute_complete(
context=context, event={"status": "error", "message": "test failure message"}
)

def test_gcs_object_with_prefix_existence_sensor_async_execute_complete(self, context):
"""Asserts that logging occurs as expected"""

with mock.patch.object(self.OPERATOR.log, "info") as mock_log_info:
self.OPERATOR.execute_complete(
context=context,
event={"status": "success", "message": "Job completed", "matches": [TEST_OBJECT]},
)
mock_log_info.assert_called_with("Resuming from trigger and checking status")


class TestGCSUploadSessionCompleteSensor:
def setup_method(self):
self.dag = DAG(
Expand Down
97 changes: 96 additions & 1 deletion tests/providers/google/cloud/triggers/test_gcs.py
Expand Up @@ -24,7 +24,11 @@
from gcloud.aio.storage import Bucket, Storage

from airflow.providers.google.cloud.hooks.gcs import GCSAsyncHook
from airflow.providers.google.cloud.triggers.gcs import GCSBlobTrigger, GCSCheckBlobUpdateTimeTrigger
from airflow.providers.google.cloud.triggers.gcs import (
GCSBlobTrigger,
GCSCheckBlobUpdateTimeTrigger,
GCSPrefixBlobTrigger,
)
from airflow.triggers.base import TriggerEvent
from tests.providers.google.cloud.utils.compat import AsyncMock, async_mock

Expand Down Expand Up @@ -129,6 +133,97 @@ async def test_object_exists(self, exists, response, trigger):
bucket.blob_exists.assert_called_once_with(blob_name=TEST_OBJECT)


class TestGCSPrefixBlobTrigger:
TRIGGER = GCSPrefixBlobTrigger(
bucket=TEST_BUCKET,
prefix=TEST_PREFIX,
poke_interval=TEST_POLLING_INTERVAL,
google_cloud_conn_id=TEST_GCP_CONN_ID,
hook_params=TEST_HOOK_PARAMS,
)

def test_gcs_prefix_blob_trigger_serialization(self):
"""
Asserts that the GCSPrefixBlobTrigger correctly serializes its arguments
and classpath.
"""

classpath, kwargs = self.TRIGGER.serialize()
assert classpath == "airflow.providers.google.cloud.triggers.gcs.GCSPrefixBlobTrigger"
assert kwargs == {
"bucket": TEST_BUCKET,
"prefix": TEST_PREFIX,
"poke_interval": TEST_POLLING_INTERVAL,
"google_cloud_conn_id": TEST_GCP_CONN_ID,
"hook_params": TEST_HOOK_PARAMS,
}

@pytest.mark.asyncio
@async_mock.patch(
"airflow.providers.google.cloud.triggers.gcs.GCSPrefixBlobTrigger" "._list_blobs_with_prefix"
)
async def test_gcs_prefix_blob_trigger_success(self, mock_list_blobs_with_prefixs):
"""
Tests that the GCSPrefixBlobTrigger is success case
"""
mock_list_blobs_with_prefixs.return_value = ["success"]

generator = self.TRIGGER.run()
actual = await generator.asend(None)
assert (
TriggerEvent({"status": "success", "message": "Successfully completed", "matches": ["success"]})
== actual
)

@pytest.mark.asyncio
@async_mock.patch(
"airflow.providers.google.cloud.triggers.gcs.GCSPrefixBlobTrigger" "._list_blobs_with_prefix"
)
async def test_gcs_prefix_blob_trigger_exception(self, mock_list_blobs_with_prefixs):
"""
Tests the GCSPrefixBlobTrigger does fire if there is an exception.
"""
mock_list_blobs_with_prefixs.side_effect = AsyncMock(side_effect=Exception("Test exception"))

task = [i async for i in self.TRIGGER.run()]
assert len(task) == 1
assert TriggerEvent({"status": "error", "message": "Test exception"}) in task

@pytest.mark.asyncio
@async_mock.patch(
"airflow.providers.google.cloud.triggers.gcs.GCSPrefixBlobTrigger" "._list_blobs_with_prefix"
)
async def test_gcs_prefix_blob_trigger_pending(self, mock_list_blobs_with_prefixs):
"""
Test that GCSPrefixBlobTrigger is in loop if file isn't found.
"""
mock_list_blobs_with_prefixs.return_value = []

task = asyncio.create_task(self.TRIGGER.run().__anext__())
await asyncio.sleep(0.5)

# TriggerEvent was not returned
assert task.done() is False
asyncio.get_event_loop().stop()

@pytest.mark.asyncio
async def test_list_blobs_with_prefix(self):
"""
Tests to check if a particular object in Google Cloud Storage
is found or not
"""
hook = AsyncMock(GCSAsyncHook)
storage = AsyncMock(Storage)
hook.get_storage_client.return_value = storage
bucket = AsyncMock(Bucket)
storage.get_bucket.return_value = bucket
bucket.list_blobs.return_value = ["test_string"]

res = await self.TRIGGER._list_blobs_with_prefix(hook, TEST_BUCKET, TEST_PREFIX)
assert res == ["test_string"]
bucket.list_blobs.assert_called_once_with(prefix=TEST_PREFIX)


class TestGCSCheckBlobUpdateTimeTrigger:
TRIGGER = GCSCheckBlobUpdateTimeTrigger(
bucket=TEST_BUCKET,
Expand Down
9 changes: 9 additions & 0 deletions tests/system/providers/google/cloud/gcs/example_gcs_sensor.py
Expand Up @@ -144,6 +144,15 @@ def mode_setter(self, value):
)
# [END howto_sensor_object_with_prefix_exists_task]

# [START howto_sensor_object_with_prefix_exists_task_async]
gcs_object_with_prefix_exists_async = GCSObjectsWithPrefixExistenceSensor(
bucket=BUCKET_NAME,
prefix=FILE_NAME[:5],
task_id="gcs_object_with_prefix_exists_task_async",
deferrable=True,
)
# [END howto_sensor_object_with_prefix_exists_task_async]

delete_bucket = GCSDeleteBucketOperator(
task_id="delete_bucket", bucket_name=BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE
)
Expand Down

0 comments on commit eed5d5b

Please sign in to comment.