Skip to content

Commit

Permalink
Add flag apply_gcs_prefix to S3ToGCSOperator (b/245077385) (#31127)
Browse files Browse the repository at this point in the history
  • Loading branch information
moiseenkov committed May 13, 2023
1 parent fdc7a31 commit fb6c501
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 64 deletions.
123 changes: 70 additions & 53 deletions airflow/providers/google/cloud/transfers/s3_to_gcs.py
Expand Up @@ -45,6 +45,15 @@ class S3ToGCSOperator(S3ListOperator):
:param bucket: The S3 bucket where to find the objects. (templated)
:param prefix: Prefix string which filters objects whose name begin with
such prefix. (templated)
:param apply_gcs_prefix: (Optional) Whether to replace source objects' path by given GCS destination path.
If apply_gcs_prefix is False (default), then objects from S3 will be copied to GCS bucket into a given
GSC path and the source path will be place inside. For example,
<s3_bucket><s3_prefix><content> => <gcs_prefix><s3_prefix><content>
If apply_gcs_prefix is True, then objects from S3 will be copied to GCS bucket into a given
GCS path and the source path will be omitted. For example:
<s3_bucket><s3_prefix><content> => <gcs_prefix><content>
:param delimiter: the delimiter marks key hierarchy. (templated)
:param aws_conn_id: The source S3 connection
:param verify: Whether or not to verify SSL certificates for S3 connection.
Expand Down Expand Up @@ -106,6 +115,7 @@ def __init__(
*,
bucket,
prefix="",
apply_gcs_prefix=False,
delimiter="",
aws_conn_id="aws_default",
verify=None,
Expand All @@ -118,6 +128,7 @@ def __init__(
):

super().__init__(bucket=bucket, prefix=prefix, delimiter=delimiter, aws_conn_id=aws_conn_id, **kwargs)
self.apply_gcs_prefix = apply_gcs_prefix
self.gcp_conn_id = gcp_conn_id
self.dest_gcs = dest_gcs
self.replace = replace
Expand All @@ -139,68 +150,74 @@ def _check_inputs(self) -> None:
def execute(self, context: Context):
self._check_inputs()
# use the super method to list all the files in an S3 bucket/key
files = super().execute(context)
s3_objects = super().execute(context)

gcs_hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.google_impersonation_chain,
)

if not self.replace:
# if we are not replacing -> list all files in the GCS bucket
# and only keep those files which are present in
# S3 and not in Google Cloud Storage
bucket_name, object_prefix = _parse_gcs_url(self.dest_gcs)
existing_files_prefixed = gcs_hook.list(bucket_name, prefix=object_prefix)

existing_files = []

if existing_files_prefixed:
# Remove the object prefix itself, an empty directory was found
if object_prefix in existing_files_prefixed:
existing_files_prefixed.remove(object_prefix)

# Remove the object prefix from all object string paths
for f in existing_files_prefixed:
if f.startswith(object_prefix):
existing_files.append(f[len(object_prefix) :])
else:
existing_files.append(f)

files = list(set(files) - set(existing_files))
if len(files) > 0:
self.log.info("%s files are going to be synced: %s.", len(files), files)
else:
self.log.info("There are no new files to sync. Have a nice day!")

if files:
s3_objects = self.exclude_existing_objects(s3_objects=s3_objects, gcs_hook=gcs_hook)

if s3_objects:
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)

for file in files:
# GCS hook builds its own in-memory file so we have to create
dest_gcs_bucket, dest_gcs_object_prefix = _parse_gcs_url(self.dest_gcs)
for obj in s3_objects:
# GCS hook builds its own in-memory file, so we have to create
# and pass the path
file_object = hook.get_key(file, self.bucket)
with NamedTemporaryFile(mode="wb", delete=True) as f:
file_object.download_fileobj(f)
f.flush()

dest_gcs_bucket, dest_gcs_object_prefix = _parse_gcs_url(self.dest_gcs)
# There will always be a '/' before file because it is
# enforced at instantiation time
dest_gcs_object = dest_gcs_object_prefix + file

# Sync is sequential and the hook already logs too much
# so skip this for now
# self.log.info(
# 'Saving file {0} from S3 bucket {1} in GCS bucket {2}'
# ' as object {3}'.format(file, self.bucket,
# dest_gcs_bucket,
# dest_gcs_object))

gcs_hook.upload(dest_gcs_bucket, dest_gcs_object, f.name, gzip=self.gzip)

self.log.info("All done, uploaded %d files to Google Cloud Storage", len(files))
file_object = hook.get_key(obj, self.bucket)
with NamedTemporaryFile(mode="wb", delete=True) as file:
file_object.download_fileobj(file)
file.flush()
gcs_file = self.s3_to_gcs_object(s3_object=obj)
gcs_hook.upload(dest_gcs_bucket, gcs_file, file.name, gzip=self.gzip)

self.log.info("All done, uploaded %d files to Google Cloud Storage", len(s3_objects))
else:
self.log.info("In sync, no files needed to be uploaded to Google Cloud Storage")

return files
return s3_objects

def exclude_existing_objects(self, s3_objects: list[str], gcs_hook: GCSHook) -> list[str]:
"""Excludes from the list objects that already exist in GCS bucket."""
bucket_name, object_prefix = _parse_gcs_url(self.dest_gcs)

existing_gcs_objects = set(gcs_hook.list(bucket_name, prefix=object_prefix))

s3_paths = set(self.gcs_to_s3_object(gcs_object=gcs_object) for gcs_object in existing_gcs_objects)
s3_objects_reduced = list(set(s3_objects) - s3_paths)

if s3_objects_reduced:
self.log.info("%s files are going to be synced: %s.", len(s3_objects_reduced), s3_objects_reduced)
else:
self.log.info("There are no new files to sync. Have a nice day!")
return s3_objects_reduced

def s3_to_gcs_object(self, s3_object: str) -> str:
"""
Transforms S3 path to GCS path according to the operator's logic.
If apply_gcs_prefix == True then <s3_prefix><content> => <gcs_prefix><content>
If apply_gcs_prefix == False then <s3_prefix><content> => <gcs_prefix><s3_prefix><content>
"""
gcs_bucket, gcs_prefix = _parse_gcs_url(self.dest_gcs)
if self.apply_gcs_prefix:
gcs_object = s3_object.replace(self.prefix, gcs_prefix, 1)
return gcs_object
return gcs_prefix + s3_object

def gcs_to_s3_object(self, gcs_object: str) -> str:
"""
Transforms GCS path to S3 path according to the operator's logic.
If apply_gcs_prefix == True then <gcs_prefix><content> => <s3_prefix><content>
If apply_gcs_prefix == False then <gcs_prefix><s3_prefix><content> => <s3_prefix><content>
"""
gcs_bucket, gcs_prefix = _parse_gcs_url(self.dest_gcs)
s3_object = gcs_object.replace(gcs_prefix, "", 1)
if self.apply_gcs_prefix:
return self.prefix + s3_object
return s3_object
146 changes: 136 additions & 10 deletions tests/providers/google/cloud/transfers/test_s3_to_gcs.py
Expand Up @@ -19,17 +19,39 @@

from unittest import mock

import pytest

from airflow.providers.google.cloud.transfers.s3_to_gcs import S3ToGCSOperator

TASK_ID = "test-s3-gcs-operator"
S3_BUCKET = "test-bucket"
S3_PREFIX = "TEST"
S3_DELIMITER = "/"
GCS_PATH_PREFIX = "gs://gcs-bucket/data/"
MOCK_FILES = ["TEST1.csv", "TEST2.csv", "TEST3.csv"]
GCS_BUCKET = "gcs-bucket"
GCS_BUCKET_URI = "gs://" + GCS_BUCKET
GCS_PREFIX = "data/"
GCS_PATH_PREFIX = GCS_BUCKET_URI + "/" + GCS_PREFIX
MOCK_FILE_1 = "TEST1.csv"
MOCK_FILE_2 = "TEST2.csv"
MOCK_FILE_3 = "TEST3.csv"
MOCK_FILES = [MOCK_FILE_1, MOCK_FILE_2, MOCK_FILE_3]
AWS_CONN_ID = "aws_default"
GCS_CONN_ID = "google_cloud_default"
IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"]
APPLY_GCS_PREFIX = False
PARAMETRIZED_OBJECT_PATHS = (
"apply_gcs_prefix, s3_prefix, s3_object, gcs_destination, gcs_object",
[
(False, "", MOCK_FILE_1, GCS_PATH_PREFIX, GCS_PREFIX + MOCK_FILE_1),
(False, S3_PREFIX, MOCK_FILE_1, GCS_PATH_PREFIX, GCS_PREFIX + S3_PREFIX + MOCK_FILE_1),
(False, "", MOCK_FILE_1, GCS_BUCKET_URI, MOCK_FILE_1),
(False, S3_PREFIX, MOCK_FILE_1, GCS_BUCKET_URI, S3_PREFIX + MOCK_FILE_1),
(True, "", MOCK_FILE_1, GCS_PATH_PREFIX, GCS_PREFIX + MOCK_FILE_1),
(True, S3_PREFIX, MOCK_FILE_1, GCS_PATH_PREFIX, GCS_PREFIX + MOCK_FILE_1),
(True, "", MOCK_FILE_1, GCS_BUCKET_URI, MOCK_FILE_1),
(True, S3_PREFIX, MOCK_FILE_1, GCS_BUCKET_URI, MOCK_FILE_1),
],
)


class TestS3ToGoogleCloudStorageOperator:
Expand All @@ -44,6 +66,7 @@ def test_init(self):
gcp_conn_id=GCS_CONN_ID,
dest_gcs=GCS_PATH_PREFIX,
google_impersonation_chain=IMPERSONATION_CHAIN,
apply_gcs_prefix=APPLY_GCS_PREFIX,
)

assert operator.task_id == TASK_ID
Expand All @@ -53,6 +76,7 @@ def test_init(self):
assert operator.gcp_conn_id == GCS_CONN_ID
assert operator.dest_gcs == GCS_PATH_PREFIX
assert operator.google_impersonation_chain == IMPERSONATION_CHAIN
assert operator.apply_gcs_prefix == APPLY_GCS_PREFIX

@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook")
@mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
Expand All @@ -73,12 +97,12 @@ def test_execute(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_hook):
s3_one_mock_hook.return_value.list_keys.return_value = MOCK_FILES
s3_two_mock_hook.return_value.list_keys.return_value = MOCK_FILES

uploaded_files = operator.execute(None)
uploaded_files = operator.execute(context={})
gcs_mock_hook.return_value.upload.assert_has_calls(
[
mock.call("gcs-bucket", "data/TEST1.csv", mock.ANY, gzip=False),
mock.call("gcs-bucket", "data/TEST3.csv", mock.ANY, gzip=False),
mock.call("gcs-bucket", "data/TEST2.csv", mock.ANY, gzip=False),
mock.call(GCS_BUCKET, GCS_PREFIX + MOCK_FILE_1, mock.ANY, gzip=False),
mock.call(GCS_BUCKET, GCS_PREFIX + MOCK_FILE_2, mock.ANY, gzip=False),
mock.call(GCS_BUCKET, GCS_PREFIX + MOCK_FILE_3, mock.ANY, gzip=False),
],
any_order=True,
)
Expand Down Expand Up @@ -112,16 +136,118 @@ def test_execute_with_gzip(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_ho
s3_one_mock_hook.return_value.list_keys.return_value = MOCK_FILES
s3_two_mock_hook.return_value.list_keys.return_value = MOCK_FILES

operator.execute(None)
operator.execute(context={})
gcs_mock_hook.assert_called_once_with(
gcp_conn_id=GCS_CONN_ID,
impersonation_chain=None,
)
gcs_mock_hook.return_value.upload.assert_has_calls(
[
mock.call("gcs-bucket", "data/TEST2.csv", mock.ANY, gzip=True),
mock.call("gcs-bucket", "data/TEST1.csv", mock.ANY, gzip=True),
mock.call("gcs-bucket", "data/TEST3.csv", mock.ANY, gzip=True),
mock.call(GCS_BUCKET, GCS_PREFIX + MOCK_FILE_1, mock.ANY, gzip=True),
mock.call(GCS_BUCKET, GCS_PREFIX + MOCK_FILE_2, mock.ANY, gzip=True),
mock.call(GCS_BUCKET, GCS_PREFIX + MOCK_FILE_3, mock.ANY, gzip=True),
],
any_order=True,
)

@pytest.mark.parametrize(
"source_objects, existing_objects, objects_expected",
[
(MOCK_FILES, [], MOCK_FILES),
(MOCK_FILES, [MOCK_FILE_1], [MOCK_FILE_2, MOCK_FILE_3]),
(MOCK_FILES, [MOCK_FILE_1, MOCK_FILE_2], [MOCK_FILE_3]),
(MOCK_FILES, [MOCK_FILE_3, MOCK_FILE_2], [MOCK_FILE_1]),
(MOCK_FILES, MOCK_FILES, []),
],
)
@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook")
def test_exclude_existing_objects(
self, mock_gcs_hook, source_objects, existing_objects, objects_expected
):
operator = S3ToGCSOperator(
task_id=TASK_ID,
bucket=S3_BUCKET,
prefix=S3_PREFIX,
delimiter=S3_DELIMITER,
gcp_conn_id=GCS_CONN_ID,
dest_gcs=GCS_PATH_PREFIX,
gzip=True,
)
mock_gcs_hook.list.return_value = existing_objects
files_reduced = operator.exclude_existing_objects(s3_objects=source_objects, gcs_hook=mock_gcs_hook)
assert set(files_reduced) == set(objects_expected)

@pytest.mark.parametrize(*PARAMETRIZED_OBJECT_PATHS)
def test_s3_to_gcs_object(self, apply_gcs_prefix, s3_prefix, s3_object, gcs_destination, gcs_object):
operator = S3ToGCSOperator(
task_id=TASK_ID,
bucket=S3_BUCKET,
prefix=s3_prefix,
delimiter=S3_DELIMITER,
gcp_conn_id=GCS_CONN_ID,
dest_gcs=gcs_destination,
gzip=True,
apply_gcs_prefix=apply_gcs_prefix,
)
assert operator.s3_to_gcs_object(s3_object=s3_prefix + s3_object) == gcs_object

@pytest.mark.parametrize(*PARAMETRIZED_OBJECT_PATHS)
def test_gcs_to_s3_object(self, apply_gcs_prefix, s3_prefix, s3_object, gcs_destination, gcs_object):
operator = S3ToGCSOperator(
task_id=TASK_ID,
bucket=S3_BUCKET,
prefix=s3_prefix,
delimiter=S3_DELIMITER,
gcp_conn_id=GCS_CONN_ID,
dest_gcs=gcs_destination,
gzip=True,
apply_gcs_prefix=apply_gcs_prefix,
)
assert operator.gcs_to_s3_object(gcs_object=gcs_object) == s3_prefix + s3_object

@pytest.mark.parametrize(*PARAMETRIZED_OBJECT_PATHS)
@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook")
@mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook")
def test_execute_apply_gcs_prefix(
self,
gcs_mock_hook,
s3_one_mock_hook,
s3_two_mock_hook,
apply_gcs_prefix,
s3_prefix,
s3_object,
gcs_destination,
gcs_object,
):

operator = S3ToGCSOperator(
task_id=TASK_ID,
bucket=S3_BUCKET,
prefix=s3_prefix,
delimiter=S3_DELIMITER,
gcp_conn_id=GCS_CONN_ID,
dest_gcs=gcs_destination,
google_impersonation_chain=IMPERSONATION_CHAIN,
apply_gcs_prefix=apply_gcs_prefix,
)

s3_one_mock_hook.return_value.list_keys.return_value = [s3_prefix + s3_object]
s3_two_mock_hook.return_value.list_keys.return_value = [s3_prefix + s3_object]

uploaded_files = operator.execute(context={})
gcs_mock_hook.return_value.upload.assert_has_calls(
[
mock.call(GCS_BUCKET, gcs_object, mock.ANY, gzip=False),
],
any_order=True,
)

s3_one_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID, verify=None)
s3_two_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID, verify=None)
gcs_mock_hook.assert_called_once_with(
gcp_conn_id=GCS_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)

assert sorted([s3_prefix + s3_object]) == sorted(uploaded_files)
6 changes: 5 additions & 1 deletion tests/system/providers/google/cloud/gcs/example_s3_to_gcs.py
Expand Up @@ -62,7 +62,11 @@ def upload_file():
)
# [START howto_transfer_s3togcs_operator]
transfer_to_gcs = S3ToGCSOperator(
task_id="s3_to_gcs_task", bucket=BUCKET_NAME, prefix=PREFIX, dest_gcs=GCS_BUCKET_URL
task_id="s3_to_gcs_task",
bucket=BUCKET_NAME,
prefix=PREFIX,
dest_gcs=GCS_BUCKET_URL,
apply_gcs_prefix=True,
)
# [END howto_transfer_s3togcs_operator]

Expand Down

0 comments on commit fb6c501

Please sign in to comment.