Skip to content

Commit

Permalink
Add GCS Requester Pays bucket support to GCSToS3Operator (#32760)
Browse files Browse the repository at this point in the history
* Add requester pays bucket support to GCSToS3Operator

* Update docstrings

* isort

* Fix failing unit tests

* Fix failing test
  • Loading branch information
hankehly committed Jul 31, 2023
1 parent e46929b commit 915f9e4
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 35 deletions.
32 changes: 20 additions & 12 deletions airflow/providers/amazon/aws/transfers/gcs_to_s3.py
Expand Up @@ -80,6 +80,8 @@ class GCSToS3Operator(BaseOperator):
on the bucket is recreated within path passed in dest_s3_key.
:param match_glob: (Optional) filters objects based on the glob pattern given by the string
(e.g, ``'**/*/.json'``)
:param gcp_user_project: (Optional) The identifier of the Google Cloud project to bill for this request.
Required for Requester Pays buckets.
"""

template_fields: Sequence[str] = (
Expand All @@ -88,6 +90,7 @@ class GCSToS3Operator(BaseOperator):
"delimiter",
"dest_s3_key",
"google_impersonation_chain",
"gcp_user_project",
)
ui_color = "#f0eee4"

Expand All @@ -107,6 +110,7 @@ def __init__(
s3_acl_policy: str | None = None,
keep_directory_structure: bool = True,
match_glob: str | None = None,
gcp_user_project: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -130,10 +134,11 @@ def __init__(
self.s3_acl_policy = s3_acl_policy
self.keep_directory_structure = keep_directory_structure
self.match_glob = match_glob
self.gcp_user_project = gcp_user_project

def execute(self, context: Context) -> list[str]:
# list all files in an Google Cloud Storage bucket
hook = GCSHook(
gcs_hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.google_impersonation_chain,
)
Expand All @@ -145,8 +150,12 @@ def execute(self, context: Context) -> list[str]:
self.prefix,
)

files = hook.list(
bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter, match_glob=self.match_glob
gcs_files = gcs_hook.list(
bucket_name=self.bucket,
prefix=self.prefix,
delimiter=self.delimiter,
match_glob=self.match_glob,
user_project=self.gcp_user_project,
)

s3_hook = S3Hook(
Expand All @@ -173,24 +182,23 @@ def execute(self, context: Context) -> list[str]:
existing_files = existing_files if existing_files is not None else []
# remove the prefix for the existing files to allow the match
existing_files = [file.replace(prefix, "", 1) for file in existing_files]
files = list(set(files) - set(existing_files))
gcs_files = list(set(gcs_files) - set(existing_files))

if files:

for file in files:
with hook.provide_file(object_name=file, bucket_name=self.bucket) as local_tmp_file:
if gcs_files:
for file in gcs_files:
with gcs_hook.provide_file(
object_name=file, bucket_name=self.bucket, user_project=self.gcp_user_project
) as local_tmp_file:
dest_key = os.path.join(self.dest_s3_key, file)
self.log.info("Saving file to %s", dest_key)

s3_hook.load_file(
filename=local_tmp_file.name,
key=dest_key,
replace=self.replace,
acl_policy=self.s3_acl_policy,
)

self.log.info("All done, uploaded %d files to S3", len(files))
self.log.info("All done, uploaded %d files to S3", len(gcs_files))
else:
self.log.info("In sync, no files needed to be uploaded to S3")

return files
return gcs_files
52 changes: 41 additions & 11 deletions airflow/providers/google/cloud/hooks/gcs.py
Expand Up @@ -197,7 +197,6 @@ def copy(
destination_object = destination_object or source_object

if source_bucket == destination_bucket and source_object == destination_object:

raise ValueError(
f"Either source/destination bucket or source/destination object must be different, "
f"not both the same: bucket={source_bucket}, object={source_object}"
Expand Down Expand Up @@ -282,6 +281,7 @@ def download(
chunk_size: int | None = None,
timeout: int | None = DEFAULT_TIMEOUT,
num_max_attempts: int | None = 1,
user_project: str | None = None,
) -> bytes:
...

Expand All @@ -294,6 +294,7 @@ def download(
chunk_size: int | None = None,
timeout: int | None = DEFAULT_TIMEOUT,
num_max_attempts: int | None = 1,
user_project: str | None = None,
) -> str:
...

Expand All @@ -305,6 +306,7 @@ def download(
chunk_size: int | None = None,
timeout: int | None = DEFAULT_TIMEOUT,
num_max_attempts: int | None = 1,
user_project: str | None = None,
) -> str | bytes:
"""
Downloads a file from Google Cloud Storage.
Expand All @@ -320,6 +322,8 @@ def download(
:param chunk_size: Blob chunk size.
:param timeout: Request timeout in seconds.
:param num_max_attempts: Number of attempts to download the file.
:param user_project: The identifier of the Google Cloud project to bill for the request.
Required for Requester Pays buckets.
"""
# TODO: future improvement check file size before downloading,
# to check for local space availability
Expand All @@ -330,7 +334,7 @@ def download(
try:
num_file_attempts += 1
client = self.get_conn()
bucket = client.bucket(bucket_name)
bucket = client.bucket(bucket_name, user_project=user_project)
blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size)

if filename:
Expand Down Expand Up @@ -395,6 +399,7 @@ def provide_file(
object_name: str | None = None,
object_url: str | None = None,
dir: str | None = None,
user_project: str | None = None,
) -> Generator[IO[bytes], None, None]:
"""
Downloads the file to a temporary directory and returns a file handle.
Expand All @@ -406,13 +411,20 @@ def provide_file(
:param object_name: The object to fetch.
:param object_url: File reference url. Must start with "gs: //"
:param dir: The tmp sub directory to download the file to. (passed to NamedTemporaryFile)
:param user_project: The identifier of the Google Cloud project to bill for the request.
Required for Requester Pays buckets.
:return: File handler
"""
if object_name is None:
raise ValueError("Object name can not be empty")
_, _, file_name = object_name.rpartition("/")
with NamedTemporaryFile(suffix=file_name, dir=dir) as tmp_file:
self.download(bucket_name=bucket_name, object_name=object_name, filename=tmp_file.name)
self.download(
bucket_name=bucket_name,
object_name=object_name,
filename=tmp_file.name,
user_project=user_project,
)
tmp_file.flush()
yield tmp_file

Expand All @@ -423,6 +435,7 @@ def provide_file_and_upload(
bucket_name: str = PROVIDE_BUCKET,
object_name: str | None = None,
object_url: str | None = None,
user_project: str | None = None,
) -> Generator[IO[bytes], None, None]:
"""
Creates temporary file, returns a file handle and uploads the files content on close.
Expand All @@ -433,6 +446,8 @@ def provide_file_and_upload(
:param bucket_name: The bucket to fetch from.
:param object_name: The object to fetch.
:param object_url: File reference url. Must start with "gs: //"
:param user_project: The identifier of the Google Cloud project to bill for the request.
Required for Requester Pays buckets.
:return: File handler
"""
if object_name is None:
Expand All @@ -442,7 +457,12 @@ def provide_file_and_upload(
with NamedTemporaryFile(suffix=file_name) as tmp_file:
yield tmp_file
tmp_file.flush()
self.upload(bucket_name=bucket_name, object_name=object_name, filename=tmp_file.name)
self.upload(
bucket_name=bucket_name,
object_name=object_name,
filename=tmp_file.name,
user_project=user_project,
)

def upload(
self,
Expand All @@ -458,6 +478,7 @@ def upload(
num_max_attempts: int = 1,
metadata: dict | None = None,
cache_control: str | None = None,
user_project: str | None = None,
) -> None:
"""
Uploads a local file or file data as string or bytes to Google Cloud Storage.
Expand All @@ -474,6 +495,8 @@ def upload(
:param num_max_attempts: Number of attempts to try to upload the file.
:param metadata: The metadata to be uploaded with the file.
:param cache_control: Cache-Control metadata field.
:param user_project: The identifier of the Google Cloud project to bill for the request.
Required for Requester Pays buckets.
"""

def _call_with_retry(f: Callable[[], None]) -> None:
Expand Down Expand Up @@ -506,7 +529,7 @@ def _call_with_retry(f: Callable[[], None]) -> None:
continue

client = self.get_conn()
bucket = client.bucket(bucket_name)
bucket = client.bucket(bucket_name, user_project=user_project)
blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size)

if metadata:
Expand Down Expand Up @@ -596,7 +619,6 @@ def is_updated_after(self, bucket_name: str, object_name: str, ts: datetime) ->
"""
blob_update_time = self.get_blob_update_time(bucket_name, object_name)
if blob_update_time is not None:

if not ts.tzinfo:
ts = ts.replace(tzinfo=timezone.utc)
self.log.info("Verify object date: %s > %s", blob_update_time, ts)
Expand All @@ -618,7 +640,6 @@ def is_updated_between(
"""
blob_update_time = self.get_blob_update_time(bucket_name, object_name)
if blob_update_time is not None:

if not min_ts.tzinfo:
min_ts = min_ts.replace(tzinfo=timezone.utc)
if not max_ts.tzinfo:
Expand All @@ -639,7 +660,6 @@ def is_updated_before(self, bucket_name: str, object_name: str, ts: datetime) ->
"""
blob_update_time = self.get_blob_update_time(bucket_name, object_name)
if blob_update_time is not None:

if not ts.tzinfo:
ts = ts.replace(tzinfo=timezone.utc)
self.log.info("Verify object date: %s < %s", blob_update_time, ts)
Expand Down Expand Up @@ -681,16 +701,18 @@ def delete(self, bucket_name: str, object_name: str) -> None:

self.log.info("Blob %s deleted.", object_name)

def delete_bucket(self, bucket_name: str, force: bool = False) -> None:
def delete_bucket(self, bucket_name: str, force: bool = False, user_project: str | None = None) -> None:
"""
Delete a bucket object from the Google Cloud Storage.
:param bucket_name: name of the bucket which will be deleted
:param force: false not allow to delete non empty bucket, set force=True
allows to delete non empty bucket
:param user_project: The identifier of the Google Cloud project to bill for the request.
Required for Requester Pays buckets.
"""
client = self.get_conn()
bucket = client.bucket(bucket_name)
bucket = client.bucket(bucket_name, user_project=user_project)

self.log.info("Deleting %s bucket", bucket_name)
try:
Expand All @@ -707,6 +729,7 @@ def list(
prefix: str | List[str] | None = None,
delimiter: str | None = None,
match_glob: str | None = None,
user_project: str | None = None,
):
"""
List all objects from the bucket with the given a single prefix or multiple prefixes.
Expand All @@ -718,6 +741,8 @@ def list(
:param delimiter: (Deprecated) filters objects based on the delimiter (for e.g '.csv')
:param match_glob: (Optional) filters objects based on the glob pattern given by the string
(e.g, ``'**/*/.json'``).
:param user_project: The identifier of the Google Cloud project to bill for the request.
Required for Requester Pays buckets.
:return: a stream of object names matching the filtering criteria
"""
if delimiter and delimiter != "/":
Expand All @@ -739,6 +764,7 @@ def list(
prefix=prefix_item,
delimiter=delimiter,
match_glob=match_glob,
user_project=user_project,
)
)
else:
Expand All @@ -750,6 +776,7 @@ def list(
prefix=prefix,
delimiter=delimiter,
match_glob=match_glob,
user_project=user_project,
)
)
return objects
Expand All @@ -762,6 +789,7 @@ def _list(
prefix: str | None = None,
delimiter: str | None = None,
match_glob: str | None = None,
user_project: str | None = None,
) -> List:
"""
List all objects from the bucket with the give string prefix in name.
Expand All @@ -773,10 +801,12 @@ def _list(
:param delimiter: (Deprecated) filters objects based on the delimiter (for e.g '.csv')
:param match_glob: (Optional) filters objects based on the glob pattern given by the string
(e.g, ``'**/*/.json'``).
:param user_project: The identifier of the Google Cloud project to bill for the request.
Required for Requester Pays buckets.
:return: a stream of object names matching the filtering criteria
"""
client = self.get_conn()
bucket = client.bucket(bucket_name)
bucket = client.bucket(bucket_name, user_project=user_project)

ids = []
page_token = None
Expand Down
8 changes: 6 additions & 2 deletions airflow/providers/google/cloud/operators/gcs.py
Expand Up @@ -301,7 +301,6 @@ def __init__(
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:

self.bucket_name = bucket_name
self.objects = objects
self.prefix = prefix
Expand Down Expand Up @@ -875,12 +874,15 @@ class GCSDeleteBucketOperator(GoogleCloudBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param user_project: (Optional) The identifier of the project to bill for this request.
Required for Requester Pays buckets.
"""

template_fields: Sequence[str] = (
"bucket_name",
"gcp_conn_id",
"impersonation_chain",
"user_project",
)

def __init__(
Expand All @@ -890,6 +892,7 @@ def __init__(
force: bool = True,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
user_project: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -898,10 +901,11 @@ def __init__(
self.force: bool = force
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.user_project = user_project

def execute(self, context: Context) -> None:
hook = GCSHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
hook.delete_bucket(bucket_name=self.bucket_name, force=self.force)
hook.delete_bucket(bucket_name=self.bucket_name, force=self.force, user_project=self.user_project)


class GCSSynchronizeBucketsOperator(GoogleCloudBaseOperator):
Expand Down
6 changes: 5 additions & 1 deletion tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
Expand Up @@ -69,7 +69,11 @@ def test_execute__match_glob(self, mock_hook):

operator.execute(None)
mock_hook.return_value.list.assert_called_once_with(
bucket_name=GCS_BUCKET, delimiter=None, match_glob=f"**/*{DELIMITER}", prefix=PREFIX
bucket_name=GCS_BUCKET,
delimiter=None,
match_glob=f"**/*{DELIMITER}",
prefix=PREFIX,
user_project=None,
)

@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
Expand Down

0 comments on commit 915f9e4

Please sign in to comment.