Skip to content

Commit

Permalink
Add use_glob to GCSObjectExistenceSensor (#34137)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Alberto Costa <alberto.costa@windtre.it>
  • Loading branch information
A-Costa and Alberto Costa committed Dec 17, 2023
1 parent 0ef3225 commit 9233541
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 8 deletions.
11 changes: 10 additions & 1 deletion airflow/providers/google/cloud/sensors/gcs.py
Expand Up @@ -50,6 +50,7 @@ class GCSObjectExistenceSensor(BaseSensorOperator):
:param bucket: The Google Cloud Storage bucket where the object is.
:param object: The name of the object to check in the Google cloud
storage bucket.
:param use_glob: When set to True the object parameter is interpreted as glob
:param google_cloud_conn_id: The connection ID to use when
connecting to Google Cloud Storage.
:param impersonation_chain: Optional service account to impersonate using short-term
Expand All @@ -75,6 +76,7 @@ def __init__(
*,
bucket: str,
object: str,
use_glob: bool = False,
google_cloud_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
retry: Retry = DEFAULT_RETRY,
Expand All @@ -84,7 +86,9 @@ def __init__(
super().__init__(**kwargs)
self.bucket = bucket
self.object = object
self.use_glob = use_glob
self.google_cloud_conn_id = google_cloud_conn_id
self._matches: list[str] = []
self.impersonation_chain = impersonation_chain
self.retry = retry

Expand All @@ -96,7 +100,11 @@ def poke(self, context: Context) -> bool:
gcp_conn_id=self.google_cloud_conn_id,
impersonation_chain=self.impersonation_chain,
)
return hook.exists(self.bucket, self.object, self.retry)
if self.use_glob:
self._matches = hook.list(self.bucket, match_glob=self.object)
return bool(self._matches)
else:
return hook.exists(self.bucket, self.object, self.retry)

def execute(self, context: Context) -> None:
"""Airflow runs this method on the worker and defers using the trigger."""
Expand All @@ -109,6 +117,7 @@ def execute(self, context: Context) -> None:
trigger=GCSBlobTrigger(
bucket=self.bucket,
object_name=self.object,
use_glob=self.use_glob,
poke_interval=self.poke_interval,
google_cloud_conn_id=self.google_cloud_conn_id,
hook_params={
Expand Down
16 changes: 13 additions & 3 deletions airflow/providers/google/cloud/triggers/gcs.py
Expand Up @@ -35,6 +35,7 @@ class GCSBlobTrigger(BaseTrigger):
:param bucket: the bucket in the google cloud storage where the objects are residing.
:param object_name: the file or folder present in the bucket
:param use_glob: if true object_name is interpreted as glob
:param google_cloud_conn_id: reference to the Google Connection
:param poke_interval: polling period in seconds to check for file/folder
:param hook_params: Extra config params to be passed to the underlying hook.
Expand All @@ -45,13 +46,15 @@ def __init__(
self,
bucket: str,
object_name: str,
use_glob: bool,
poke_interval: float,
google_cloud_conn_id: str,
hook_params: dict[str, Any],
):
super().__init__()
self.bucket = bucket
self.object_name = object_name
self.use_glob = use_glob
self.poke_interval = poke_interval
self.google_cloud_conn_id: str = google_cloud_conn_id
self.hook_params = hook_params
Expand All @@ -63,6 +66,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
{
"bucket": self.bucket,
"object_name": self.object_name,
"use_glob": self.use_glob,
"poke_interval": self.poke_interval,
"google_cloud_conn_id": self.google_cloud_conn_id,
"hook_params": self.hook_params,
Expand Down Expand Up @@ -98,9 +102,14 @@ async def _object_exists(self, hook: GCSAsyncHook, bucket_name: str, object_name
async with ClientSession() as s:
client = await hook.get_storage_client(s)
bucket = client.get_bucket(bucket_name)
object_response = await bucket.blob_exists(blob_name=object_name)
if object_response:
return "success"
if self.use_glob:
list_blobs_response = await bucket.list_blobs(match_glob=object_name)
if len(list_blobs_response) > 0:
return "success"
else:
blob_exists_response = await bucket.blob_exists(blob_name=object_name)
if blob_exists_response:
return "success"
return "pending"


Expand Down Expand Up @@ -234,6 +243,7 @@ def __init__(
poke_interval=poke_interval,
google_cloud_conn_id=google_cloud_conn_id,
hook_params=hook_params,
use_glob=False,
)
self.prefix = prefix

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/provider.yaml
Expand Up @@ -87,7 +87,7 @@ dependencies:
- asgiref>=3.5.2
- gcloud-aio-auth>=4.0.0,<5.0.0
- gcloud-aio-bigquery>=6.1.2
- gcloud-aio-storage
- gcloud-aio-storage>=9.0.0
- gcsfs>=2023.10.0
- google-ads>=22.1.0
- google-api-core>=2.11.0
Expand Down
2 changes: 1 addition & 1 deletion docs/apache-airflow/img/airflow_erd.sha256
@@ -1 +1 @@
a5677b0b603e8835f92da4b8b061ec268ce7257ef6b446f12593743ecf90710a
194706fc390025f473f73ce934bfe4b394b50ce76748e5df33ae643e38538357
4 changes: 2 additions & 2 deletions docs/apache-airflow/img/airflow_erd.svg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 22 additions & 0 deletions tests/providers/google/cloud/sensors/test_gcs.py
Expand Up @@ -94,6 +94,7 @@ def test_should_pass_argument_to_hook(self, mock_hook):
task_id="task-id",
bucket=TEST_BUCKET,
object=TEST_OBJECT,
use_glob=False,
google_cloud_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
Expand All @@ -108,6 +109,27 @@ def test_should_pass_argument_to_hook(self, mock_hook):
)
mock_hook.return_value.exists.assert_called_once_with(TEST_BUCKET, TEST_OBJECT, DEFAULT_RETRY)

@mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook")
def test_should_pass_argument_to_hook_using_glob(self, mock_hook):
task = GCSObjectExistenceSensor(
task_id="task-id",
bucket=TEST_BUCKET,
object=TEST_OBJECT,
use_glob=True,
google_cloud_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
mock_hook.return_value.list.return_value = [mock.MagicMock()]

result = task.poke(mock.MagicMock())

assert result is True
mock_hook.assert_called_once_with(
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET, match_glob=TEST_OBJECT)

@mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook")
@mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceSensor.defer")
def test_gcs_object_existence_sensor_finish_before_deferred(self, mock_defer, mock_hook):
Expand Down
38 changes: 38 additions & 0 deletions tests/providers/google/cloud/triggers/test_gcs.py
Expand Up @@ -55,6 +55,19 @@ def trigger():
return GCSBlobTrigger(
bucket=TEST_BUCKET,
object_name=TEST_OBJECT,
use_glob=False,
poke_interval=TEST_POLLING_INTERVAL,
google_cloud_conn_id=TEST_GCP_CONN_ID,
hook_params=TEST_HOOK_PARAMS,
)


@pytest.fixture
def trigger_using_glob():
return GCSBlobTrigger(
bucket=TEST_BUCKET,
object_name=TEST_OBJECT,
use_glob=True,
poke_interval=TEST_POLLING_INTERVAL,
google_cloud_conn_id=TEST_GCP_CONN_ID,
hook_params=TEST_HOOK_PARAMS,
Expand All @@ -73,6 +86,7 @@ def test_gcs_blob_trigger_serialization(self, trigger):
assert kwargs == {
"bucket": TEST_BUCKET,
"object_name": TEST_OBJECT,
"use_glob": False,
"poke_interval": TEST_POLLING_INTERVAL,
"google_cloud_conn_id": TEST_GCP_CONN_ID,
"hook_params": TEST_HOOK_PARAMS,
Expand Down Expand Up @@ -141,6 +155,30 @@ async def test_object_exists(self, exists, response, trigger):
assert res == response
bucket.blob_exists.assert_called_once_with(blob_name=TEST_OBJECT)

@pytest.mark.asyncio
@pytest.mark.parametrize(
"blob_list,response",
[
([TEST_OBJECT], "success"),
([], "pending"),
],
)
async def test_object_exists_using_glob(self, blob_list, response, trigger_using_glob):
"""
Tests to check if a particular object in Google Cloud Storage
is found or not
"""
hook = AsyncMock(GCSAsyncHook)
storage = AsyncMock(Storage)
hook.get_storage_client.return_value = storage
bucket = AsyncMock(Bucket)
storage.get_bucket.return_value = bucket
bucket.list_blobs.return_value = blob_list

res = await trigger_using_glob._object_exists(hook, TEST_BUCKET, TEST_OBJECT)
assert res == response
bucket.list_blobs.assert_called_once_with(match_glob=TEST_OBJECT)


class TestGCSPrefixBlobTrigger:
TRIGGER = GCSPrefixBlobTrigger(
Expand Down

0 comments on commit 9233541

Please sign in to comment.