Skip to content

Commit

Permalink
allow multiple prefixes in gcs delete/list hooks and operators (#30815)
Browse files Browse the repository at this point in the history
  • Loading branch information
shahar1 committed Apr 23, 2023
1 parent 2d40f41 commit 432697d
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 14 deletions.
54 changes: 51 additions & 3 deletions airflow/providers/google/cloud/hooks/gcs.py
Expand Up @@ -696,15 +696,63 @@ def delete_bucket(self, bucket_name: str, force: bool = False) -> None:
except NotFound:
self.log.info("Bucket %s not exists", bucket_name)

def list(self, bucket_name, versions=None, max_results=None, prefix=None, delimiter=None) -> List:
def list(
self,
bucket_name: str,
versions: bool | None = None,
max_results: int | None = None,
prefix: str | List[str] | None = None,
delimiter: str | None = None,
):
"""
List all objects from the bucket with the given a single prefix or multiple prefixes
:param bucket_name: bucket name
:param versions: if true, list all versions of the objects
:param max_results: max count of items to return in a single page of responses
:param prefix: string or list of strings which filter objects whose name begin with it/them
:param delimiter: filters objects based on the delimiter (for e.g '.csv')
:return: a stream of object names matching the filtering criteria
"""
objects = []
if isinstance(prefix, list):
for prefix_item in prefix:
objects.extend(
self._list(
bucket_name=bucket_name,
versions=versions,
max_results=max_results,
prefix=prefix_item,
delimiter=delimiter,
)
)
else:
objects.extend(
self._list(
bucket_name=bucket_name,
versions=versions,
max_results=max_results,
prefix=prefix,
delimiter=delimiter,
)
)
return objects

def _list(
self,
bucket_name: str,
versions: bool | None = None,
max_results: int | None = None,
prefix: str | None = None,
delimiter: str | None = None,
) -> List:
"""
List all objects from the bucket with the give string prefix in name
:param bucket_name: bucket name
:param versions: if true, list all versions of the objects
:param max_results: max count of items to return in a single page of responses
:param prefix: prefix string which filters objects whose name begin with
this prefix
:param prefix: string which filters objects whose name begin with it
:param delimiter: filters objects based on the delimiter (for e.g '.csv')
:return: a stream of object names matching the filtering criteria
"""
Expand Down
15 changes: 6 additions & 9 deletions airflow/providers/google/cloud/operators/gcs.py
Expand Up @@ -163,8 +163,8 @@ class GCSListObjectsOperator(GoogleCloudBaseOperator):
XCom in the downstream task.
:param bucket: The Google Cloud Storage bucket to find the objects. (templated)
:param prefix: Prefix string which filters objects whose name begin with
this prefix. (templated)
:param prefix: String or list of strings, which filter objects whose name begin with
it/them. (templated)
:param delimiter: The delimiter by which you want to filter the objects. (templated)
For example, to lists the CSV files from in a directory in GCS you would use
delimiter='.csv'.
Expand Down Expand Up @@ -206,7 +206,7 @@ def __init__(
self,
*,
bucket: str,
prefix: str | None = None,
prefix: str | list[str] | None = None,
delimiter: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
Expand All @@ -220,14 +220,13 @@ def __init__(
self.impersonation_chain = impersonation_chain

def execute(self, context: Context) -> list:

hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

self.log.info(
"Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s",
"Getting list of the files. Bucket: %s; Delimiter: %s; Prefix(es): %s",
self.bucket,
self.delimiter,
self.prefix,
Expand All @@ -239,7 +238,6 @@ def execute(self, context: Context) -> list:
uri=self.bucket,
project_id=hook.project_id,
)

return hook.list(bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter)


Expand All @@ -252,8 +250,8 @@ class GCSDeleteObjectsOperator(GoogleCloudBaseOperator):
:param bucket_name: The GCS bucket to delete from
:param objects: List of objects to delete. These should be the names
of objects in the bucket, not including gs://bucket/
:param prefix: Prefix of objects to delete. All objects matching this
prefix in the bucket will be deleted.
:param prefix: String or list of strings, which filter objects whose name begin with
it/them. (templated)
:param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
Expand Down Expand Up @@ -307,7 +305,6 @@ def execute(self, context: Context) -> None:
objects = self.objects
else:
objects = hook.list(bucket_name=self.bucket_name, prefix=self.prefix)

self.log.info("Deleting %s objects from %s", len(objects), self.bucket_name)
for object_name in objects:
hook.delete(bucket_name=self.bucket_name, object_name=object_name)
Expand Down
30 changes: 30 additions & 0 deletions tests/providers/google/cloud/hooks/test_gcs.py
Expand Up @@ -758,6 +758,36 @@ def test_provide_file_upload(self, mock_upload, mock_temp_file):
]
)

@pytest.mark.parametrize(
"prefix, result",
(
(
"prefix",
[mock.call(delimiter=",", prefix="prefix", versions=None, max_results=None, page_token=None)],
),
(
["prefix", "prefix_2"],
[
mock.call(
delimiter=",", prefix="prefix", versions=None, max_results=None, page_token=None
),
mock.call(
delimiter=",", prefix="prefix_2", versions=None, max_results=None, page_token=None
),
],
),
),
)
@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
def test_list(self, mock_service, prefix, result):
mock_service.return_value.bucket.return_value.list_blobs.return_value.next_page_token = None
self.gcs_hook.list(
bucket_name="test_bucket",
prefix=prefix,
delimiter=",",
)
assert mock_service.return_value.bucket.return_value.list_blobs.call_args_list == result

@mock.patch(GCS_STRING.format("GCSHook.get_conn"))
def test_list_by_timespans(self, mock_service):
test_bucket = "test_bucket"
Expand Down
3 changes: 1 addition & 2 deletions tests/providers/google/cloud/operators/test_gcs.py
Expand Up @@ -38,6 +38,7 @@
TEST_PROJECT = "test-project"
DELIMITER = ".csv"
PREFIX = "TEST"
PREFIX_2 = "TEST2"
MOCK_FILES = ["TEST1.csv", "TEST2.csv", "TEST3.csv", "OTHERTEST1.csv"]
TEST_OBJECT = "dir1/test-object"
LOCAL_FILE_PATH = "/home/airflow/gcp/test-object"
Expand Down Expand Up @@ -160,11 +161,9 @@ class TestGoogleCloudStorageListOperator:
@mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
def test_execute(self, mock_hook):
mock_hook.return_value.list.return_value = MOCK_FILES

operator = GCSListObjectsOperator(
task_id=TASK_ID, bucket=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER
)

files = operator.execute(context=mock.MagicMock())
mock_hook.return_value.list.assert_called_once_with(
bucket_name=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER
Expand Down

0 comments on commit 432697d

Please sign in to comment.