diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index 742b05e32e50a..c0600dde8ee7a 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -696,15 +696,63 @@ def delete_bucket(self, bucket_name: str, force: bool = False) -> None: except NotFound: self.log.info("Bucket %s not exists", bucket_name) - def list(self, bucket_name, versions=None, max_results=None, prefix=None, delimiter=None) -> List: + def list( + self, + bucket_name: str, + versions: bool | None = None, + max_results: int | None = None, + prefix: str | List[str] | None = None, + delimiter: str | None = None, + ): + """ + List all objects from the bucket with the given a single prefix or multiple prefixes + + :param bucket_name: bucket name + :param versions: if true, list all versions of the objects + :param max_results: max count of items to return in a single page of responses + :param prefix: string or list of strings which filter objects whose name begin with it/them + :param delimiter: filters objects based on the delimiter (for e.g '.csv') + :return: a stream of object names matching the filtering criteria + """ + objects = [] + if isinstance(prefix, list): + for prefix_item in prefix: + objects.extend( + self._list( + bucket_name=bucket_name, + versions=versions, + max_results=max_results, + prefix=prefix_item, + delimiter=delimiter, + ) + ) + else: + objects.extend( + self._list( + bucket_name=bucket_name, + versions=versions, + max_results=max_results, + prefix=prefix, + delimiter=delimiter, + ) + ) + return objects + + def _list( + self, + bucket_name: str, + versions: bool | None = None, + max_results: int | None = None, + prefix: str | None = None, + delimiter: str | None = None, + ) -> List: """ List all objects from the bucket with the give string prefix in name :param bucket_name: bucket name :param versions: if true, list all versions of the objects :param max_results: max count of items to return in a single page of responses - :param prefix: prefix string which filters objects whose name begin with - this prefix + :param prefix: string which filters objects whose name begin with it :param delimiter: filters objects based on the delimiter (for e.g '.csv') :return: a stream of object names matching the filtering criteria """ diff --git a/airflow/providers/google/cloud/operators/gcs.py b/airflow/providers/google/cloud/operators/gcs.py index e2936c0933cc1..e2ac68c90d659 100644 --- a/airflow/providers/google/cloud/operators/gcs.py +++ b/airflow/providers/google/cloud/operators/gcs.py @@ -163,8 +163,8 @@ class GCSListObjectsOperator(GoogleCloudBaseOperator): XCom in the downstream task. :param bucket: The Google Cloud Storage bucket to find the objects. (templated) - :param prefix: Prefix string which filters objects whose name begin with - this prefix. (templated) + :param prefix: String or list of strings, which filter objects whose name begin with + it/them. (templated) :param delimiter: The delimiter by which you want to filter the objects. (templated) For example, to lists the CSV files from in a directory in GCS you would use delimiter='.csv'. @@ -206,7 +206,7 @@ def __init__( self, *, bucket: str, - prefix: str | None = None, + prefix: str | list[str] | None = None, delimiter: str | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, @@ -220,14 +220,13 @@ def __init__( self.impersonation_chain = impersonation_chain def execute(self, context: Context) -> list: - hook = GCSHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) self.log.info( - "Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s", + "Getting list of the files. Bucket: %s; Delimiter: %s; Prefix(es): %s", self.bucket, self.delimiter, self.prefix, @@ -239,7 +238,6 @@ def execute(self, context: Context) -> list: uri=self.bucket, project_id=hook.project_id, ) - return hook.list(bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter) @@ -252,8 +250,8 @@ class GCSDeleteObjectsOperator(GoogleCloudBaseOperator): :param bucket_name: The GCS bucket to delete from :param objects: List of objects to delete. These should be the names of objects in the bucket, not including gs://bucket/ - :param prefix: Prefix of objects to delete. All objects matching this - prefix in the bucket will be deleted. + :param prefix: String or list of strings, which filter objects whose name begin with + it/them. (templated) :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token @@ -307,7 +305,6 @@ def execute(self, context: Context) -> None: objects = self.objects else: objects = hook.list(bucket_name=self.bucket_name, prefix=self.prefix) - self.log.info("Deleting %s objects from %s", len(objects), self.bucket_name) for object_name in objects: hook.delete(bucket_name=self.bucket_name, object_name=object_name) diff --git a/tests/providers/google/cloud/hooks/test_gcs.py b/tests/providers/google/cloud/hooks/test_gcs.py index c70a67a66a461..22cedfbd0a239 100644 --- a/tests/providers/google/cloud/hooks/test_gcs.py +++ b/tests/providers/google/cloud/hooks/test_gcs.py @@ -758,6 +758,36 @@ def test_provide_file_upload(self, mock_upload, mock_temp_file): ] ) + @pytest.mark.parametrize( + "prefix, result", + ( + ( + "prefix", + [mock.call(delimiter=",", prefix="prefix", versions=None, max_results=None, page_token=None)], + ), + ( + ["prefix", "prefix_2"], + [ + mock.call( + delimiter=",", prefix="prefix", versions=None, max_results=None, page_token=None + ), + mock.call( + delimiter=",", prefix="prefix_2", versions=None, max_results=None, page_token=None + ), + ], + ), + ), + ) + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) + def test_list(self, mock_service, prefix, result): + mock_service.return_value.bucket.return_value.list_blobs.return_value.next_page_token = None + self.gcs_hook.list( + bucket_name="test_bucket", + prefix=prefix, + delimiter=",", + ) + assert mock_service.return_value.bucket.return_value.list_blobs.call_args_list == result + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) def test_list_by_timespans(self, mock_service): test_bucket = "test_bucket" diff --git a/tests/providers/google/cloud/operators/test_gcs.py b/tests/providers/google/cloud/operators/test_gcs.py index 4e3ee66d3852a..bf9a4f5d7ea5e 100644 --- a/tests/providers/google/cloud/operators/test_gcs.py +++ b/tests/providers/google/cloud/operators/test_gcs.py @@ -38,6 +38,7 @@ TEST_PROJECT = "test-project" DELIMITER = ".csv" PREFIX = "TEST" +PREFIX_2 = "TEST2" MOCK_FILES = ["TEST1.csv", "TEST2.csv", "TEST3.csv", "OTHERTEST1.csv"] TEST_OBJECT = "dir1/test-object" LOCAL_FILE_PATH = "/home/airflow/gcp/test-object" @@ -160,11 +161,9 @@ class TestGoogleCloudStorageListOperator: @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook") def test_execute(self, mock_hook): mock_hook.return_value.list.return_value = MOCK_FILES - operator = GCSListObjectsOperator( task_id=TASK_ID, bucket=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER ) - files = operator.execute(context=mock.MagicMock()) mock_hook.return_value.list.assert_called_once_with( bucket_name=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER