Skip to content

Commit

Permalink
Add deferrable mode to CloudSQLExportInstanceOperator (#30852)
Browse files Browse the repository at this point in the history
Co-authored-by: Beata Kossakowska <bkossakowska@google.com>
  • Loading branch information
bkossakowska and Beata Kossakowska committed Jun 29, 2023
1 parent f3f69bf commit c0eaa9b
Show file tree
Hide file tree
Showing 9 changed files with 666 additions and 22 deletions.
50 changes: 44 additions & 6 deletions airflow/providers/google/cloud/hooks/cloud_sql.py
Expand Up @@ -39,16 +39,18 @@
from urllib.parse import quote_plus

import httpx
from aiohttp import ClientSession
from gcloud.aio.auth import AioSession, Token
from googleapiclient.discovery import Resource, build
from googleapiclient.errors import HttpError

from airflow.exceptions import AirflowException
from requests import Session

# Number of retries - used by googleapiclient method calls to perform retries
# For requests that are "retriable"
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, get_field
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook, get_field
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -300,8 +302,7 @@ def delete_database(self, instance: str, database: str, project_id: str) -> None
self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name)

@GoogleBaseHook.fallback_to_default_project_id
@GoogleBaseHook.operation_in_progress_retry()
def export_instance(self, instance: str, body: dict, project_id: str) -> None:
def export_instance(self, instance: str, body: dict, project_id: str):
"""
Exports data from a Cloud SQL instance to a Cloud Storage bucket as a SQL dump
or CSV file.
Expand All @@ -321,7 +322,7 @@ def export_instance(self, instance: str, body: dict, project_id: str) -> None:
.execute(num_retries=self.num_retries)
)
operation_name = response["name"]
self._wait_for_operation_to_complete(project_id=project_id, operation_name=operation_name)
return operation_name

@GoogleBaseHook.fallback_to_default_project_id
def import_instance(self, instance: str, body: dict, project_id: str) -> None:
Expand Down Expand Up @@ -376,6 +377,7 @@ def clone_instance(self, instance: str, body: dict, project_id: str) -> None:
except HttpError as ex:
raise AirflowException(f"Cloning of instance {instance} failed: {ex.content}")

@GoogleBaseHook.fallback_to_default_project_id
def _wait_for_operation_to_complete(
self, project_id: str, operation_name: str, time_to_sleep: int = TIME_TO_SLEEP_IN_SECONDS
) -> None:
Expand Down Expand Up @@ -412,6 +414,42 @@ def _wait_for_operation_to_complete(
)


class CloudSQLAsyncHook(GoogleBaseAsyncHook):
"""Class to get asynchronous hook for Google Cloud SQL."""

sync_hook_class = CloudSQLHook

async def _get_conn(self, session: Session, url: str):
scopes = [
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/sqlservice.admin",
]

async with Token(scopes=scopes) as token:
session_aio = AioSession(session)
headers = {
"Authorization": f"Bearer {await token.get()}",
}
return await session_aio.get(url=url, headers=headers)

async def get_operation_name(self, project_id: str, operation_name: str, session):
url = f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project_id}/operations/{operation_name}"
return await self._get_conn(url=str(url), session=session)

async def get_operation(self, project_id: str, operation_name: str):
async with ClientSession() as session:
try:
operation = await self.get_operation_name(
project_id=project_id,
operation_name=operation_name,
session=session,
)
operation = await operation.json(content_type=None)
except HttpError as e:
raise e
return operation


class CloudSqlProxyRunner(LoggingMixin):
"""
Downloads and runs cloud-sql-proxy as subprocess of the Python process.
Expand Down
41 changes: 40 additions & 1 deletion airflow/providers/google/cloud/operators/cloud_sql.py
Expand Up @@ -28,6 +28,7 @@
from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLDatabaseHook, CloudSQLHook
from airflow.providers.google.cloud.links.cloud_sql import CloudSQLInstanceDatabaseLink, CloudSQLInstanceLink
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.cloud.triggers.cloud_sql import CloudSQLExportTrigger
from airflow.providers.google.cloud.utils.field_validator import GcpBodyFieldValidator
from airflow.providers.google.common.hooks.base_google import get_field
from airflow.providers.google.common.links.storage import FileDetailsLink
Expand Down Expand Up @@ -926,6 +927,9 @@ class CloudSQLExportInstanceOperator(CloudSQLBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode.
:param poke_interval: (Deferrable mode only) Time (seconds) to wait between calls
to check the run status.
"""

# [START gcp_sql_export_template_fields]
Expand All @@ -951,10 +955,14 @@ def __init__(
api_version: str = "v1beta4",
validate_body: bool = True,
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = False,
poke_interval: int = 10,
**kwargs,
) -> None:
self.body = body
self.validate_body = validate_body
self.deferrable = deferrable
self.poke_interval = poke_interval
super().__init__(
project_id=project_id,
instance=instance,
Expand Down Expand Up @@ -994,7 +1002,38 @@ def execute(self, context: Context) -> None:
uri=self.body["exportContext"]["uri"][5:],
project_id=self.project_id or hook.project_id,
)
return hook.export_instance(project_id=self.project_id, instance=self.instance, body=self.body)

operation_name = hook.export_instance(
project_id=self.project_id, instance=self.instance, body=self.body
)

if not self.deferrable:
return hook._wait_for_operation_to_complete(
project_id=self.project_id, operation_name=operation_name
)
else:
self.defer(
trigger=CloudSQLExportTrigger(
operation_name=operation_name,
project_id=self.project_id or hook.project_id,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context, event=None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] == "success":
self.log.info("Operation %s completed successfully", event["operation_name"])
else:
self.log.exception("Unexpected error in the operation.")
raise AirflowException(event["message"])


class CloudSQLImportInstanceOperator(CloudSQLBaseOperator):
Expand Down
102 changes: 102 additions & 0 deletions airflow/providers/google/cloud/triggers/cloud_sql.py
@@ -0,0 +1,102 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains Google Cloud SQL triggers."""
from __future__ import annotations

import asyncio
from typing import Sequence

from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLAsyncHook, CloudSqlOperationStatus
from airflow.triggers.base import BaseTrigger, TriggerEvent


class CloudSQLExportTrigger(BaseTrigger):
"""
Trigger that periodically polls information from Cloud SQL API to verify job status.
Implementation leverages asynchronous transport.
"""

def __init__(
self,
operation_name: str,
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
poke_interval: int = 20,
):
super().__init__()
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.operation_name = operation_name
self.project_id = project_id
self.poke_interval = poke_interval
self.hook = CloudSQLAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

def serialize(self):
return (
"airflow.providers.google.cloud.triggers.cloud_sql.CloudSQLExportTrigger",
{
"operation_name": self.operation_name,
"project_id": self.project_id,
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"poke_interval": self.poke_interval,
},
)

async def run(self):
while True:
try:
operation = await self.hook.get_operation(
project_id=self.project_id, operation_name=self.operation_name
)
if operation["status"] == CloudSqlOperationStatus.DONE:
if "error" in operation:
yield TriggerEvent(
{
"operation_name": operation["name"],
"status": "error",
"message": operation["error"]["message"],
}
)
return
yield TriggerEvent(
{
"operation_name": operation["name"],
"status": "success",
}
)
return
else:
self.log.info(
"Operation status is %s, sleeping for %s seconds.",
operation["status"],
self.poke_interval,
)
await asyncio.sleep(self.poke_interval)
except Exception as e:
self.log.exception("Exception occurred while checking operation status.")
yield TriggerEvent(
{
"status": "failed",
"message": str(e),
}
)
3 changes: 3 additions & 0 deletions airflow/providers/google/provider.yaml
Expand Up @@ -847,6 +847,9 @@ triggers:
- integration-name: Google Cloud Composer
python-modules:
- airflow.providers.google.cloud.triggers.cloud_composer
- integration-name: Google Cloud SQL
python-modules:
- airflow.providers.google.cloud.triggers.cloud_sql
- integration-name: Google Dataflow
python-modules:
- airflow.providers.google.cloud.triggers.dataflow
Expand Down
Expand Up @@ -241,6 +241,14 @@ it will be retrieved from the Google Cloud connection used. Both variants are sh
:start-after: [START howto_operator_cloudsql_export]
:end-before: [END howto_operator_cloudsql_export]

Also for all this action you can use operator in the deferrable mode:

.. exampleinclude:: /../../tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_deferrable.py
:language: python
:dedent: 4
:start-after: [START howto_operator_cloudsql_export_async]
:end-before: [END howto_operator_cloudsql_export_async]

Templating
""""""""""

Expand Down

0 comments on commit c0eaa9b

Please sign in to comment.