Skip to content

Commit

Permalink
Combine 8 into 1 (#29462)
Browse files Browse the repository at this point in the history
Implement deferrable mode for S3ToGCSOperator

Update airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py

Co-authored-by: Wei Lee <weilee.rx@gmail.com>
  • Loading branch information
moiseenkov and Lee-W committed Jul 11, 2023
1 parent 3f6ac2f commit 86c6cc9
Show file tree
Hide file tree
Showing 11 changed files with 1,139 additions and 24 deletions.
Expand Up @@ -15,7 +15,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains a Google Storage Transfer Service Hook."""
"""
This module contains a Google Storage Transfer Service Hook.
.. spelling::
ListTransferJobsAsyncPager
StorageTransferServiceAsyncClient
"""

from __future__ import annotations

import json
Expand All @@ -24,13 +33,23 @@
import warnings
from copy import deepcopy
from datetime import timedelta
from typing import Sequence

from typing import Any, Sequence

from google.cloud.storage_transfer_v1 import (
ListTransferJobsRequest,
StorageTransferServiceAsyncClient,
TransferJob,
TransferOperation,
)
from google.cloud.storage_transfer_v1.services.storage_transfer_service.pagers import (
ListTransferJobsAsyncPager,
)
from googleapiclient.discovery import Resource, build
from googleapiclient.errors import HttpError
from proto import Message

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,6 +79,7 @@ class GcpTransferOperationStatus:
ACCESS_KEY_ID = "accessKeyId"
ALREADY_EXISTING_IN_SINK = "overwriteObjectsAlreadyExistingInSink"
AWS_ACCESS_KEY = "awsAccessKey"
AWS_SECRET_ACCESS_KEY = "secretAccessKey"
AWS_S3_DATA_SOURCE = "awsS3DataSource"
BODY = "body"
BUCKET_NAME = "bucketName"
Expand All @@ -73,6 +93,7 @@ class GcpTransferOperationStatus:
GCS_DATA_SOURCE = "gcsDataSource"
HOURS = "hours"
HTTP_DATA_SOURCE = "httpDataSource"
INCLUDE_PREFIXES = "includePrefixes"
JOB_NAME = "name"
LIST_URL = "list_url"
METADATA = "metadata"
Expand All @@ -81,6 +102,7 @@ class GcpTransferOperationStatus:
NAME = "name"
OBJECT_CONDITIONS = "object_conditions"
OPERATIONS = "operations"
OVERWRITE_OBJECTS_ALREADY_EXISTING_IN_SINK = "overwriteObjectsAlreadyExistingInSink"
PATH = "path"
PROJECT_ID = "projectId"
SCHEDULE = "schedule"
Expand Down Expand Up @@ -466,3 +488,50 @@ def operations_contain_expected_statuses(
f"Expected: {', '.join(expected_statuses_set)}"
)
return False


class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
"""Asynchronous hook for Google Storage Transfer Service."""

def __init__(self, project_id: str | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.project_id = project_id
self._client: StorageTransferServiceAsyncClient | None = None

def get_conn(self) -> StorageTransferServiceAsyncClient:
"""
Returns async connection to the Storage Transfer Service.
:return: Google Storage Transfer asynchronous client.
"""
if not self._client:
self._client = StorageTransferServiceAsyncClient()
return self._client

async def get_jobs(self, job_names: list[str]) -> ListTransferJobsAsyncPager:
"""
Gets the latest state of a long-running operations in Google Storage Transfer Service.
:param job_names: (Required) List of names of the jobs to be fetched.
:return: Object that yields Transfer jobs.
"""
client = self.get_conn()
jobs_list_request = ListTransferJobsRequest(
filter=json.dumps(dict(project_id=self.project_id, job_names=job_names))
)
return await client.list_transfer_jobs(request=jobs_list_request)

async def get_latest_operation(self, job: TransferJob) -> Message | None:
"""
Gets the latest operation of the given TransferJob instance.
:param job: Transfer job instance.
:return: The latest job operation.
"""
latest_operation_name = job.latest_operation_name
if latest_operation_name:
client = self.get_conn()
response_operation = await client.transport.operations_client.get_operation(latest_operation_name)
operation = TransferOperation.deserialize(response_operation.metadata.value)
return operation
return None
164 changes: 145 additions & 19 deletions airflow/providers/google/cloud/transfers/s3_to_gcs.py
Expand Up @@ -17,12 +17,38 @@
# under the License.
from __future__ import annotations

from datetime import datetime
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import (
ACCESS_KEY_ID,
AWS_ACCESS_KEY,
AWS_S3_DATA_SOURCE,
AWS_SECRET_ACCESS_KEY,
BUCKET_NAME,
GCS_DATA_SINK,
INCLUDE_PREFIXES,
OBJECT_CONDITIONS,
OVERWRITE_OBJECTS_ALREADY_EXISTING_IN_SINK,
PATH,
PROJECT_ID,
SCHEDULE,
SCHEDULE_END_DATE,
SCHEDULE_START_DATE,
STATUS,
TRANSFER_OPTIONS,
TRANSFER_SPEC,
CloudDataTransferServiceHook,
GcpTransferJobsStatus,
)
from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url, gcs_object_is_directory
from airflow.providers.google.cloud.triggers.cloud_storage_transfer_service import (
CloudStorageTransferServiceCreateJobsTrigger,
)

try:
from airflow.providers.amazon.aws.operators.s3 import S3ListOperator
Expand Down Expand Up @@ -70,7 +96,7 @@ class S3ToGCSOperator(S3ListOperator):
where you want to store the files. (templated)
:param replace: Whether you want to replace existing destination files
or not.
:param gzip: Option to compress file for upload
:param gzip: Option to compress file for upload. Parameter ignored in deferrable mode.
:param google_impersonation_chain: Optional Google service account to impersonate using
short-term credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
Expand All @@ -79,7 +105,9 @@ class S3ToGCSOperator(S3ListOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode
:param poll_interval: time in seconds between polling for job completion.
The value is considered only when running in deferrable mode. Must be greater than 0.
**Example**:
Expand Down Expand Up @@ -108,6 +136,7 @@ class S3ToGCSOperator(S3ListOperator):
"google_impersonation_chain",
)
ui_color = "#e09411"
transfer_job_max_files_number = 1000

def __init__(
self,
Expand All @@ -123,6 +152,8 @@ def __init__(
replace=False,
gzip=False,
google_impersonation_chain: str | Sequence[str] | None = None,
deferrable=conf.getboolean("operators", "default_deferrable", fallback=False),
poll_interval: int = 10,
**kwargs,
):

Expand All @@ -134,6 +165,10 @@ def __init__(
self.verify = verify
self.gzip = gzip
self.google_impersonation_chain = google_impersonation_chain
self.deferrable = deferrable
if poll_interval <= 0:
raise ValueError("Invalid value for poll_interval. Expected value greater than 0")
self.poll_interval = poll_interval

def _check_inputs(self) -> None:
if self.dest_gcs and not gcs_object_is_directory(self.dest_gcs):
Expand All @@ -158,23 +193,13 @@ def execute(self, context: Context):
if not self.replace:
s3_objects = self.exclude_existing_objects(s3_objects=s3_objects, gcs_hook=gcs_hook)

if s3_objects:
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)

dest_gcs_bucket, dest_gcs_object_prefix = _parse_gcs_url(self.dest_gcs)
for obj in s3_objects:
# GCS hook builds its own in-memory file, so we have to create
# and pass the path
file_object = hook.get_key(obj, self.bucket)
with NamedTemporaryFile(mode="wb", delete=True) as file:
file_object.download_fileobj(file)
file.flush()
gcs_file = self.s3_to_gcs_object(s3_object=obj)
gcs_hook.upload(dest_gcs_bucket, gcs_file, file.name, gzip=self.gzip)

self.log.info("All done, uploaded %d files to Google Cloud Storage", len(s3_objects))
else:
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
if not s3_objects:
self.log.info("In sync, no files needed to be uploaded to Google Cloud Storage")
elif self.deferrable:
self.transfer_files_async(s3_objects, gcs_hook, s3_hook)
else:
self.transfer_files(s3_objects, gcs_hook, s3_hook)

return s3_objects

Expand Down Expand Up @@ -220,3 +245,104 @@ def gcs_to_s3_object(self, gcs_object: str) -> str:
if self.apply_gcs_prefix:
return self.prefix + s3_object
return s3_object

def transfer_files(self, s3_objects: list[str], gcs_hook: GCSHook, s3_hook: S3Hook) -> None:
if s3_objects:
dest_gcs_bucket, dest_gcs_object_prefix = _parse_gcs_url(self.dest_gcs)
for obj in s3_objects:
# GCS hook builds its own in-memory file, so we have to create
# and pass the path
file_object = s3_hook.get_key(obj, self.bucket)
with NamedTemporaryFile(mode="wb", delete=True) as file:
file_object.download_fileobj(file)
file.flush()
gcs_file = self.s3_to_gcs_object(s3_object=obj)
gcs_hook.upload(dest_gcs_bucket, gcs_file, file.name, gzip=self.gzip)

self.log.info("All done, uploaded %d files to Google Cloud Storage", len(s3_objects))

def transfer_files_async(self, files: list[str], gcs_hook: GCSHook, s3_hook: S3Hook) -> None:
"""Submits Google Cloud Storage Transfer Service job to copy files from AWS S3 to GCS."""
if not len(files):
raise ValueError("List of transferring files cannot be empty")
job_names = self.submit_transfer_jobs(files=files, gcs_hook=gcs_hook, s3_hook=s3_hook)

self.defer(
trigger=CloudStorageTransferServiceCreateJobsTrigger(
project_id=gcs_hook.project_id,
job_names=job_names,
poll_interval=self.poll_interval,
),
method_name="execute_complete",
)

def submit_transfer_jobs(self, files: list[str], gcs_hook: GCSHook, s3_hook: S3Hook) -> list[str]:
now = datetime.utcnow()
one_time_schedule = {"day": now.day, "month": now.month, "year": now.year}

gcs_bucket, gcs_prefix = _parse_gcs_url(self.dest_gcs)
config = s3_hook.conn_config

body: dict[str, Any] = {
PROJECT_ID: gcs_hook.project_id,
STATUS: GcpTransferJobsStatus.ENABLED,
SCHEDULE: {
SCHEDULE_START_DATE: one_time_schedule,
SCHEDULE_END_DATE: one_time_schedule,
},
TRANSFER_SPEC: {
AWS_S3_DATA_SOURCE: {
BUCKET_NAME: self.bucket,
AWS_ACCESS_KEY: {
ACCESS_KEY_ID: config.aws_access_key_id,
AWS_SECRET_ACCESS_KEY: config.aws_secret_access_key,
},
},
OBJECT_CONDITIONS: {
INCLUDE_PREFIXES: [],
},
GCS_DATA_SINK: {BUCKET_NAME: gcs_bucket, PATH: gcs_prefix},
TRANSFER_OPTIONS: {
OVERWRITE_OBJECTS_ALREADY_EXISTING_IN_SINK: self.replace,
},
},
}

# max size of the field 'transfer_job.transfer_spec.object_conditions.include_prefixes' is 1000,
# that's why we submit multiple jobs transferring 1000 files each.
# See documentation below
# https://cloud.google.com/storage-transfer/docs/reference/rest/v1/TransferSpec#ObjectConditions
chunk_size = self.transfer_job_max_files_number
job_names = []
transfer_hook = self.get_transfer_hook()
for i in range(0, len(files), chunk_size):
files_chunk = files[i : i + chunk_size]
body[TRANSFER_SPEC][OBJECT_CONDITIONS][INCLUDE_PREFIXES] = files_chunk
job = transfer_hook.create_transfer_job(body=body)

s = "s" if len(files_chunk) > 1 else ""
self.log.info(f"Submitted job {job['name']} to transfer {len(files_chunk)} file{s}")
job_names.append(job["name"])

if len(files) > chunk_size:
js = "s" if len(job_names) > 1 else ""
fs = "s" if len(files) > 1 else ""
self.log.info(f"Overall submitted {len(job_names)} job{js} to transfer {len(files)} file{fs}")

return job_names

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])
self.log.info("%s completed with response %s ", self.task_id, event["message"])

def get_transfer_hook(self):
return CloudDataTransferServiceHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.google_impersonation_chain,
)

0 comments on commit 86c6cc9

Please sign in to comment.