From ba59f3475c160102332bb36fb7414ed7b7346383 Mon Sep 17 00:00:00 2001 From: Beata Kossakowska <109511937+bkossakowska@users.noreply.github.com> Date: Mon, 4 Sep 2023 11:29:55 +0200 Subject: [PATCH] Add deferrable mode to Dataplex DataQuality. (#33954) Co-authored-by: Beata Kossakowska --- .../providers/google/cloud/hooks/dataplex.py | 71 ++++++++- .../google/cloud/operators/dataplex.py | 115 ++++++++++++++- .../google/cloud/triggers/dataplex.py | 110 ++++++++++++++ airflow/providers/google/provider.yaml | 3 + .../operators/cloud/dataplex.rst | 16 +++ .../google/cloud/operators/test_dataplex.py | 71 ++++++++- .../google/cloud/triggers/test_dataplex.py | 135 ++++++++++++++++++ .../cloud/dataplex/example_dataplex_dq.py | 28 ++++ 8 files changed, 539 insertions(+), 10 deletions(-) create mode 100644 airflow/providers/google/cloud/triggers/dataplex.py create mode 100644 tests/providers/google/cloud/triggers/test_dataplex.py diff --git a/airflow/providers/google/cloud/hooks/dataplex.py b/airflow/providers/google/cloud/hooks/dataplex.py index d3c958820a0a6..98ec121b65cf9 100644 --- a/airflow/providers/google/cloud/hooks/dataplex.py +++ b/airflow/providers/google/cloud/hooks/dataplex.py @@ -22,7 +22,7 @@ from google.api_core.client_options import ClientOptions from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault -from google.cloud.dataplex_v1 import DataplexServiceClient, DataScanServiceClient +from google.cloud.dataplex_v1 import DataplexServiceClient, DataScanServiceAsyncClient, DataScanServiceClient from google.cloud.dataplex_v1.types import ( Asset, DataScan, @@ -35,7 +35,7 @@ from airflow.exceptions import AirflowException from airflow.providers.google.common.consts import CLIENT_INFO -from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook if TYPE_CHECKING: from google.api_core.operation import Operation @@ -859,3 +859,70 @@ def list_data_scan_jobs( metadata=metadata, ) return result + + +class DataplexAsyncHook(GoogleBaseAsyncHook): + """ + Asynchronous Hook for Google Cloud Dataplex APIs. + + All the methods in the hook where project_id is used must be called with + keyword arguments rather than positional. + """ + + sync_hook_class = DataplexHook + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + + super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain) + + async def get_dataplex_data_scan_client(self) -> DataScanServiceAsyncClient: + """Returns DataScanServiceAsyncClient.""" + client_options = ClientOptions(api_endpoint="dataplex.googleapis.com:443") + + return DataScanServiceAsyncClient( + credentials=(await self.get_sync_hook()).get_credentials(), + client_info=CLIENT_INFO, + client_options=client_options, + ) + + @GoogleBaseHook.fallback_to_default_project_id + async def get_data_scan_job( + self, + project_id: str, + region: str, + data_scan_id: str | None = None, + job_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Any: + """ + Gets a DataScan Job resource. + + :param project_id: Required. The ID of the Google Cloud project that the lake belongs to. + :param region: Required. The ID of the Google Cloud region that the lake belongs to. + :param data_scan_id: Required. DataScan identifier. + :param job_id: Required. The resource name of the DataScanJob: + projects/{project_id}/locations/{region}/dataScans/{data_scan_id}/jobs/{data_scan_job_id} + :param retry: A retry object used to retry requests. If `None` is specified, requests + will not be retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + """ + client = await self.get_dataplex_data_scan_client() + + name = f"projects/{project_id}/locations/{region}/dataScans/{data_scan_id}/jobs/{job_id}" + result = await client.get_data_scan_job( + request={"name": name, "view": "FULL"}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + return result diff --git a/airflow/providers/google/cloud/operators/dataplex.py b/airflow/providers/google/cloud/operators/dataplex.py index ffbb75a7b608b..9787212ab2dec 100644 --- a/airflow/providers/google/cloud/operators/dataplex.py +++ b/airflow/providers/google/cloud/operators/dataplex.py @@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow import AirflowException +from airflow.providers.google.cloud.triggers.dataplex import DataplexDataQualityJobTrigger if TYPE_CHECKING: from google.protobuf.field_mask_pb2 import FieldMask @@ -34,6 +35,7 @@ from google.cloud.dataplex_v1.types import Asset, DataScan, DataScanJob, Lake, Task, Zone from googleapiclient.errors import HttpError +from airflow.configuration import conf from airflow.providers.google.cloud.hooks.dataplex import AirflowDataQualityScanException, DataplexHook from airflow.providers.google.cloud.links.dataplex import ( DataplexLakeLink, @@ -895,6 +897,9 @@ class DataplexRunDataQualityScanOperator(GoogleCloudBaseOperator): :param result_timeout: Value in seconds for which operator will wait for the Data Quality scan result when the flag `asynchronous = False`. Throws exception if there is no result found after specified amount of seconds. + :param polling_interval_seconds: time in seconds between polling for job completion. + The value is considered only when running in deferrable mode. Must be greater than 0. + :param deferrable: Run operator in the deferrable mode. :return: Dataplex Data Quality scan job id. """ @@ -915,6 +920,8 @@ def __init__( asynchronous: bool = False, fail_on_dq_failure: bool = False, result_timeout: float = 60.0 * 10, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + polling_interval_seconds: int = 10, *args, **kwargs, ) -> None: @@ -932,6 +939,8 @@ def __init__( self.asynchronous = asynchronous self.fail_on_dq_failure = fail_on_dq_failure self.result_timeout = result_timeout + self.deferrable = deferrable + self.polling_interval_seconds = polling_interval_seconds def execute(self, context: Context) -> str: hook = DataplexHook( @@ -949,6 +958,24 @@ def execute(self, context: Context) -> str: metadata=self.metadata, ) job_id = result.job.name.split("/")[-1] + + if self.deferrable: + if self.asynchronous: + raise AirflowException( + "Both asynchronous and deferrable parameters were passed. Please, provide only one." + ) + self.defer( + trigger=DataplexDataQualityJobTrigger( + job_id=job_id, + data_scan_id=self.data_scan_id, + project_id=self.project_id, + region=self.region, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + polling_interval_seconds=self.polling_interval_seconds, + ), + method_name="execute_complete", + ) if not self.asynchronous: job = hook.wait_for_data_scan_job( job_id=job_id, @@ -974,6 +1001,31 @@ def execute(self, context: Context) -> str: return job_id + def execute_complete(self, context, event=None) -> None: + """ + Callback for when the trigger fires - returns immediately. + + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + job_state = event["job_state"] + job_id = event["job_id"] + if job_state == DataScanJob.State.FAILED: + raise AirflowException(f"Job failed:\n{job_id}") + if job_state == DataScanJob.State.CANCELLED: + raise AirflowException(f"Job was cancelled:\n{job_id}") + if job_state == DataScanJob.State.SUCCEEDED: + job = event["job"] + if not job["data_quality_result"]["passed"]: + if self.fail_on_dq_failure: + raise AirflowDataQualityScanException( + f"Data Quality job {job_id} execution failed due to failure of its scanning " + f"rules: {self.data_scan_id}" + ) + else: + self.log.info("Data Quality job executed successfully.") + return job_id + class DataplexGetDataQualityScanResultOperator(GoogleCloudBaseOperator): """ @@ -1006,6 +1058,9 @@ class DataplexGetDataQualityScanResultOperator(GoogleCloudBaseOperator): :param result_timeout: Value in seconds for which operator will wait for the Data Quality scan result when the flag `wait_for_result = True`. Throws exception if there is no result found after specified amount of seconds. + :param polling_interval_seconds: time in seconds between polling for job completion. + The value is considered only when running in deferrable mode. Must be greater than 0. + :param deferrable: Run operator in the deferrable mode. :return: Dict representing DataScanJob. When the job completes with a successful status, information about the Data Quality result @@ -1029,6 +1084,8 @@ def __init__( fail_on_dq_failure: bool = False, wait_for_results: bool = True, result_timeout: float = 60.0 * 10, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + polling_interval_seconds: int = 10, *args, **kwargs, ) -> None: @@ -1046,6 +1103,8 @@ def __init__( self.fail_on_dq_failure = fail_on_dq_failure self.wait_for_results = wait_for_results self.result_timeout = result_timeout + self.deferrable = deferrable + self.polling_interval_seconds = polling_interval_seconds def execute(self, context: Context) -> dict: hook = DataplexHook( @@ -1070,13 +1129,27 @@ def execute(self, context: Context) -> dict: self.job_id = job_id.split("/")[-1] if self.wait_for_results: - job = hook.wait_for_data_scan_job( - job_id=self.job_id, - data_scan_id=self.data_scan_id, - project_id=self.project_id, - region=self.region, - result_timeout=self.result_timeout, - ) + if self.deferrable: + self.defer( + trigger=DataplexDataQualityJobTrigger( + job_id=self.job_id, + data_scan_id=self.data_scan_id, + project_id=self.project_id, + region=self.region, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + polling_interval_seconds=self.polling_interval_seconds, + ), + method_name="execute_complete", + ) + else: + job = hook.wait_for_data_scan_job( + job_id=self.job_id, + data_scan_id=self.data_scan_id, + project_id=self.project_id, + region=self.region, + result_timeout=self.result_timeout, + ) else: job = hook.get_data_scan_job( project_id=self.project_id, @@ -1105,6 +1178,34 @@ def execute(self, context: Context) -> dict: return result + def execute_complete(self, context, event=None) -> None: + """ + Callback for when the trigger fires - returns immediately. + + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + job_state = event["job_state"] + job_id = event["job_id"] + job = event["job"] + if job_state == DataScanJob.State.FAILED: + raise AirflowException(f"Job failed:\n{job_id}") + if job_state == DataScanJob.State.CANCELLED: + raise AirflowException(f"Job was cancelled:\n{job_id}") + if job_state == DataScanJob.State.SUCCEEDED: + if not job["data_quality_result"]["passed"]: + if self.fail_on_dq_failure: + raise AirflowDataQualityScanException( + f"Data Quality job {self.job_id} execution failed due to failure of its scanning " + f"rules: {self.data_scan_id}" + ) + else: + self.log.info("Data Quality job executed successfully") + else: + self.log.info("Data Quality job execution returned status: %s", job_state) + + return job + class DataplexCreateZoneOperator(GoogleCloudBaseOperator): """ diff --git a/airflow/providers/google/cloud/triggers/dataplex.py b/airflow/providers/google/cloud/triggers/dataplex.py new file mode 100644 index 0000000000000..fd0ff2fa93b65 --- /dev/null +++ b/airflow/providers/google/cloud/triggers/dataplex.py @@ -0,0 +1,110 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Dataplex triggers.""" +from __future__ import annotations + +import asyncio +from typing import AsyncIterator, Sequence + +from google.cloud.dataplex_v1.types import DataScanJob + +from airflow.providers.google.cloud.hooks.dataplex import DataplexAsyncHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class DataplexDataQualityJobTrigger(BaseTrigger): + """ + DataplexDataQualityJobTrigger runs on the trigger worker and waits for the job to be `SUCCEEDED` state. + + :param job_id: Optional. The ID of a Dataplex job. + :param data_scan_id: Required. DataScan identifier. + :param project_id: Google Cloud Project where the job is running. + :param region: The ID of the Google Cloud region that the job belongs to. + :param gcp_conn_id: Optional, the connection ID used to connect to Google Cloud Platform. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param polling_interval_seconds: polling period in seconds to check for the status. + """ + + def __init__( + self, + job_id: str | None, + data_scan_id: str, + project_id: str | None, + region: str, + gcp_conn_id: str = "google_cloud_default", + polling_interval_seconds: int = 10, + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ): + + super().__init__(**kwargs) + self.job_id = job_id + self.data_scan_id = data_scan_id + self.project_id = project_id + self.region = region + self.gcp_conn_id = gcp_conn_id + self.polling_interval_seconds = polling_interval_seconds + self.impersonation_chain = impersonation_chain + + def serialize(self): + return ( + "airflow.providers.google.cloud.triggers.dataplex.DataplexDataQualityJobTrigger", + { + "job_id": self.job_id, + "data_scan_id": self.data_scan_id, + "project_id": self.project_id, + "region": self.region, + "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, + "polling_interval_seconds": self.polling_interval_seconds, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + hook = DataplexAsyncHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + while True: + job = await hook.get_data_scan_job( + project_id=self.project_id, + region=self.region, + job_id=self.job_id, + data_scan_id=self.data_scan_id, + ) + state = job.state + if state in (DataScanJob.State.FAILED, DataScanJob.State.SUCCEEDED, DataScanJob.State.CANCELLED): + break + self.log.info( + "Current state is: %s, sleeping for %s seconds.", + DataScanJob.State(state).name, + self.polling_interval_seconds, + ) + await asyncio.sleep(self.polling_interval_seconds) + yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": self._convert_to_dict(job)}) + + def _convert_to_dict(self, job: DataScanJob) -> dict: + """Returns a representation of a DataScanJob instance as a dict.""" + return DataScanJob.to_dict(job) diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 7670ca5966f40..778d789fe586e 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -905,6 +905,9 @@ triggers: - integration-name: Google Data Fusion python-modules: - airflow.providers.google.cloud.triggers.datafusion + - integration-name: Google Dataplex + python-modules: + - airflow.providers.google.cloud.triggers.dataplex - integration-name: Google Dataproc python-modules: - airflow.providers.google.cloud.triggers.dataproc diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataplex.rst b/docs/apache-airflow-providers-google/operators/cloud/dataplex.rst index 05b96f1a1dd4b..8c0da7d68610d 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/dataplex.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/dataplex.rst @@ -217,6 +217,14 @@ To check that running Dataplex Data Quality scan succeeded you can use: :start-after: [START howto_dataplex_data_scan_job_state_sensor] :end-before: [END howto_dataplex_data_scan_job_state_sensor] +Also for this action you can use operator in the deferrable mode: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataplex/example_dataplex_dq.py + :language: python + :dedent: 4 + :start-after: [START howto_dataplex_run_data_quality_def_operator] + :end-before: [END howto_dataplex_run_data_quality_def_operator] + Get a Data Quality scan job --------------------------- @@ -230,6 +238,14 @@ To get a Data Quality scan job you can use: :start-after: [START howto_dataplex_get_data_quality_job_operator] :end-before: [END howto_dataplex_get_data_quality_job_operator] +Also for this action you can use operator in the deferrable mode: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataplex/example_dataplex_dq.py + :language: python + :dedent: 4 + :start-after: [START howto_dataplex_get_data_quality_job_def_operator] + :end-before: [END howto_dataplex_get_data_quality_job_def_operator] + Create a zone ------------- diff --git a/tests/providers/google/cloud/operators/test_dataplex.py b/tests/providers/google/cloud/operators/test_dataplex.py index 3d5d770b563c4..2cddabb4a76b2 100644 --- a/tests/providers/google/cloud/operators/test_dataplex.py +++ b/tests/providers/google/cloud/operators/test_dataplex.py @@ -18,8 +18,10 @@ from unittest import mock +import pytest from google.api_core.gapic_v1.method import DEFAULT +from airflow.exceptions import TaskDeferred from airflow.providers.google.cloud.operators.dataplex import ( DataplexCreateAssetOperator, DataplexCreateLakeOperator, @@ -36,6 +38,8 @@ DataplexListTasksOperator, DataplexRunDataQualityScanOperator, ) +from airflow.providers.google.cloud.triggers.dataplex import DataplexDataQualityJobTrigger +from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME HOOK_STR = "airflow.providers.google.cloud.operators.dataplex.DataplexHook" TASK_STR = "airflow.providers.google.cloud.operators.dataplex.Task" @@ -289,13 +293,48 @@ def test_execute(self, mock_data_scan_job, hook_mock): metadata=(), ) + @mock.patch(HOOK_STR) + @mock.patch(DATASCANJOB_STR) + def test_execute_deferrable(self, mock_data_scan_job, hook_mock): + op = DataplexRunDataQualityScanOperator( + task_id="execute_data_scan", + project_id=PROJECT_ID, + region=REGION, + data_scan_id=DATA_SCAN_ID, + api_version=API_VERSION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc: + op.execute(mock.MagicMock()) + + hook_mock.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + api_version=API_VERSION, + impersonation_chain=IMPERSONATION_CHAIN, + ) + hook_mock.return_value.run_data_scan.assert_called_once_with( + project_id=PROJECT_ID, + region=REGION, + data_scan_id=DATA_SCAN_ID, + retry=DEFAULT, + timeout=None, + metadata=(), + ) + hook_mock.return_value.wait_for_data_scan_job.assert_not_called() + + assert isinstance(exc.value.trigger, DataplexDataQualityJobTrigger) + assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME + class TestDataplexGetDataQualityScanResultOperator: @mock.patch(HOOK_STR) @mock.patch(DATASCANJOB_STR) def test_execute(self, mock_data_scan_job, hook_mock): op = DataplexGetDataQualityScanResultOperator( - task_id="get_data_scan", + task_id="get_data_scan_result", project_id=PROJECT_ID, region=REGION, job_id=JOB_ID, @@ -322,6 +361,36 @@ def test_execute(self, mock_data_scan_job, hook_mock): metadata=(), ) + @mock.patch(HOOK_STR) + @mock.patch(DATASCANJOB_STR) + def test_execute_deferrable(self, mock_data_scan_job, hook_mock): + op = DataplexGetDataQualityScanResultOperator( + task_id="get_data_scan_result", + project_id=PROJECT_ID, + region=REGION, + job_id=JOB_ID, + data_scan_id=DATA_SCAN_ID, + api_version=API_VERSION, + wait_for_results=True, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc: + op.execute(mock.MagicMock()) + + hook_mock.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + api_version=API_VERSION, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + hook_mock.return_value.wait_for_data_scan_job.assert_not_called() + + assert isinstance(exc.value.trigger, DataplexDataQualityJobTrigger) + assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME + class TestDataplexCreateAssetOperator: @mock.patch(HOOK_STR) diff --git a/tests/providers/google/cloud/triggers/test_dataplex.py b/tests/providers/google/cloud/triggers/test_dataplex.py new file mode 100644 index 0000000000000..574a6168b061d --- /dev/null +++ b/tests/providers/google/cloud/triggers/test_dataplex.py @@ -0,0 +1,135 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import logging +from unittest import mock + +import pytest +from google.cloud.dataplex_v1.types import DataScanJob + +from airflow.providers.google.cloud.triggers.dataplex import DataplexDataQualityJobTrigger +from airflow.triggers.base import TriggerEvent + +TEST_PROJECT_ID = "project-id" +TEST_REGION = "region" +TEST_POLL_INTERVAL = 5 +TEST_GCP_CONN_ID = "test_conn" +TEST_JOB_ID = "test_job_id" +TEST_DATA_SCAN_ID = "test_data_scan_id" +HOOK_STR = "airflow.providers.google.cloud.hooks.dataplex.DataplexAsyncHook.{}" +TRIGGER_STR = "airflow.providers.google.cloud.triggers.dataplex.DataplexDataQualityJobTrigger.{}" + + +@pytest.fixture +def trigger(): + return DataplexDataQualityJobTrigger( + job_id=TEST_JOB_ID, + data_scan_id=TEST_DATA_SCAN_ID, + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=None, + polling_interval_seconds=TEST_POLL_INTERVAL, + ) + + +@pytest.fixture() +def async_get_data_scan_job(): + def func(**kwargs): + m = mock.MagicMock() + m.configure_mock(**kwargs) + f = asyncio.Future() + f.set_result(m) + return f + + return func + + +class TestDataplexDataQualityJobTrigger: + def test_async_dataplex_job_trigger_serialization_should_execute_successfully(self, trigger): + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.google.cloud.triggers.dataplex.DataplexDataQualityJobTrigger" + assert kwargs == { + "job_id": TEST_JOB_ID, + "data_scan_id": TEST_DATA_SCAN_ID, + "project_id": TEST_PROJECT_ID, + "region": TEST_REGION, + "gcp_conn_id": TEST_GCP_CONN_ID, + "impersonation_chain": None, + "polling_interval_seconds": TEST_POLL_INTERVAL, + } + + @pytest.mark.asyncio + @mock.patch(TRIGGER_STR.format("_convert_to_dict")) + @mock.patch(HOOK_STR.format("get_data_scan_job")) + async def test_async_dataplex_job_triggers_on_success_should_execute_successfully( + self, mock_hook, mock_convert_to_dict, trigger, async_get_data_scan_job + ): + mock_hook.return_value = async_get_data_scan_job( + state=DataScanJob.State.SUCCEEDED, + ) + mock_convert_to_dict.return_value = {} + + generator = trigger.run() + actual_event = await generator.asend(None) + + expected_event = TriggerEvent( + { + "job_id": TEST_JOB_ID, + "job_state": DataScanJob.State.SUCCEEDED, + "job": {}, + } + ) + assert expected_event == actual_event + + @pytest.mark.asyncio + @mock.patch(TRIGGER_STR.format("_convert_to_dict")) + @mock.patch(HOOK_STR.format("get_data_scan_job")) + async def test_async_dataplex_job_trigger_run_returns_error_event( + self, mock_hook, mock_convert_to_dict, trigger, async_get_data_scan_job + ): + mock_hook.return_value = async_get_data_scan_job( + state=DataScanJob.State.FAILED, + ) + mock_convert_to_dict.return_value = {} + + actual_event = await (trigger.run()).asend(None) + await asyncio.sleep(0.5) + + expected_event = TriggerEvent( + {"job_id": TEST_JOB_ID, "job_state": DataScanJob.State.FAILED, "job": {}} + ) + assert expected_event == actual_event + + @pytest.mark.asyncio + @mock.patch(HOOK_STR.format("get_data_scan_job")) + async def test_async_dataplex_job_run_loop_is_still_running( + self, mock_hook, trigger, caplog, async_get_data_scan_job + ): + mock_hook.return_value = async_get_data_scan_job( + state=DataScanJob.State.RUNNING, + ) + + caplog.set_level(logging.INFO) + + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + assert not task.done() + assert f"Current state is: {DataScanJob.State.RUNNING}, sleeping for {TEST_POLL_INTERVAL} seconds." diff --git a/tests/system/providers/google/cloud/dataplex/example_dataplex_dq.py b/tests/system/providers/google/cloud/dataplex/example_dataplex_dq.py index 4bcd9abbca031..1290698c0e97f 100644 --- a/tests/system/providers/google/cloud/dataplex/example_dataplex_dq.py +++ b/tests/system/providers/google/cloud/dataplex/example_dataplex_dq.py @@ -261,6 +261,31 @@ data_scan_id=DATA_SCAN_ID, ) # [END howto_dataplex_get_data_quality_job_operator] + # [START howto_dataplex_run_data_quality_def_operator] + run_data_scan_def = DataplexRunDataQualityScanOperator( + task_id="run_data_scan_def", + project_id=PROJECT_ID, + region=REGION, + data_scan_id=DATA_SCAN_ID, + deferrable=True, + ) + # [END howto_dataplex_run_data_quality_def_operator] + run_data_scan_async_2 = DataplexRunDataQualityScanOperator( + task_id="run_data_scan_async_2", + project_id=PROJECT_ID, + region=REGION, + data_scan_id=DATA_SCAN_ID, + asynchronous=True, + ) + # [START howto_dataplex_get_data_quality_job_def_operator] + get_data_scan_job_result_def = DataplexGetDataQualityScanResultOperator( + task_id="get_data_scan_job_result_def", + project_id=PROJECT_ID, + region=REGION, + data_scan_id=DATA_SCAN_ID, + deferrable=True, + ) + # [END howto_dataplex_get_data_quality_job_def_operator] # [START howto_dataplex_delete_asset_operator] delete_asset = DataplexDeleteAssetOperator( task_id="delete_asset", @@ -323,6 +348,9 @@ run_data_scan_async, get_data_scan_job_status, get_data_scan_job_result_2, + run_data_scan_def, + run_data_scan_async_2, + get_data_scan_job_result_def, # TEST TEARDOWN delete_asset, delete_zone,