Skip to content

Commit

Permalink
Add deferrable mode for Big Query Transfer operator (#27833)
Browse files Browse the repository at this point in the history
  • Loading branch information
MrGeorgeOwl committed Jan 20, 2023
1 parent 4b3a9ca commit 5fcdd32
Show file tree
Hide file tree
Showing 7 changed files with 577 additions and 13 deletions.
75 changes: 73 additions & 2 deletions airflow/providers/google/cloud/hooks/bigquery_dts.py
Expand Up @@ -23,7 +23,7 @@

from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.api_core.retry import Retry
from google.cloud.bigquery_datatransfer_v1 import DataTransferServiceClient
from google.cloud.bigquery_datatransfer_v1 import DataTransferServiceAsyncClient, DataTransferServiceClient
from google.cloud.bigquery_datatransfer_v1.types import (
StartManualTransferRunsResponse,
TransferConfig,
Expand All @@ -32,7 +32,11 @@
from googleapiclient.discovery import Resource

from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
from airflow.providers.google.common.hooks.base_google import (
PROVIDE_PROJECT_ID,
GoogleBaseAsyncHook,
GoogleBaseHook,
)


def get_object_id(obj: dict) -> str:
Expand Down Expand Up @@ -263,3 +267,70 @@ def get_transfer_run(
return client.get_transfer_run(
request={"name": name}, retry=retry, timeout=timeout, metadata=metadata or ()
)


class AsyncBiqQueryDataTransferServiceHook(GoogleBaseAsyncHook):
"""Hook of the BigQuery service to be used with async client of the Google library."""

sync_hook_class = BiqQueryDataTransferServiceHook

def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
delegate_to: str | None = None,
location: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
):
super().__init__(
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
location=location,
impersonation_chain=impersonation_chain,
)
self._conn: DataTransferServiceAsyncClient | None = None

async def _get_conn(self) -> DataTransferServiceAsyncClient:
if not self._conn:
credentials = (await self.get_sync_hook()).get_credentials()
self._conn = DataTransferServiceAsyncClient(credentials=credentials, client_info=CLIENT_INFO)
return self._conn

async def _get_project_id(self) -> str:
sync_hook = await self.get_sync_hook()
return sync_hook.project_id

async def get_transfer_run(
self,
config_id: str,
run_id: str,
project_id: str | None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
):
"""
Returns information about the particular transfer run.
:param run_id: ID of the transfer run.
:param config_id: ID of transfer config to be used.
:param project_id: The BigQuery project id where the transfer configuration should be
created. If set to None or missing, the default project_id from the Google Cloud connection
is used.
:param retry: A retry object used to retry requests. If `None` is
specified, requests will not be retried.
:param timeout: The amount of time, in seconds, to wait for the request to
complete. Note that if retry is specified, the timeout applies to each individual
attempt.
:param metadata: Additional metadata that is provided to the method.
:return: An ``google.cloud.bigquery_datatransfer_v1.types.TransferRun`` instance.
"""
project_id = project_id or (await self._get_project_id())
client = await self._get_conn()
name = f"projects/{project_id}/transferConfigs/{config_id}/runs/{run_id}"
transfer_run = await client.get_transfer_run(
name=name,
retry=retry,
timeout=timeout,
metadata=metadata,
)
return transfer_run
106 changes: 99 additions & 7 deletions airflow/providers/google/cloud/operators/bigquery_dts.py
Expand Up @@ -18,15 +18,24 @@
"""This module contains Google BigQuery Data Transfer Service operators."""
from __future__ import annotations

import time
from typing import TYPE_CHECKING, Sequence

from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.api_core.retry import Retry
from google.cloud.bigquery_datatransfer_v1 import StartManualTransferRunsResponse, TransferConfig
from google.cloud.bigquery_datatransfer_v1 import (
StartManualTransferRunsResponse,
TransferConfig,
TransferRun,
TransferState,
)

from airflow import AirflowException
from airflow.compat.functools import cached_property
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.bigquery_dts import BiqQueryDataTransferServiceHook, get_object_id
from airflow.providers.google.cloud.links.bigquery_dts import BigQueryDataTransferConfigLink
from airflow.providers.google.cloud.triggers.bigquery_dts import BigQueryDataTransferRunTrigger

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -224,7 +233,7 @@ class BigQueryDataTransferServiceStartTransferRunsOperator(BaseOperator):
must be of the same form as the protobuf message
`~google.cloud.bigquery_datatransfer_v1.types.Timestamp`
:param project_id: The BigQuery project id where the transfer configuration should be
created. If set to None or missing, the default project_id from the Google Cloud connection is used.
created.
:param location: BigQuery Transfer Service location for regional transfers.
:param retry: A retry object used to retry requests. If `None` is
specified, requests will not be retried.
Expand All @@ -241,6 +250,7 @@ class BigQueryDataTransferServiceStartTransferRunsOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode.
"""

template_fields: Sequence[str] = (
Expand All @@ -266,6 +276,7 @@ def __init__(
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id="google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -279,13 +290,20 @@ def __init__(
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable

def execute(self, context: Context):
@cached_property
def hook(self) -> BiqQueryDataTransferServiceHook:
hook = BiqQueryDataTransferServiceHook(
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, location=self.location
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
location=self.location,
)
return hook

def execute(self, context: Context):
self.log.info("Submitting manual transfer for %s", self.transfer_config_id)
response = hook.start_manual_transfer_runs(
response = self.hook.start_manual_transfer_runs(
transfer_config_id=self.transfer_config_id,
requested_time_range=self.requested_time_range,
requested_run_time=self.requested_run_time,
Expand All @@ -307,5 +325,79 @@ def execute(self, context: Context):
result = StartManualTransferRunsResponse.to_dict(response)
run_id = get_object_id(result["runs"][0])
self.xcom_push(context, key="run_id", value=run_id)
self.log.info("Transfer run %s submitted successfully.", run_id)
return result

if not self.deferrable:
result = self._wait_for_transfer_to_be_done(
run_id=run_id,
transfer_config_id=transfer_config["config_id"],
)
self.log.info("Transfer run %s submitted successfully.", run_id)
return result

self.defer(
trigger=BigQueryDataTransferRunTrigger(
project_id=self.project_id,
config_id=transfer_config["config_id"],
run_id=run_id,
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_completed",
)

def _wait_for_transfer_to_be_done(self, run_id: str, transfer_config_id: str, interval: int = 10):
if interval < 0:
raise ValueError("Interval must be > 0")

while True:
transfer_run: TransferRun = self.hook.get_transfer_run(
run_id=run_id,
transfer_config_id=transfer_config_id,
project_id=self.project_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)
state = transfer_run.state

if self._job_is_done(state):
if state == TransferState.FAILED or state == TransferState.CANCELLED:
raise AirflowException(f"Transfer run was finished with {state} status.")

result = TransferRun.to_dict(transfer_run)
return result

self.log.info("Transfer run is still working, waiting for %s seconds...", interval)
self.log.info("Transfer run status: %s", state)
time.sleep(interval)

@staticmethod
def _job_is_done(state: TransferState) -> bool:
finished_job_statuses = [
state.SUCCEEDED,
state.CANCELLED,
state.FAILED,
]

return state in finished_job_statuses

def execute_completed(self, context: Context, event: dict):
"""Method to be executed after invoked trigger in defer method finishes its job."""
if event["status"] == "failed" or event["status"] == "cancelled":
self.log.error("Trigger finished its work with status: %s.", event["status"])
raise AirflowException(event["message"])

transfer_run: TransferRun = self.hook.get_transfer_run(
project_id=self.project_id,
run_id=event["run_id"],
transfer_config_id=event["config_id"],
)

self.log.info(
"%s finished with message: %s",
event["run_id"],
event["message"],
)

return TransferRun.to_dict(transfer_run)

0 comments on commit 5fcdd32

Please sign in to comment.